In [None]:
import pyarrow.parquet as pq
import pandas as pd
import numpy as np
import pathlib
import warnings

from sklearn.utils import shuffle, parallel_backend
from sklearn.exceptions import ConvergenceWarning
from joblib import dump

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import (
    StratifiedKFold,
    GridSearchCV,
)
from joblib import load
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    confusion_matrix,
    classification_report,
)
import toml
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
import ast

In [None]:
# read in toml file and get parameters
toml_path = pathlib.Path("../1.train_models/single_class_config.toml")
with open(toml_path, "r") as f:
    config = toml.load(f)
f.close()
control = config["logistic_regression_params"]["control"]
treatment = config["logistic_regression_params"]["treatments"]
aggregation = ast.literal_eval(config["logistic_regression_params"]["aggregation"])
nomic = ast.literal_eval(config["logistic_regression_params"]["nomic"])
cell_type = config["logistic_regression_params"]["cell_type"]
print(aggregation, nomic, cell_type)

In [None]:
# if ((aggregation == True) and (nomic == True)):
#     data_split_path = pathlib.Path(f"../0.split_data/indexes/aggregated_sc_and_nomic_data_split_indexes.tsv")
# elif ((aggregation == True) and (nomic == False)):
#     data_split_path = pathlib.Path(f"../0.split_data/indexes/aggregated_sc_data_split_indexes.tsv")
# elif ((aggregation == False) and (nomic == True)):
#     data_split_path = pathlib.Path(f"../0.split_data/indexes/sc_and_nomic_data_split_indexes.tsv")
# elif (aggregation == False and nomic == False):
#     data_split_path = pathlib.Path(f"../0.split_data/indexes/sc_split_indexes.tsv")
# else:
#     print('Error')

In [None]:
# load training data from indexes and features dataframe
# data_split_path = pathlib.Path(f"../0.split_data/indexes/data_split_indexes.tsv")
data_path = pathlib.Path("../../data/SHSY5Y_preprocessed_sc_norm.parquet")

# dataframe with only the labeled data we want (exclude certain phenotypic classes)
data_df = pq.read_table(data_path).to_pandas()

# import nomic data
nomic_df_path = pathlib.Path(
    f"../../2.Nomic_nELISA_Analysis/Data/clean/Plate2/nELISA_plate_430420_{cell_type}.csv"
)
df_nomic = pd.read_csv(nomic_df_path)

# clean up nomic data
df_nomic = df_nomic.drop(columns=[col for col in df_nomic.columns if "[pgML]" in col])
# drop first 25 columns (Metadata that is not needed)
df_nomic = df_nomic.drop(columns=df_nomic.columns[3:25])
df_nomic = df_nomic.drop(columns=df_nomic.columns[0:2])

In [None]:
if (aggregation == True) and (nomic == True):
    data_split_path = pathlib.Path(
        f"../0.split_data/indexes/aggregated_sc_and_nomic_data_split_indexes.tsv"
    )
    data_split_indexes = pd.read_csv(data_split_path, sep="\t", index_col=0)
    data_df = pd.merge(
        data_df, df_nomic, left_on="Metadata_Well", right_on="position_x"
    )
    data_df = data_df.drop(columns=["position_x"])
elif (aggregation == True) and (nomic == False):
    data_split_path = pathlib.Path(
        f"../0.split_data/indexes/aggregated_sc_data_split_indexes.tsv"
    )
    data_split_indexes = pd.read_csv(data_split_path, sep="\t", index_col=0)
elif (aggregation == False) and (nomic == True):
    data_split_path = pathlib.Path(
        f"../0.split_data/indexes/sc_and_nomic_data_split_indexes.tsv"
    )
    data_split_indexes = pd.read_csv(data_split_path, sep="\t", index_col=0)
    data_df = pd.merge(
        data_df, df_nomic, left_on="Metadata_Well", right_on="position_x"
    )
    data_df = data_df.drop(columns=["position_x"])
elif aggregation == False and nomic == False:
    data_split_path = pathlib.Path(f"../0.split_data/indexes/sc_split_indexes.tsv")
    data_split_indexes = pd.read_csv(data_split_path, sep="\t", index_col=0)
else:
    print("Error")

In [None]:
# subset data_df by indexes in data_split_indexes
training_data = data_df.loc[data_split_indexes["labeled_data_index"]]

