In [None]:
from jupyter_client import find_connection_file

connection_file = find_connection_file()
print(connection_file)

In [None]:
# Plotting related
import os
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio

pio.renderers.default = "notebook"

# Scikit-learn related imports
import pyarrow
import pandas as pd

pd.options.mode.copy_on_write = True
import numpy as np

from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split, cross_validate, RandomizedSearchCV
from sklearn.metrics import accuracy_score
from scipy.stats import randint, ttest_ind
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from mlxtend.plotting import plot_decision_regions

In [None]:
dataset_col = "Dataset"
uid_col = "uid"
age_col = "Age"
day_col = "Day"
response_col = "Response"
immage_col = "IMMAGE"
strain_col = "Strain"

In [None]:
def get_data_dir():
    # Define the starting directory
    current_dir = os.getcwd()

    # Traverse up the directory tree until we find a directory named "data"
    while current_dir != "/":
        if "data" in os.listdir(current_dir):
            data_dir = os.path.join(current_dir, "data")
            return data_dir
        current_dir = os.path.dirname(current_dir)
    else:
        print("Directory 'data' not found in the parent directories.")
        raise ()

In [None]:
# Read in Data and drop missing values
data_dir = get_data_dir()
df = pd.read_csv(os.path.join(data_dir, "../data/all_vaccines.csv"))
df.dropna(inplace=True, subset=[immage_col, dataset_col, day_col, response_col])

dataset_names = df.Dataset.unique()

##### Plot distribution of studies' N values

In [None]:
# Plot distribution of N values
N_vals = df[[dataset_col, uid_col]].groupby(dataset_col, as_index=False)[uid_col].nunique()
N_vals = N_vals.rename(columns={uid_col: "N"})
sns.histplot(N_vals.N)
plt.title("N values across studies")

##### Narrow to large datasets only (N > 70)

In [None]:
# Narrow N_v to large datasets only
N_vals = N_vals.loc[N_vals["N"] > 70]
datasets = df.loc[df["Dataset"].isin(N_vals["Dataset"])]
dataset_names = datasets["Dataset"].unique()
N_vals

In [None]:
# Examine available days per dataset
days = (
    datasets[[dataset_col, uid_col, day_col]].groupby(dataset_col, as_index=False)[day_col].unique()
)
t = pd.Series(days.loc[[True, False, False, False], "Day"])
# with pd.option_context('display.max_colwidth', None):
#    for index, row in days.iterrows():
#     print(f"Dataset: {row['Dataset']}\nDays: {row['Day']}\n\n")

In [None]:
# Collect day info from papers here
dataset_day_dict = {}

dataset_day_dict["GSE41080.SDY212"] = "HAI.D28"
dataset_day_dict["GSE48018.SDY1276"] = "nAb.D28"
dataset_day_dict["GSE48023.SDY1276"] = "nAb.D28"
dataset_day_dict["SDY67"] = "nAb.D28"
# dataset_day_dict[dataset_names[0]]

##### Narrow to a specific dataset and day, then keep only relevant columns

In [None]:
# Narrow to a specific dataset and day, then keep only relevant columns
strain_index = 0
dataset_name = dataset_names[0]

In [None]:
name_mask = datasets[dataset_col] == dataset_name
day_mask = datasets[day_col] == dataset_day_dict[dataset_name]

data = datasets.loc[(name_mask) & (day_mask)].reset_index()

# Somtimes there are multiple strains - so multiple rows per day
strains = data[strain_col].unique()
if len(strains) > 1:
    data = data.loc[data[strain_col] == strains[strain_index]].reset_index()

strains_t = data[strain_col].unique()
assert len(strains_t) == 1
strain = strains_t[0]

# Sometimes there are multiple geo_accession numbers, like in GSE48018.SDY1276, average the IMMAGE, since all else is the same
accessions = data["geo_accession"].unique()
if len(accessions) > 1:
    print(f"*** Multiple accession detected! Collapsing by averaging on IMMAGE value ***\n")
    data = data.groupby(uid_col, as_index=False).agg(
        {
            immage_col: "mean",
            **{col: "first" for col in data.columns if col not in [uid_col, immage_col]},
        }
    )

# Take relevant columns only
data = data[[immage_col, response_col, age_col]]

print(f"Working with dataset: {dataset_name}, strain: {strain}")
print(f"Total subjects in study: N={data.shape[0]}")
print(f"available strains: {strains}")

# data.head()

In [None]:
# Get a boolean map of sub and super threshold values
low_response_thr = data[[response_col]].quantile(q=0.3).item()

# Generate labels
# Note that we define y=1 for all responses < 30th percentile (and not <=)
# Also note that we defined y=1 as *non* responders, since later on that's what we'll care most about detecting

data["y"] = data[response_col].apply(lambda x: 1 if x < low_response_thr else 0)

# Add a text label for plot legends
data["Label text"] = data["y"].apply(lambda x: "Responders" if x == 0 else "Non-Responders")

