# Assess generalizability of the model by using drug dose curve data

**NOTE:** For assess generalizability based on a drug dose response, we will be using Plates 1 and 2, split by heart number (all failing hearts).

## Import libraries

In [1]:
import pathlib
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import seaborn as sns
from joblib import load
from sklearn.metrics import precision_recall_curve

sys.path.append("../utils")
from eval_utils import generate_confusion_matrix_df, generate_f1_score_df
from training_utils import get_X_y_data

## Set paths and variables

In [2]:
# Directory with plate datasets
data_dir = pathlib.Path("../3.process_cfret_features/data/single_cell_profiles")

# Directory with models
models_dir = pathlib.Path("./models")

# Directory for model figures output
fig_dir = pathlib.Path("./figures")
fig_dir.mkdir(exist_ok=True)

# Directory for probability data to be saved
prob_dir = pathlib.Path("./prob_data")
prob_dir.mkdir(exist_ok=True)

# Load in each model individually
final_model = load(
    pathlib.Path(f"{models_dir}/log_reg_fs_plate_4_final_downsample.joblib")
)
shuffled_model = load(
    pathlib.Path(f"{models_dir}/log_reg_fs_plate_4_shuffled_downsample.joblib")
)

## Load in Plate 4 fs data to extract column names to filter from the other plates

In [3]:
# Load in Plate 4 normalized feature selected data metadata (used with model) to get the feature columns to filter the plate data
parquet_metadata = pq.read_metadata(
    pathlib.Path(f"{data_dir}/localhost231120090001_sc_feature_selected.parquet")
)

# Get the column names from the metadata
all_column_names = parquet_metadata.schema.names

# Filter out the column names that start with "Metadata_"
model_column_names = [
    col for col in all_column_names if not col.startswith("Metadata_")
]

print(len(model_column_names))
print(model_column_names)