In [None]:
# get oneb_Metadata_Treatment_Dose_Inhibitor_Dose  =='DMSO_0.100_DMSO_0.025' and 'LPS_100.000_DMSO_0.025 and Thapsigargin_10.000_DMSO_0.025'
training_data = training_data[
    training_data["oneb_Metadata_Treatment_Dose_Inhibitor_Dose"].isin(
        [control, treatment]
    )
]

In [None]:
# at random downsample the DMSO treatment to match the number of wells in the LPS treatment
# get the number of wells in the LPS treatment
trt_wells = training_data[
    training_data["oneb_Metadata_Treatment_Dose_Inhibitor_Dose"] == treatment
].shape[0]
# get the number of wells in the DMSO treatment
dmso_wells = training_data[
    training_data["oneb_Metadata_Treatment_Dose_Inhibitor_Dose"] == control
].shape[0]
if dmso_wells > trt_wells:
    # downsample the DMSO treatment to match the number of wells in the LPS treatment
    dmso_holdout = training_data[
        training_data["oneb_Metadata_Treatment_Dose_Inhibitor_Dose"] == control
    ].sample(n=trt_wells)
    # remove the downsampled DMSO wells from the data
    training_data = training_data.drop(dmso_holdout.index)
else:
    pass

In [None]:
# set model path from parameters
if (aggregation == True) and (nomic == True):
    model_path = pathlib.Path(
        f"models/single_class/{cell_type}/aggregated_with_nomic/{control}__{treatment}"
    )
elif (aggregation == True) and (nomic == False):
    model_path = pathlib.Path(
        f"models/single_class/{cell_type}/aggregated/{control}__{treatment}"
    )
elif (aggregation == False) and (nomic == True):
    model_path = pathlib.Path(
        f"models/single_class/{cell_type}/sc_with_nomic/{control}__{treatment}"
    )
elif (aggregation == False) and (nomic == False):
    model_path = pathlib.Path(
        f"models/single_class/{cell_type}/sc/{control}__{treatment}"
    )
else:
    print("Error")

In [None]:
model_types = ["final", "shuffled_baseline"]
feature_types = ["CP"]
phenotypic_classes = [
    training_data["oneb_Metadata_Treatment_Dose_Inhibitor_Dose"].unique()[0]
]

In [None]:
# define metadata columns
# subset each column that contains metadata
metadata = training_data.filter(regex="Metadata")
# drop all metadata columns
data_x = training_data.drop(metadata.columns, axis=1)
labeled_data = training_data["oneb_Metadata_Treatment_Dose_Inhibitor_Dose"]

In [None]:
# set path for figures
if (aggregation == True) and (nomic == True):
    figure_path = pathlib.Path(
        f"./figures/single_class/{cell_type}/aggregated_with_nomic/{control}__{treatment}"
    )
elif (aggregation == True) and (nomic == False):
    figure_path = pathlib.Path(
        f"./figures/single_class/{cell_type}/aggregated/{control}__{treatment}"
    )
elif (aggregation == False) and (nomic == True):
    figure_path = pathlib.Path(
        f"./figures/single_class/{cell_type}/sc_with_nomic/{control}__{treatment}"
    )
elif (aggregation == False) and (nomic == False):
    figure_path = pathlib.Path(
        f"./figures/single_class/{cell_type}/sc/{control}__{treatment}"
    )
else:
    print("Error")
figure_path.mkdir(parents=True, exist_ok=True)

In [None]:
# test model on testing data
for model_type, feature_type, phenotypic_class in itertools.product(
    model_types, feature_types, phenotypic_classes
):
    print(model_type, feature_type, phenotypic_class)
    # load model
    model = load(f"../1.train_models/{model_path}/{model_type}__{feature_type}.joblib")

    # get predictions
    predictions = model.predict(data_x)

    # get probabilities

    probabilities = model.predict_proba(data_x)

    # get accuracy
    accuracy = accuracy_score(labeled_data, predictions)

    # get f1 score
    f1 = f1_score(labeled_data, predictions, average="weighted")

    # plot confusion matrix heatmap
    sns.heatmap(confusion_matrix(labeled_data, predictions), annot=True, fmt="g")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(
        f"Confusion matrix of {model_type} model on {feature_type} features for {phenotypic_class}"
    )
    plt.savefig(f"{figure_path}/{model_type}__{feature_type}__{phenotypic_class}.png")
    plt.show()