##### Plot IMMAGE, response, and age values to look at the dynamic range

In [None]:
# Plot IMMAGE, response, and age values to look at the dynamic range
from scipy.stats import probplot

fig, axs = plt.subplots(2, 3, figsize=(18, 6))  # Create a figure with two subplots side by side

sns.histplot(data=data, x=immage_col, bins=50, ax=axs[0, 0])
sns.boxplot(data=data, x=immage_col, ax=axs[1, 0], fill=False)
# axs[0].set_title('Box Plot')
axs[0, 0].set_title(f" {immage_col}")

sns.histplot(data=data, x=response_col, bins=50, ax=axs[0, 1])
sns.boxplot(data=data, x=response_col, ax=axs[1, 1], fill=False)
# axs[1].set_title('Box Plot')
axs[0, 1].set_title(f" {response_col}")

sns.histplot(data=data, x=age_col, bins=50, ax=axs[0, 2])
sns.boxplot(data=data, x=age_col, ax=axs[1, 2], fill=False)
# axs[1].set_title('Box Plot')
axs[0, 2].set_title(f" {age_col}")

plt.tight_layout(pad=3.0)  # Adjust the layout so everything fits without overlap
fig.suptitle(f"Values Distribution in {dataset_name}, strain: {strain}")

plt.show()

##### Is there a trend like we expect? (High IMMAGE ⇒ low response)
#### Also show the distributions of IMMAGE values for responders & non-responders

In [None]:
# Classifying with logistic regression - fit on the entire dataset
from math import log


def get_threshold_from_probability(prob, intercept, slope):
    return -1 * (log(1 / prob - 1) + intercept) / slope


log_regress_immage = LogisticRegression()
log_regress_age = LogisticRegression()
log_regress_combined = LogisticRegression()

# Train a classifier based on immage and on age for comparison
log_regress_immage.fit(data[[immage_col]], data["y"])
log_regress_age.fit(data[[age_col]], data["y"])
log_regress_combined.fit(data[[immage_col, age_col]], data["y"])

non_responder_col = "p_non_responder"
non_responder_col_age = "p_non_responder_age"
non_responder_col_combined = "p_non_responder_combined"

proba = pd.DataFrame(log_regress_immage.predict_proba(data[[immage_col]]))
data[non_responder_col] = proba[1]
proba = pd.DataFrame(log_regress_age.predict_proba(data[[age_col]]))
data[non_responder_col_age] = proba[1]
proba = pd.DataFrame(log_regress_combined.predict_proba(data[[immage_col, age_col]]))
data[non_responder_col_combined] = proba[1]

##### Use a logistic regression's probabilties to look for a threshold based on above-threshold non-responder rate

In [None]:
# Define auxilary functions
from sklearn.metrics import auc, roc_auc_score, roc_curve
from math import log


def calc_and_plot_prob_threshold(data, classifier, fpr, tpr, thresholds, col_name="", plot="both"):
    # plot=['both', 'ROC', 'threshold', 'none']
    roc_auc = auc(fpr, tpr)
    intercept = classifier.intercept_[0]
    slope = classifier.coef_[0][0]

    # Identifying the optimal threshold (using Youden’s Index)
    optimal_idx = np.argmax(tpr - fpr)
    prob_threshold = thresholds[optimal_idx]

    # Calculate the cutoff value
    feature_threshold = get_threshold_from_probability(
        prob_threshold, intercept=intercept, slope=slope
    )

    if plot != "none":
        if plot == "both":
            fig, axs = plt.subplots(
                1, 2, figsize=(16, 6)
            )  # Creates a figure with two side-by-side subplots
        else:
            fig, axs = plt.subplots(1, 1, figsize=(8, 6))  # Creates a figure with one subplot
            # Make axs subsriptable
            axs = [axs]

        # Plot ROC curve on the first subplot
        axs[0].plot(fpr, tpr, label=f"ROC curve (area = {roc_auc : 0.2f})")
        axs[0].plot([0, 1], [0, 1], "k--")  # Random chance line
        axs[0].plot(fpr[optimal_idx], tpr[optimal_idx], marker="o", markersize=5, color="red")
        axs[0].set_xlim([0.0, 1.0])
        axs[0].set_ylim([0.0, 1.05])
        axs[0].set_xlabel("False Positive Rate")
        axs[0].set_ylabel("True Positive Rate")
        axs[0].set_title("ROC curve")
        axs[0].legend(loc="lower right")

        if plot == "both":
            # Plot sorted IMMAGE values vs Index on the second subplot
            sorted_data = data.sort_values(col_name, ignore_index=True).reset_index()
            sns.scatterplot(ax=axs[1], data=sorted_data, x="index", y=col_name, hue="Label text")
            axs[1].axhline(y=feature_threshold, color="black", linestyle="--")
            axs[1].set_title(f"Sorted {col_name} vs Index")

        fig.suptitle(f"Probability-based threshold with ROC\n({dataset_name}, {strain})")
        plt.tight_layout()  # Adjusts subplot params so that subplots fit into the figure area.
        plt.show()
    return (
        prob_threshold,
        feature_threshold,
    )  # feature threshold is meaningless for the multivariate case


