In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from druxai.models.NN_flexible import Interaction_Model
from druxai.utils.data import DrugResponseDataset
from druxai.utils.dataframe_utils import split_data_by_cell_line_ids, standardize_molecular_data_inplace
from torch.utils.data import DataLoader
import pandas as pd
from scipy.stats import spearmanr
from druxai.utils.plotting import plot_ordered_r_scores
from types import SimpleNamespace

In [None]:
# Checkpoint path
# Always safe specific config with a run! Important especially if seed is changed because then we have other splits!

ckpt_path = "/Users/niklaskiermeyer/Desktop/Codespace/DruxAI/results/training/ckpt_sgd_baseline.pt"

checkpoint = torch.load(ckpt_path)
cfg = checkpoint["config"]

In [None]:
# Load data
data = DrugResponseDataset(cfg["DATA_PATH"])
train_id, val_id, test_id = split_data_by_cell_line_ids(data.targets)
standardize_molecular_data_inplace(data, train_id=train_id, val_id=val_id, test_id=test_id)

val_loader = DataLoader(
    data,
    sampler=val_id,
    batch_size=128,
    shuffle=False,
    pin_memory=True,
    num_workers=6,
    persistent_workers=True
)

train_loader = DataLoader(
    data,
    sampler=train_id,
    batch_size=128,
    shuffle=False,
    pin_memory=True,
    num_workers=6,
    persistent_workers=True
)

In [None]:
config = {
    "metric": {"name": "r2_val", "goal": "maximize"},
    "resume": False,
    "patience": 5,
    "epochs": 10,
    "optimizer": "sgd",
    "scheduler": "exponential",
    "loss": "huber",
    "batch_size": 64,
    "learning_rate": 0.1,
    "output_features": 10,
    "hidden_dims_drug_nn": [512],
    "hidden_dims_gene_expression_nn": [512],
    "dropout_drug_nn": 0.2,
    "dropout_gene_expression_nn": 0.2,
}

config = SimpleNamespace(**config)

In [None]:
model = Interaction_Model(
            data,
            config.output_features,
            config.hidden_dims_drug_nn,
            config.hidden_dims_gene_expression_nn,
            config.dropout_drug_nn,
            config.dropout_gene_expression_nn,
        )

model.load_state_dict(checkpoint["model"])
model.to("mps")
# Setup optimizers

predictions = []
targets = []
cell_lines = []
drugs = []
keys = []
with torch.no_grad():
    model.eval()
    for X, y, idx in val_loader:
        drug, molecular, outcome = (
            X["drug_encoding"].to("mps"),
            X["gene_expression"].to("mps"),
            y.to("mps")
        )

        # Assuming model.predict returns the predicted values
        prediction = model.forward(drug, molecular).reshape(-1).tolist()
        predictions.append(prediction)

        targets.append(y.reshape(-1).tolist())

        # Fetch cell line and drug from data.targets DataFrame
        cell_line = data.targets.iloc[idx]["cell_line"].to_list()
        compound = data.targets.iloc[idx]["DRUG"].to_list()

        cell_lines.append(cell_line)
        drugs.append(compound)
        keys.append([id.item() for id in idx])

In [None]:
data_dict = {
    "ID": keys,
    "Prediction": predictions,
    "Target": targets,
    "cell_line": cell_lines,
    "Drug": drugs
}

# Create a DataFrame from the dictionary
results_df_val = pd.DataFrame(data_dict).apply(pd.Series.explode)

print(f"Overall R-Score: {spearmanr(results_df_val['Prediction'], results_df_val['Target'])[0]} \n")
print(results_df_val.head())

In [None]:
# Calculate R score for each cell line
group_by_features = ["cell_line", "Drug"]
grouped_dfs_val = {}
for group_by_feature in group_by_features:
    grouped_dfs_val[group_by_feature] = plot_ordered_r_scores(results_df_val, group_by_feature)

In [None]:
mean_r_score = {}
for key in grouped_dfs_val:
    mean_r_score[key] = grouped_dfs_val[key]["R Score"].mean()

print("Mean R Score stratified by Cell Line and Drug")
mean_r_score

In [None]:
grouped_dfs_val["cell_line"]

In [None]:
frame = grouped_dfs_val["Drug"]
frame[frame["Group"]=="POZIOTINIB"]

In [None]:
grouped_dfs_val

## Prediction on Train set

In [None]:
model = Interaction_Model(data, config["hidden_dim"], config["output_features"], config["dropout_nn1"],
                          config["dropout_nn2"])

model.load_state_dict(checkpoint["model"])
model.to("mps")
# Setup optimizers

predictions = []
targets = []
cell_lines = []
drugs = []
keys = []
with torch.no_grad():
    model.eval()
    for X, y, idx in train_loader:
        drug, molecular, outcome = (
            X["drug_encoding"].to("mps"),
            X["gene_expression"].to("mps"),
            y.to("mps")
        )

        # Assuming model.predict returns the predicted values
        prediction = model.forward(drug, molecular).reshape(-1).tolist()
        predictions.append(prediction)

        targets.append(y.reshape(-1).tolist())

        # Fetch cell line and drug from data.targets DataFrame
        cell_line = data.targets.iloc[idx]["cell_line"].to_list()
        compound = data.targets.iloc[idx]["DRUG"].to_list()

        cell_lines.append(cell_line)
        drugs.append(compound)
        keys.append([id.item() for id in idx])

In [None]:
data_dict = {
    "ID": keys,
    "Prediction": predictions,
    "Target": targets,
    "cell_line": cell_lines,
    "Drug": drugs
}

# Create a DataFrame from the dictionary
results_df_train = pd.DataFrame(data_dict).apply(pd.Series.explode)

print(f"Overall R-Score: {spearmanr(results_df_train['Prediction'], results_df_train['Target'])[0]} \n")
print(results_df_train.head())

# Calculate R score for each cell line
group_by_features = ["cell_line", "Drug"]
grouped_dfs_train = {}
for group_by_feature in group_by_features:
    grouped_dfs_train[group_by_feature] = plot_ordered_r_scores(results_df_train, group_by_feature)

In [None]:
mean_r_score = {}
for key in grouped_dfs_train:
    mean_r_score[key] = grouped_dfs_train[key]["R Score"].mean()

mean_r_score

In [None]:
grouped_dfs_train["cell_line"]

In [None]:
grouped_dfs_train["Drug"]

## Check Correlation between cell_lines type count vs predicted values

Is there a correlation between a cell_line which does not have many observation in terms of prediction vs highly presented ones
Same for drugs. 



In [None]:
val_cell_line_ranks = data.targets.iloc[val_id].groupby("cell_line")["DRUG"].nunique().sort_values().rank()
val_cell_line_ranks = pd.DataFrame(val_cell_line_ranks)

df1 = grouped_dfs_val["cell_line"].set_index("Group")
df2 = val_cell_line_ranks

merged_df = df1.join(df2)
correlation = merged_df["R Score"].corr(merged_df["DRUG"])

print("Val Correlation between R Score and rank:", correlation)

In [None]:
train_cell_line_ranks = data.targets.iloc[train_id].groupby("cell_line")["DRUG"].nunique().sort_values().rank()
train_cell_line_ranks = pd.DataFrame(train_cell_line_ranks)

df1 = grouped_dfs_train["cell_line"].set_index("Group")
df2 = train_cell_line_ranks

merged_df = df1.join(df2)
correlation = merged_df["R Score"].corr(merged_df["DRUG"])

print("Train Correlation between R Score and rank:", correlation)