# scClone2DR Simulation Benchmark

This notebook performs a comprehensive benchmark comparison of the scClone2DR model against several baseline methods using simulated data.

## Overview

The benchmark evaluates different approaches for predicting drug response from single-cell data:

1. **scClone2DR** - The full single-cell clone to drug response model
2. **Factorization Machine (FM)** - Matrix factorization approach
3. **Neural Network (NN)** - Deep learning baseline
4. **Dual Bulk** - Bimodal aggregation method
5. **Bulk** - Simple bulk aggregation
6. **Baseline** - Basic model without clone information

## Workflow

1. **Data Generation**: Create simulated training data with known ground truth
2. **Train/Test Split**: Split data into training (50%) and testing (50%) sets
3. **Model Training**: Train each model variant with appropriate regularization
4. **Evaluation**: Compare models using:
   - Drug effect prediction accuracy (L¹ error)
   - Drug score predictions (MSE)
   - Fold change correlations (explained variance)

## Parameters

- `setting`: Difficulty level ("easy", "very_easy", or "hard")
- `n_steps`: Number of training steps (default: 2000)
- Regularization: L1 and L2 penalties

## Output

The notebook generates comparison plots showing:
- L¹ error curves for drug effect predictions
- Violin plots of drug scores with MSE comparison
- Scatter plots of predicted vs. true fold changes for each model

In [None]:
import sys
sys.path.append('../')
import scClone2DR as sccdr
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
n_steps = 2000
np.float_ = np.float64

In [None]:
modelscClone2DR = sccdr.models.scClone2DR()

setting = "easy"
if setting=="easy":
    R = 20
    neg_bin = 100
elif setting=="very_veasy":
    R = 30
    neg_bin = 10000
else:
    R = 5
    neg_bin = 2
data_ref = modelscClone2DR.get_simulated_training_data({'C':24,'R':R,'N':100,'Kmax':7, 'D':30, 'theta_rna':15}, neg_bin_n=neg_bin, mode_nu="noise_correction", mode_theta="not shared decoupled")
data_ref['pi'] = modelscClone2DR.compute_survival_probas_subclone_features(data_ref, data_ref)
idxs_train = [i for i in range(int(0.5*data_ref['N']))]
idxs_test = [i for i in range(int(0.5*data_ref['N']), data_ref['N'])]

data_train, data_test = modelscClone2DR.get_data_split_simu(data_ref, idxs_train, idxs_test)

In [None]:
if setting not in ["easy", "very_easy"]:
    params_svi = modelscClone2DR.train(data_train, penalty_l1=0.1, penalty_l2=0.1 , n_steps=n_steps)
else:
    params_svi = modelscClone2DR.train(data_train, penalty_l1=0.02, penalty_l2=0.02 , n_steps=n_steps)

In [None]:
params_svi = modelscClone2DR.convert_to_tensor(params_svi)
params_svi['pi'] = modelscClone2DR.compute_survival_probas_subclone_features(data_ref, params_svi)
modelscClone2DR.compute_all_stats(data_ref, data_ref, params_svi)

# Bulk

In [None]:
model_bulk = sccdr.models.scClone2DR()
databulk = model_bulk.get_bulk_from_data(data_ref)
data_train_bulk = model_bulk.get_bulk_from_data(data_train)
data_test_bulk = model_bulk.get_bulk_from_data(data_test)
if setting not in ["easy", "very_easy"]:
    params_svi_bulk = model_bulk.train(data_train_bulk, penalty_l1=0.1, penalty_l2=0.1 , n_steps=n_steps)
else:
    params_svi_bulk = model_bulk.train(data_train_bulk, penalty_l1=0.02, penalty_l2=0.02 , n_steps=n_steps)

In [None]:
model_bulk.compute_all_stats4bulk_or_bimodal(modelscClone2DR, data_ref, databulk, params_svi_bulk)

# Bimodal

In [None]:
model_bimodal = sccdr.models.scClone2DR()
databimodal = model_bimodal.get_bimodal_from_data(data_ref)
data_train_bimodal = model_bimodal.get_bimodal_from_data(data_train)
data_test_bimodal = model_bimodal.get_bimodal_from_data(data_test)

if setting not in ["easy", "very_easy"]:
    params_svi_bimodal = model_bimodal.train(data_train_bimodal, penalty_l1=0.1, penalty_l2=0.1 , n_steps=n_steps)
else:
    params_svi_bimodal = model_bimodal.train(data_train_bimodal, penalty_l1=0.02, penalty_l2=0.02 , n_steps=n_steps)

In [None]:
model_bimodal.compute_all_stats4bulk_or_bimodal(modelscClone2DR, data_ref, databimodal, params_svi_bimodal)

