## Imports

In [None]:
import os
import tarfile
import urllib
from scipy import stats

import pandas as pd
from pandas import DataFrame, Series
from pandas.plotting import scatter_matrix

import matplotlib.pyplot as plt

import numpy as np

from sklearn.model_selection import train_test_split, StratifiedShuffleSplit, cross_val_score, GridSearchCV
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder, StandardScaler, FunctionTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

from sklearn.linear_model import LinearRegression
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc, precision_recall_curve, precision_score, recall_score, f1_score, average_precision_score

## Feature Construction

In [None]:
class MissingAttributesAdder(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=None):
        return self

    def transform(self, X: DataFrame):
        X = X.copy()
        X["donor_CMV_missing"] = (X["donor_CMV"] == "?").astype(int)
        X["recipient_CMV_missing"] = (X["recipient_CMV"] == "?").astype(int)

        X["ABO_match_missing"] = (X["ABO_match"] == "?").astype(int)
        return X

## Encoding

In [None]:
def encode_booleans(data: DataFrame):
    zero_mapper = {"absent", "no", "female", "mismatched", "female_to_male", "low", "nonmalignant", "peripheral_blood", "?"}
    one_mapper  = {"present", "yes", "male", "matched", "other", "high", "malignant", "bone_marrow"}

    def map_values(value):
        if value in zero_mapper:
            return 0
        if value in one_mapper:
            return 1
        
        return value

    return data.map(map_values)

## Imputing

In [None]:
def impute_body_mass(data: DataFrame):
    data = data.copy()

    data["recipient_age_group"] = pd.cut(data["recipient_age"], bins=[0., 1., 5., 10., 15., 20., np.inf], labels=["<1", "1-5", "5-10", "10-15", "15-20", "20+"])

    group_median = data.groupby(["recipient_gender", "recipient_age_group"])["recipient_body_mass"].median()

    def fill_mass(row):
        if pd.isna(row["recipient_body_mass"]):
            return group_median.get((row["recipient_gender"], row["recipient_age_group"]), data["recipient_body_mass"].median())
        return row["recipient_body_mass"]

    data["recipient_body_mass"] = data.apply(fill_mass, axis=1)

    data = data.drop(columns="recipient_age_group")
    return data

In [None]:
def impute_cell_dosage(data: DataFrame):
    data = data.copy()

    train_data = data["CD3_per_kg", "CD34_per_kg"].dropna()

    linear_reg = LinearRegression()
    linear_reg.fit(train_data[["CD34_per_kg"]], train_data["CD3_per_kg"])

    missing_CD3_data = data["CD3_per_kg"].isna()
    data.loc[missing_CD3_data, "CD3_per_kg"] = linear_reg.predict(data.loc[missing_CD3_data, ["CD34_per_kg"]])

    data.loc[data["CD3_CD34_ratio"].isna(), "CD3_CD34_ratio"] = data["CD3_per_kg"] / data["CD34_per_kg"]

    return data

## Metrics

In [None]:
def display_metrics(y, y_pred, y_prob):

    # Confusion Matrix
    fig, ax = plt.subplots(figsize=(5, 5))
    cmatrix = confusion_matrix(y, y_pred)
    cmatrix_disp = ConfusionMatrixDisplay(
        confusion_matrix=cmatrix,
        display_labels=["no", "yes"]
    )
    cmatrix_disp.plot(ax=ax)
    plt.show()

    # ROC Curve
    fig, ax = plt.subplots(figsize=(5, 5))
    fpr, tpr, _ = roc_curve(y, y_prob[:, 1])
    roc_auc = auc(fpr, tpr)

    ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
    ax.plot([0, 1], [0, 1], linestyle="--")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.legend()
    plt.show()

    # Precision-Recall Curve
    fig, ax = plt.subplots(figsize=(5, 5))
    precision, recall, _ = precision_recall_curve(y, y_prob[:, 1])
    pr_auc = average_precision_score(y, y_prob[:, 1])

    ax.plot(recall, precision, label=f"AP = {pr_auc:.3f}")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.legend()
    plt.show()

    # General Metrics
    p = precision_score(y, y_pred)
    r = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)

    print(f"Precision: {p:.3f}")
    print(f"   Recall: {r:.3f}")
    print(f" f1-score: {f1:.3f}")