def get_classifier_stats_prob(data, prob_column, prob_threshold):
    # Global measures (entire dataset)
    optimal_pred = data[prob_column].apply(lambda x: 1 if x >= prob_threshold else 0)
    test_accuracy = accuracy_score(data["y"], optimal_pred)
    # Performance above the prob_threshold
    y_over_thr = data.loc[data[prob_column] >= prob_threshold, ["y"]]
    non_response_rate_over_thr = y_over_thr.mean().y
    y_under_thr = data.loc[data[prob_column] < prob_threshold, ["y"]]
    non_response_rate_under_thr = y_under_thr.mean().y
    return non_response_rate_over_thr, non_response_rate_under_thr

In [None]:
# Run for immage and age to compare
# IMMAGE
fpr, tpr, thresholds = roc_curve(data["y"], data[non_responder_col])
prob_threshold, feature_threshold = calc_and_plot_prob_threshold(
    data, log_regress_immage, fpr, tpr, thresholds, col_name=immage_col, plot="both"
)
non_response_rate_over_thr, non_response_rate_under_thr = get_classifier_stats_prob(
    data, non_responder_col, prob_threshold
)
print(
    f"Optimal threshold: {feature_threshold : 0.2f} (IMMAGE value), None-responder rate: over threshold: {non_response_rate_over_thr : 0.2f}, under threshold: {non_response_rate_under_thr : 0.2f}"
)

# Age
fpr, tpr, thresholds = roc_curve(data["y"], data[non_responder_col_age])
prob_threshold_age, feature_threshold = calc_and_plot_prob_threshold(
    data, log_regress_age, fpr, tpr, thresholds, col_name=age_col, plot="both"
)
non_response_rate_over_thr, non_response_rate_under_thr = get_classifier_stats_prob(
    data, non_responder_col_age, prob_threshold_age
)
print(
    f"Optimal threshold: {feature_threshold : 0.2f} (Age), None-responder rate: over threshold: {non_response_rate_over_thr : 0.2f}, under threshold: {non_response_rate_under_thr : 0.2f}"
)

# Combined
fpr, tpr, thresholds = roc_curve(data["y"], data[non_responder_col_combined])
prob_threshold_combined, _ = calc_and_plot_prob_threshold(
    data, log_regress_combined, fpr, tpr, thresholds, plot="ROC"
)
non_response_rate_over_thr, non_response_rate_under_thr = get_classifier_stats_prob(
    data, non_responder_col_combined, prob_threshold_combined
)
print(
    f"Optimal threshold: {prob_threshold_combined : 0.2f} (probability), None-responder rate: over threshold: {non_response_rate_over_thr : 0.2f}, under threshold: {non_response_rate_under_thr : 0.2f}"
)

In [None]:
# Sliding window instead of bins, plotting non-reponder rate vs window start
def generate_windows_and_rates(data, feature_col, num_units, num_units_per_window):
    window_starts = np.linspace(
        start=data[feature_col].min(), stop=data[feature_col].max(), num=num_units
    )
    window_size = (
        (data[feature_col].max() - data[feature_col].min()) / num_units * num_units_per_window
    )
    windows = pd.DataFrame(
        {
            "start": window_starts[:-num_units_per_window],
            "end": window_starts[num_units_per_window:],
        }
    )
    rates = []

    for i, start, end in windows.itertuples():
        over = data[feature_col] >= windows["start"][i]
        under = data[feature_col] < windows["end"][i]
        rates.append(data.loc[(over & under), "y"].mean())

    rates = pd.Series(rates).fillna(0)
    windows["rate"] = rates
    threshold_idx = rates.argmax()

    return windows, window_size


def plot_sliding_windows(feature_cols):
    num_units = 100
    num_features = len(feature_cols)
    fig, axs = plt.subplots(1, num_features, figsize=(5 * num_features, 5))

    for i, feature_col in enumerate(feature_cols):
        windows, window_size = generate_windows_and_rates(data, feature_col, num_units, 20)
        sns.lineplot(data=windows, x="start", y="rate", ax=axs[i])
        axs[i].axhline(y=0.5, color="black", linestyle="--")
        axs[i].set_title(f"Window size: {window_size:.2f} {feature_col} units")
        axs[i].set_xlabel("Start")
        axs[i].set_ylabel("Rate")

    fig.suptitle(
        f"Sliding window performance\nrate of non-responders vs feature columns\n({dataset_name}, {strain})"
    )
    plt.subplots_adjust(top=0.75)

    plt.show()


plot_sliding_windows([immage_col, age_col])