# Base

In [None]:
model_base = sccdr.models.scClone2DR()
database = model_base.get_base_from_data(data_ref)
data_train_base = model_base.get_base_from_data(data_train)
data_test_base = model_base.get_base_from_data(data_test)

if setting not in ["easy", "very_easy"]:
    params_svi_base = model_base.train(data_train_base, penalty_l1=0.1, penalty_l2=0.1 , n_steps=n_steps)
else:
    params_svi_base = model_base.train(data_train_base, penalty_l1=0.02, penalty_l2=0.02 , n_steps=n_steps)

In [None]:
params_svi_base = model_base.convert_to_tensor(params_svi_base)
params_svi_base['pi'] = model_base.compute_survival_probas_subclone_features(database, params_svi_base)
model_base.compute_all_stats(data_ref, data_ref, params_svi_base)

# Factorization machine

In [None]:
Kmax, N, latent_dim = data_train['X'].shape
N = data_train['n_r'].shape[2]
D = data_train['D']
modelFM = sccdr.models.FM(modelscClone2DR.cluster2clonelabel, modelscClone2DR.clonelabel2cat)
modelFM.train(data_train)
modelFM.eval(data_ref, true_params = {'pi': data_ref['pi']})

modelFM_trueprops = sccdr.models.FM(modelscClone2DR.cluster2clonelabel, modelscClone2DR.clonelabel2cat, use_true_proportions=True)
modelFM_trueprops.train(data_train)
modelFM_trueprops.eval(data_ref, true_params = {'pi': data_ref['pi']})

# Neural Network

In [None]:
Kmax, N, latent_dim = data_train['X'].shape
N = data_train['n_r'].shape[2]
D = data_train['D']
modelNN = sccdr.models.NN(modelscClone2DR.cluster2clonelabel, modelscClone2DR.clonelabel2cat)
modelNN.train(data_train, data_ref['beta'])
modelNN.eval(data_ref, true_params = {'pi': data_ref['pi'], 'beta':data_ref['beta']})

modelNN_trueprops = sccdr.models.NN(modelscClone2DR.cluster2clonelabel, modelscClone2DR.clonelabel2cat, use_true_proportions=True)
modelNN_trueprops.train(data_train, data_ref['beta'])
modelNN_trueprops.eval(data_ref, true_params = {'pi': data_ref['pi'], 'beta':data_ref['beta']})

# Visualizing results

In [None]:
res = {}
res['us'] = modelscClone2DR.results
res['base'] = model_base.results
res['bulk'] = model_bulk.results
res['bimodal'] = model_bimodal.results
res['fm'] = modelFM.results
res['nn'] = modelNN.results
import seaborn as sns
colors_models = sns.color_palette('Set2')

models = ['us','fm','nn','bimodal','bulk', 'base']
model2name = {'us':'scClone2DR', 'base':'Baseline', 'bimodal':'Dual bulk','bulk':'Bulk','fm':'FM','nn':'NN', 
              'fm_true_props': 'FM true props','nn_true_props':'NN true props'}

In [None]:
for m in models:
    x = np.abs(1.-res[m]['true_drug_effects'].numpy())
    argidxs = np.argsort(x)
    argidxs_dec = np.flip(argidxs, axis=0)
    x = x[argidxs_dec]
    err = np.cumsum(np.abs(res[m]['true_drug_effects'].numpy()-res[m]['drug_effects'].numpy())[argidxs_dec])/np.cumsum(np.ones(len(x)))
    plt.plot(x, (err), label=model2name[m])
plt.legend(loc=2)
plt.ylabel('$L^1$ error on the drug effects', fontsize=14)
plt.xlabel('Threshold', fontsize=14)
plt.show()

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import collections

# Initialize the dictionary
dic = {'method': [], "drug scores": [], "statistic": [], 'ground_truth':[]}
mses = {}

# Loop through models and calculate MSE
for m in models:
    estimated_scores = np.array(res[m]['drug_scores'])
    true_scores = np.array(res[m]['true_drug_scores'])
    
    # Calculate MSE
    mse = np.mean((estimated_scores - true_scores) ** 2)
    mses[model2name[m]] = mse
    
    # Populate dictionary
    dic['method'] += [model2name[m] for i in range(2 * len(estimated_scores))]
    dic['drug scores'] += list(estimated_scores)
    dic['drug scores'] += list(true_scores)
    dic['ground_truth'] += [False for i in range(len(estimated_scores))]
    dic['ground_truth'] += [True for i in range(len(true_scores))]
    dic['statistic'] += ['estimated ' for i in range(len(estimated_scores))]
    dic['statistic'] += ['ground truth' for i in range(len(true_scores))]