625
['Cytoplasm_AreaShape_BoundingBoxArea', 'Cytoplasm_AreaShape_Compactness', 'Cytoplasm_AreaShape_Eccentricity', 'Cytoplasm_AreaShape_Extent', 'Cytoplasm_AreaShape_FormFactor', 'Cytoplasm_AreaShape_MajorAxisLength', 'Cytoplasm_AreaShape_MinorAxisLength', 'Cytoplasm_AreaShape_Perimeter', 'Cytoplasm_AreaShape_Solidity', 'Cytoplasm_AreaShape_Zernike_0_0', 'Cytoplasm_AreaShape_Zernike_1_1', 'Cytoplasm_AreaShape_Zernike_2_0', 'Cytoplasm_AreaShape_Zernike_2_2', 'Cytoplasm_AreaShape_Zernike_3_1', 'Cytoplasm_AreaShape_Zernike_3_3', 'Cytoplasm_AreaShape_Zernike_4_0', 'Cytoplasm_AreaShape_Zernike_4_2', 'Cytoplasm_AreaShape_Zernike_5_1', 'Cytoplasm_AreaShape_Zernike_5_3', 'Cytoplasm_AreaShape_Zernike_6_0', 'Cytoplasm_AreaShape_Zernike_6_2', 'Cytoplasm_AreaShape_Zernike_7_1', 'Cytoplasm_AreaShape_Zernike_7_3', 'Cytoplasm_AreaShape_Zernike_8_0', 'Cytoplasm_AreaShape_Zernike_8_2', 'Cytoplasm_AreaShape_Zernike_8_4', 'Cytoplasm_AreaShape_Zernike_8_6', 'Cytoplasm_AreaShape_Zernike_9_1', 'Cytoplasm_Ar

## Load in Plates 1 and 2, concat vertically, and drop any rows where there are NaNs in the feature columns from the model

In [4]:
# Load in Plate 1 and 2 data -> concat
plate_1_df = pd.read_parquet(
    pathlib.Path(f"{data_dir}/localhost220512140003_KK22-05-198_sc_normalized.parquet")
)
plate_2_df = pd.read_parquet(
    pathlib.Path(
        f"{data_dir}/localhost220513100001_KK22-05-198_FactinAdjusted_sc_normalized.parquet"
    )
)

# Concat separate parts of the same plate together
concatenated_df = pd.concat([plate_1_df, plate_2_df], axis=0)

# Drop rows with NaN values in feature columns that the model uses
concatenated_df = concatenated_df.dropna(subset=model_column_names)

print(concatenated_df.shape)
concatenated_df.head()

(40645, 2022)


Unnamed: 0,Metadata_WellRow,Metadata_WellCol,Metadata_heart_number,Metadata_treatment,Metadata_dose,Metadata_dose_unit,Metadata_Nuclei_Location_Center_X,Metadata_Nuclei_Location_Center_Y,Metadata_Cells_Location_Center_X,Metadata_Cells_Location_Center_Y,...,Nuclei_Texture_Variance_Hoechst_3_02_256,Nuclei_Texture_Variance_Hoechst_3_03_256,Nuclei_Texture_Variance_Mitochondria_3_00_256,Nuclei_Texture_Variance_Mitochondria_3_01_256,Nuclei_Texture_Variance_Mitochondria_3_02_256,Nuclei_Texture_Variance_Mitochondria_3_03_256,Nuclei_Texture_Variance_PM_3_00_256,Nuclei_Texture_Variance_PM_3_01_256,Nuclei_Texture_Variance_PM_3_02_256,Nuclei_Texture_Variance_PM_3_03_256
0,A,1,3,drug_x,5.0,uM,220.491103,103.878114,238.995447,104.179514,...,-0.134972,-0.135704,-0.359869,-0.358051,-0.363576,-0.358844,-0.226031,-0.225456,-0.226659,-0.225719
1,A,1,3,drug_x,5.0,uM,77.765782,148.827081,92.415993,130.301826,...,0.349497,0.343469,-0.355964,-0.35421,-0.354343,-0.349994,-0.17011,-0.171281,-0.167815,-0.170193
2,A,1,3,drug_x,5.0,uM,180.302253,171.314685,199.191002,189.827676,...,-0.224982,-0.233478,-0.367058,-0.365434,-0.36926,-0.363866,-0.200103,-0.202045,-0.201089,-0.197016
3,A,1,3,drug_x,5.0,uM,419.328201,196.226094,411.240573,178.878571,...,-0.19548,-0.193045,-0.36079,-0.360166,-0.363218,-0.35698,-0.233946,-0.234798,-0.234972,-0.233913
4,A,1,3,drug_x,5.0,uM,765.912882,208.396664,777.621076,212.232104,...,-0.071024,-0.076007,-0.362509,-0.360042,-0.363735,-0.360015,-0.215742,-0.215914,-0.217449,-0.216141


## Filter the concat data to only include metadata and filtered feature columns

In [5]:
# Extract metadata columns from the plate
metadata_columns = [col for col in concatenated_df.columns if col.startswith("Metadata_")]

# Extract feature columns that don't start with "Metadata_"
feature_columns = [col for col in concatenated_df.columns if not col.startswith("Metadata_")]

# Filter columns in data frame to only include those in the model
filtered_feature_columns = [
    col for col in concatenated_df.columns if col in model_column_names
]

# Filter the DataFrame to keep only the desired columns
concatenated_df = concatenated_df[metadata_columns + filtered_feature_columns]

concatenated_df

Unnamed: 0,Metadata_WellRow,Metadata_WellCol,Metadata_heart_number,Metadata_treatment,Metadata_dose,Metadata_dose_unit,Metadata_Nuclei_Location_Center_X,Metadata_Nuclei_Location_Center_Y,Metadata_Cells_Location_Center_X,Metadata_Cells_Location_Center_Y,...,Nuclei_Texture_InfoMeas2_PM_3_03_256,Nuclei_Texture_InverseDifferenceMoment_ER_3_01_256,Nuclei_Texture_InverseDifferenceMoment_ER_3_03_256,Nuclei_Texture_InverseDifferenceMoment_Mitochondria_3_00_256,Nuclei_Texture_InverseDifferenceMoment_Mitochondria_3_02_256,Nuclei_Texture_InverseDifferenceMoment_PM_3_01_256,Nuclei_Texture_InverseDifferenceMoment_PM_3_03_256,Nuclei_Texture_SumVariance_ER_3_01_256,Nuclei_Texture_SumVariance_Mitochondria_3_03_256,Nuclei_Texture_SumVariance_PM_3_01_256
0,A,1,3,drug_x,5.0,uM,220.491103,103.878114,238.995447,104.179514,...,-3.228784,0.767772,0.163633,1.895080,1.437233,0.799532,0.566632,-0.299555,-0.336081,-0.214493
1,A,1,3,drug_x,5.0,uM,77.765782,148.827081,92.415993,130.301826,...,-0.664698,-0.352643,-0.413803,1.633057,1.705154,-0.159620,-0.384002,-0.017753,-0.326756,-0.172044
2,A,1,3,drug_x,5.0,uM,180.302253,171.314685,199.191002,189.827676,...,0.333811,0.421589,0.789205,1.605760,1.458741,0.828915,0.968036,-0.295674,-0.342126,-0.189797
3,A,1,3,drug_x,5.0,uM,419.328201,196.226094,411.240573,178.878571,...,-2.050903,0.596287,0.540479,2.103799,1.668091,1.228932,1.350530,-0.300708,-0.332476,-0.221203
4,A,1,3,drug_x,5.0,uM,765.912882,208.396664,777.621076,212.232104,...,-1.140967,1.001414,1.257948,1.597395,1.656660,0.805992,0.780507,-0.315316,-0.337728,-0.203520
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9801,H,12,9,DMSO,0.0,,142.307651,819.552072,162.786064,832.028852,...,1.161073,0.614588,1.077641,0.971698,1.189705,0.563754,0.835118,-0.396330,-0.165633,-0.145272
9802,H,12,9,DMSO,0.0,,459.882963,858.093333,421.880624,862.924633,...,0.295518,-1.421432,-0.530030,-0.067212,0.518285,-0.563723,-0.262270,0.259311,-0.156618,-0.143351
9803,H,12,9,DMSO,0.0,,567.998175,864.539538,577.212095,845.292466,...,0.492783,-1.071568,-1.748143,-0.178144,0.260716,-0.746079,-0.404571,0.590393,-0.128399,-0.133301
9804,H,12,9,DMSO,0.0,,884.287577,900.474949,874.227928,889.731832,...,-1.210057,-0.967518,-1.032382,-0.075027,0.142006,0.150642,-0.063021,0.049348,-0.172297,-0.182117


## Create a dictionary with concat dataframe splitting the data by the heart number

In [6]:
# Split the plate data into different data frames with different data for applying the model to

# Define a dictionary
plate_1_2_dfs_dict = {}

# Filter the DataFrame to a data frame per treatment
three_df = concatenated_df[concatenated_df["Metadata_heart_number"] == 3]
eight_df = concatenated_df[concatenated_df["Metadata_heart_number"] == 8]
nine_df = concatenated_df[concatenated_df["Metadata_heart_number"] == 9]

# Add each DataFrame to the dictionary with a corresponding key
plate_1_2_dfs_dict["heart_3"] = {"data_df": three_df}
plate_1_2_dfs_dict["heart_8"] = {"data_df": eight_df}
plate_1_2_dfs_dict["heart_9"] = {"data_df": nine_df}

## Extract final model predicted probabilities for each heart number

In [7]:
# Create an empty DataFrame to store the results
combined_prob_df = pd.DataFrame()

for model_path in models_dir.iterdir():
    if model_path.is_dir() or model_path.suffix != ".joblib":
        continue  # Skip directories or files that are not model files
    
    model_type = model_path.stem.split("_")[5]  # Get the model type

    for data, info in plate_1_2_dfs_dict.items():
        # Ensure that the file is named the correct data split
        data_split = data
        print(f"Extracting {model_type} probabilities from {data} data...")

        # Load in model to apply to datasets
        model = load(model_path)

        # Load in label encoder
        le = load(
            pathlib.Path("./encoder_results/label_encoder_log_reg_fs_plate_4.joblib")
        )

        # Get unique cell types and their corresponding encoded values
        unique_labels = le.classes_
        encoded_values = le.transform(unique_labels)

        # Create a dictionary mapping encoded values to original labels
        label_dict = dict(zip(encoded_values, unique_labels))

        # Load in data frame associated with the data split
        data_df = info["data_df"].reset_index(drop=True)

        # Load in X data to get predicted probabilities
        X, _ = get_X_y_data(df=data_df, label="Metadata_heart_number")

        # Predict class probabilities for morphology feature data
        predicted_probs = model.predict_proba(X)

        # Storing probabilities in a pandas DataFrame
        prob_df = pd.DataFrame(predicted_probs, columns=model.classes_)

        # Update column names in prob_df using the dictionary and add suffix "_probas"
        prob_df.columns = [label_dict[col] + "_probas" for col in prob_df.columns]

        # Add a new column called predicted_label for each row
        prob_df["predicted_label"] = prob_df.apply(
            lambda row: row.idxmax()[:-7], axis=1
        )

        # Select metadata columns from the data
        metadata_columns = data_df.filter(like="Metadata_")

        # Combine metadata columns with predicted probabilities DataFrame based on index
        prob_df = prob_df.join(metadata_columns)

        # Add a new column for model_type
        prob_df["model_type"] = model_type

        # Append the probability DataFrame to the combined DataFrame
        combined_prob_df = pd.concat([combined_prob_df, prob_df], ignore_index=True)

# Save combined prob data
combined_prob_df.to_csv(f"{prob_dir}/combined_plates_1_2_predicted_proba.csv", index=False)

Extracting shuffled probabilities from heart_3 data...
Extracting shuffled probabilities from heart_8 data...
Extracting shuffled probabilities from heart_9 data...
Extracting final probabilities from heart_3 data...
Extracting final probabilities from heart_8 data...
Extracting final probabilities from heart_9 data...