# Convert to DataFrame
df = pd.DataFrame(dic)

# Set theme and style
sns.set_theme(style='white')
sns.set_style("ticks", {"xtick.major.size": 12, "ytick.major.size": 12})

# Create figure with two subplots
fig, (ax_violin, ax_bar) = plt.subplots(nrows=2, ncols=1, figsize=(10, 8), gridspec_kw={'height_ratios': [3, 1]})

# Violin plot on top
colors = {m:colors_models[i] for i,m in enumerate((models))}
models = ['us','fm','nn','bimodal','bulk', 'base']
colors = {}
colors['us'] = "#009E73"
colors['fm'] = "#0072B2"
colors['nn'] = "#56B4E9"
colors['bimodal'] = "#D55E00"
colors['bulk'] = "#E69F00"
colors['base'] = "#F0E442"


ax = sns.violinplot(data=df, x="method", y="drug scores", hue="statistic", split=True, inner="quart", density_norm="width", legend = False, ax=ax_violin)
for ind, violin in enumerate(ax.findobj(collections.PolyCollection)):
    rgb = colors[models[ind//2]]
    if ind % 2 != 0:
        rgb = "gray"  # make white
    violin.set_facecolor(rgb)
ax_violin.plot([],[], linewidth=10, c="gray", label='Ground truth')
ax_violin.legend(fontsize=15)
ax_violin.set_ylabel('Drug Scores', fontsize=19)
ax_violin.set_xlabel('')
ax_violin.legend(loc=1, fontsize=19,ncol=2)
ax_violin.tick_params(axis='x', rotation=40)
ax_violin.set_xticks([])  # Remove the x-ticks
ax_violin.set_xticklabels([])  # Remove the x-tick labels

# MSE bar plot on the bottom
methods = list(mses.keys())
mse_values = list(mses.values())
ax_bar.bar(methods, mse_values, color=[colors[mod] for mod in (models)])
ax_bar.set_ylabel('MSE', fontsize=19)
#ax_bar.set_xlabel('Method', fontsize=18)
ax_bar.tick_params(axis='x', rotation=30, labelsize=19)

# Align the x-ticks and make sure both plots share the same x-axis
ax_bar.set_xticks(range(len(methods)))
ax_bar.set_xticklabels(methods)

# Save and show plot
plt.tight_layout()
ax_violin.text(-1.25,1.45, "(b)", fontsize=20)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.gridspec import GridSpec
from sklearn.metrics import explained_variance_score


size_dots = 2

custom_font_size = 20
custom_fontsize_labelsaxis = 16
fontsizelegend = 18
tit = ['scClone2DR', 'FM', 'NN', 'Dual bulk', 'Bulk', 'Baseline',  'FM model (true props)', 'NN model (true props)']

# Create subplots
fig = plt.figure(figsize=(12, 8))
gs = GridSpec(2, 3, width_ratios=[1, 1, 1])

# Plot for model 1
ax00 = fig.add_subplot(gs[0, 0])
fc_true = res['us']['fold_change_true']
fc_pred = res['us']['fold_change_pred']
corr = np.round(np.corrcoef(fc_true, fc_pred)[0,1]**2, 3)
rsquare = np.round(explained_variance_score(fc_true, fc_pred), 3)
ax00.scatter(fc_true, fc_pred, label='{0}'.format(np.round(corr,3)), s=size_dots, c=colors['us'])
ax00.set_title('{0}'.format(tit[0]), fontsize=custom_font_size)
ax00.set_xlabel('Fold change (ground truth)', fontsize=custom_fontsize_labelsaxis)
ax00.set_ylabel('Fold change (predicted)', fontsize=custom_fontsize_labelsaxis)
m = min([np.min(fc_pred),np.min(fc_true)])
ax00.set_ylim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax00.text(m+0.02,m+0.02, "Explained variance: {0}".format(corr,rsquare), fontsize=14)


# Plot for model 2
ax01 = fig.add_subplot(gs[0, 1])
fc_true = res['us']['fold_change_true']
fc_pred = res['fm']['fold_change_pred']
corr = np.round(np.corrcoef(fc_true, fc_pred)[0,1]**2, 3)
rsquare = np.round(explained_variance_score(fc_true, fc_pred), 3)
ax01.scatter(fc_true, fc_pred, label='{0}'.format(np.round(corr,3)), s=size_dots, c=colors['fm'])
ax01.set_title('{0}'.format(tit[1]), fontsize=custom_font_size)
ax01.set_xlabel('Fold change (ground truth)', fontsize=custom_fontsize_labelsaxis)
ax01.set_ylabel('Fold change (predicted)', fontsize=custom_fontsize_labelsaxis)
m = min([np.min(fc_pred),np.min(fc_true)])
ax01.set_ylim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax01.set_xlim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax01.text(m+0.02,m+0.02, "Explained variance: {0}".format(corr,rsquare), fontsize=14)

# Plot for model 3
ax10 = fig.add_subplot(gs[0, 2])
fc_true = res['us']['fold_change_true']
fc_pred = res['nn']['fold_change_pred']
corr = np.round(np.corrcoef(fc_true, fc_pred)[0,1]**2, 3)
rsquare = np.round(explained_variance_score(fc_true, fc_pred), 3)
ax10.scatter(fc_true, fc_pred, label='{0}'.format(np.round(corr,3)), s=size_dots, c=colors['nn'])
ax10.set_title('{0}'.format(tit[2]), fontsize=custom_font_size)
ax10.set_xlabel('Fold change (ground truth)', fontsize=custom_fontsize_labelsaxis)
ax10.set_ylabel('Fold change (predicted)', fontsize=custom_fontsize_labelsaxis)
m = min([np.min(fc_pred),np.min(fc_true)])
ax10.set_ylim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax10.set_xlim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax10.text(m+0.02,m+0.02, "Explained variance: {0}".format(corr,rsquare), fontsize=14)

# Plot for model 4
ax11 = fig.add_subplot(gs[1, 0])
fc_true = res['us']['fold_change_true']
fc_pred = res['bimodal']['fold_change_pred']
corr = np.round(np.corrcoef(fc_true, fc_pred)[0,1]**2, 3)
rsquare = np.round(explained_variance_score(fc_true, fc_pred), 3)
ax11.scatter(fc_true, fc_pred, label='{0}'.format(np.round(corr,3)), s=size_dots, c=colors['bimodal'])
ax11.set_title('{0}'.format(tit[3]), fontsize=custom_font_size)
ax11.set_xlabel('Fold change (ground truth)', fontsize=custom_fontsize_labelsaxis)
ax11.set_ylabel('Fold change (predicted)', fontsize=custom_fontsize_labelsaxis)
m = min([np.min(fc_pred),np.min(fc_true)])
ax11.set_ylim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax11.set_xlim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax11.text(m+0.02,m+0.02, "Explained variance: {0}".format(corr,rsquare), fontsize=14)


# Plot for model 5
ax12 = fig.add_subplot(gs[1, 1])
fc_true = res['us']['fold_change_true']
fc_pred = res['bulk']['fold_change_pred']
corr = np.round(np.corrcoef(fc_true, fc_pred)[0,1]**2, 3)
rsquare = np.round(explained_variance_score(fc_true, fc_pred), 3)
ax12.scatter(fc_true, fc_pred, label='{0}'.format(np.round(corr,3)), s=size_dots, c=colors['bulk'])
ax12.set_title('{0}'.format(tit[4]), fontsize=custom_font_size)
ax12.set_xlabel('Fold change (ground truth)', fontsize=custom_fontsize_labelsaxis)
ax12.set_ylabel('Fold change (predicted)', fontsize=custom_fontsize_labelsaxis)
m = min([np.min(fc_pred),np.min(fc_true)])
ax12.set_ylim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax12.set_xlim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax12.text(m+0.02,m+0.02, "Explained variance: {0}".format(corr,rsquare), fontsize=14)

# Plot for model 6
ax02 = fig.add_subplot(gs[1, 2])
fc_true = res['us']['fold_change_true']
fc_pred = res['base']['fold_change_pred']
corr = np.round(np.corrcoef(fc_true, fc_pred)[0,1]**2, 3)
rsquare = np.round(explained_variance_score(fc_true, fc_pred), 3)
ax02.scatter(fc_true, fc_pred, label='{0}'.format(np.round(corr,3)), s=size_dots, c=colors['base'])
ax02.set_title('{0}'.format(tit[5]), fontsize=custom_font_size)
m = min([np.min(fc_pred),np.min(fc_true)])
ax02.set_xlabel('Fold change (ground truth)', fontsize=custom_fontsize_labelsaxis)
ax02.set_ylabel('Fold change (predicted)', fontsize=custom_fontsize_labelsaxis)
ax02.set_ylim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax02.set_xlim(min([np.min(fc_pred),np.min(fc_true)]), max([np.max(fc_pred),np.max(fc_true)]))
ax02.text(m+0.02,m+0.02, "Explained variance: {0}".format(corr,rsquare), fontsize=14)

# Create a colorbar
# Adjust layout
plt.tight_layout()

ax00.text(-0.3,0.083, "(a)", fontsize=20)

# Show the plot
plt.show()
