In [1]:
import json
import os
import warnings
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap, Normalize

import importlib
import utils  # Import your module
from utils import model_name_mapping, metrics_name_mapping

# After making changes to utils.py, reload it
importlib.reload(utils)

# filepath = Path(__file__).parent
filepath = Path(os.path.abspath(''))
print(filepath)

# ----------------------------------------------------------------
# datasets_order = ['CCLE', 'CTRPv2', 'GDSCv1', 'GDSCv2', 'gCSI']  # alphabetical order
datasets_order = ['gCSI', 'CCLE', 'GDSCv2', 'GDSCv1', 'CTRPv2']  # order by sample size

# metrics_name_mapping = {
#     "r2": "R²",
#     "mae": "MAE",
#     "rmse": "RMSE",
#     "stgr": "STGR",
#     "stgi": "STGI",
# }

# model_name_mapping = {
#     "deepcdr": "DeepCDR",
#     "graphdrp": "GraphDRP",
#     "hidra": "HiDRA",
#     "lgbm": "LGBM",
#     "tcnns": "tCNNS",
#     "uno": "UNO",
# }

/nfs/ml_lab/projects/improve/data/experiments/cross-dataset-drp-paper


In [None]:
datadir = Path('splits_averaged')
outdir = filepath / 'results_for_paper_revision'
os.makedirs(outdir, exist_ok=True)

# file_format = 'eps'
# file_format = 'jpeg'
# file_format = 'png'
file_format = 'tiff'
dpi = 600

filename = 'all_models_scores.csv'
canc_col_name = 'improve_sample_id'
drug_col_name = 'improve_chem_id'

# datasets_order = ['CCLE', 'CTRPv2', 'GDSCv1', 'GDSCv2', 'gCSI']  # alphabetical order
datasets_order = ['gCSI', 'CCLE', 'GDSCv2', 'GDSCv1', 'CTRPv2']  # order by sample size
show_plot = True

all_scores = pd.read_csv(filepath / datadir / filename, sep=',')
all_scores.iloc[:3,:]

Unnamed: 0,met,split,value,src,trg,model
0,mse,0,0.006295,CCLE,CCLE,deepcdr
1,mse,1,0.00595,CCLE,CCLE,deepcdr
2,mse,2,0.005129,CCLE,CCLE,deepcdr


In [3]:
# Specify the metric
metric_name = "r2"

# Specify the models you want to include
models_to_include = []  # Replace with your desired models
# models_to_include = ["deepcdr", "graphdrp", "hidra", "lgbm", "uno"]  # Replace with your desired models
# models_to_include = ["graphdrp"]  # Replace with your desired models

# Statistical Tests (Reviewer 3, comment 1)

In [4]:
comment_outdir = outdir / 'reviewer3_comment1'
comment_outdir.mkdir(parents=True, exist_ok=True)

from scipy.stats import wilcoxon
from itertools import combinations

df = all_scores.copy()

# Filtering for R² scores
r2_df = df[df['met'] == 'r2'][['src', 'trg', 'model', 'split', 'value']]
r2_df['model'] = r2_df['model'].map(model_name_mapping)

# Defining models and datasets
models = r2_df['model'].unique()
src_datasets = r2_df['src'].unique()
trg_datasets = r2_df['trg'].unique()

# Preparing results
results = []
skipped = []
alpha = 0.05 / 15  # Bonferroni correction for 15 pairwise comparisons

In [5]:
# Iterate over source-target pairs
for src in src_datasets:
    for trg in trg_datasets:
        # Data for this pair
        pair_df = r2_df[(r2_df['src'] == src) & (r2_df['trg'] == trg)]
        if pair_df.empty:
            continue

        model_pairs = combinations(models, 2)

        # Pairwise Wilcoxon tests
        for model1, model2 in model_pairs:
            scores1 = pair_df[pair_df['model'] == model1]['value'].values
            scores2 = pair_df[pair_df['model'] == model2]['value'].values
            
            if len(scores1) == 10 and len(scores2) == 10:  # Ensure 10 splits
                try:
                    stat, p = wilcoxon(scores1, scores2, alternative='two-sided')
                    mean_diff = np.mean(scores1 - scores2)
                    # if p < alpha:
                    results.append({
                        'src': src,
                        'trg': trg,
                        'model1': model1,
                        'model2': model2,
                        'median_model1': np.median(scores1),
                        'median_model2': np.median(scores2),
                        'mean_r2_diff': mean_diff,
                        'p_value': p,
                        'significant': p < alpha
                    })
                except:
                    print(f"Wilcoxon test failed for {model1} vs {model2} on {src} → {trg}: {e}")
            else:
                print(f"Skipping {model1} vs {model2} on {src} → {trg} due to insufficient splits: {len(scores1)} vs {len(scores2)}")
                skipped.append({
                    'src': src,
                    'trg': trg,
                    'model1': model1,
                    'model2': model2,
                    'len_model1': len(scores1),
                    'len_model2': len(scores2)
                })

In [None]:
# Saving significant results
res_df = pd.DataFrame(results)
res_df.to_csv(comment_outdir / 'wilcoxon_tests_r2_all_combos.csv', index=False)

# Generating boxplots for key pairs
for src in src_datasets:
    for trg in trg_datasets:
        pair_df = r2_df[(r2_df['src'] == src) & (r2_df['trg'] == trg)]
        res_df_src_trg_combo = res_df[(res_df['src'] == src) & (res_df['trg'] == trg)]
        res_df_src_trg_combo.to_csv(comment_outdir / f'wilcoxon_tests_r2_{src}_{trg}_combos.csv', index=False)
        plt.figure(figsize=(8, 6))
        # sns.boxplot(x='model', y='value', data=pair_df, palette='Set3')
        # sns.boxplot(x='model', y='value', data=pair_df, hue='model', palette='Set3', legend=False)
        sns.boxplot(x='model', y='value', data=pair_df, hue='model', palette='Set3', legend=False,
            showmeans=True, meanprops={"marker":"o", "markerfacecolor":"black"})
        plt.title(f'R² Scores: {src} → {trg}')
        plt.xlabel('Model')
        plt.ylabel('R²')
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.savefig(comment_outdir / f'boxplot_{src}_{trg}.{file_format}')
        plt.close()

print("Wilcoxon statistical tests completed. Results saved to 'wilcoxon_tests_r2_all_combos.csv'.")

Wilcoxon statistical tests completed. Results saved to 'wilcoxon_tests_r2_all_combos.csv'.


# Bubble Heatmap (Reviewer 2, comment 8)

In [7]:
# Note: This is copied from stage4_generate_paper_plots.ipynb
# -----------------------------------------------------------
comment_outdir = outdir / 'reviewer2_comment8'
comment_outdir.mkdir(parents=True, exist_ok=True)

# Extract all within-study results (src == trg)
df = all_scores[
    (all_scores["met"] == metric_name) & 
    (all_scores["src"] == all_scores["trg"])  # src == trg
].reset_index(drop=True)
print(df.shape)
df.head()

df = df.groupby(["model", "src"]).agg(mean_splits=("value", "mean"), std_splits=("value", "std")).reset_index()

df_mean = df.sort_values(by=["src", "mean_splits"], ascending=[True, False]).reset_index(drop=True)  # compute mean
df_std = df.sort_values(by=["src", "std_splits"], ascending=[True, False]).reset_index(drop=True)    # compute std
# display(df.iloc[:7,:])

# Mean across splits
df_mean = df_mean.pivot(index="src", columns="model", values="mean_splits")#.reset_index(drop=False)
df_mean.index.name = None
df_mean.columns.name = None
df_mean = df_mean.T
df_mean = df_mean.round(3)
df_mean.index = df_mean.index.map(model_name_mapping)
df_mean = df_mean[datasets_order]
print('Mean across splits')
display(df_mean)

# Std across splits
df_std = df_std.pivot(index="src", columns="model", values="std_splits")#.reset_index(drop=False)
df_std.index.name = None
df_std.columns.name = None
df_std = df_std.T
df_std = df_std.round(3)
df_std.index = df_std.index.map(model_name_mapping)
df_std = df_std[datasets_order]
print('Std across splits')
display(df_std)

# ---------------------------------------------------------
# Add new row to df_mean containing the mean of each column
datasets_mean = df_mean.mean(axis=0)
df_mean.loc['Mean across datasets'] = datasets_mean

# Add new column to df_mean containing the mean of each row
models_mean = df_mean.mean(axis=1)
df_mean['Mean across models'] = models_mean

# Assign NA to cell of (mean_dataset, mean_model)
df_mean.loc['Mean across datasets', 'Mean across models'] = np.nan
df_mean.to_csv(comment_outdir / f'{metric_name}_mean_within_study_all_models.csv')

print('Mean across splits (including across models and datasets)')
display(df_mean)

# Add new row to df_mean containing the mean of each column
datasets_std = df_std.mean(axis=0)
df_std.loc['Mean across datasets'] = datasets_std

# Add new column to df_mean containing the mean of each row
models_std = df_std.mean(axis=1)
df_std['Mean across models'] = models_std

# Assign NA to cell of (mean_dataset, mean_model)
df_std.loc['Mean across datasets', 'Mean across models'] = np.nan
df_std.to_csv(comment_outdir / f'{metric_name}_std_within_study_all_models.csv')

print('Std across splits (including across models and datasets)')
display(df_std)

(350, 6)
Mean across splits


Unnamed: 0,gCSI,CCLE,GDSCv2,GDSCv1,CTRPv2
DeepCDR,0.72,0.766,0.76,0.704,0.811
DeepTTC,0.759,0.789,0.775,0.737,0.849
GraphDRP,0.736,0.746,0.765,0.733,0.855
HiDRA,0.711,0.756,0.768,0.722,0.832
LGBM,0.782,0.801,0.764,0.695,0.784
tCNNS,0.591,0.705,0.648,0.575,0.639
UNO,0.774,0.796,0.775,0.738,0.841


Std across splits


Unnamed: 0,gCSI,CCLE,GDSCv2,GDSCv1,CTRPv2
DeepCDR,0.02,0.023,0.007,0.008,0.005
DeepTTC,0.023,0.018,0.01,0.006,0.004
GraphDRP,0.029,0.018,0.008,0.007,0.006
HiDRA,0.027,0.02,0.011,0.007,0.005
LGBM,0.02,0.011,0.008,0.006,0.003
tCNNS,0.061,0.049,0.052,0.049,0.063
UNO,0.025,0.012,0.007,0.007,0.006


Mean across splits (including across models and datasets)


Unnamed: 0,gCSI,CCLE,GDSCv2,GDSCv1,CTRPv2,Mean across models
DeepCDR,0.72,0.766,0.76,0.704,0.811,0.7522
DeepTTC,0.759,0.789,0.775,0.737,0.849,0.7818
GraphDRP,0.736,0.746,0.765,0.733,0.855,0.767
HiDRA,0.711,0.756,0.768,0.722,0.832,0.7578
LGBM,0.782,0.801,0.764,0.695,0.784,0.7652
tCNNS,0.591,0.705,0.648,0.575,0.639,0.6316
UNO,0.774,0.796,0.775,0.738,0.841,0.7848
Mean across datasets,0.724714,0.765571,0.750714,0.700571,0.801571,


Std across splits (including across models and datasets)


Unnamed: 0,gCSI,CCLE,GDSCv2,GDSCv1,CTRPv2,Mean across models
DeepCDR,0.02,0.023,0.007,0.008,0.005,0.0126
DeepTTC,0.023,0.018,0.01,0.006,0.004,0.0122
GraphDRP,0.029,0.018,0.008,0.007,0.006,0.0136
HiDRA,0.027,0.02,0.011,0.007,0.005,0.014
LGBM,0.02,0.011,0.008,0.006,0.003,0.0096
tCNNS,0.061,0.049,0.052,0.049,0.063,0.0548
UNO,0.025,0.012,0.007,0.007,0.006,0.0114
Mean across datasets,0.029286,0.021571,0.014714,0.012857,0.013143,


In [None]:
# --- Bubble Heatmap for Within-Dataset Results ---
model_order = ['DeepCDR', 'DeepTTC', 'GraphDRP', 'HiDRA', 'LGBM', 'tCNNS', 'UNO']
model_order = [s.lower() for s in model_order]  # Convert to lowercase to match the dataset

import matplotlib.patheffects as path_effects

fontsize = 8
bubble_size_min = 400
bubble_size_max = 600
# cmap = 'PiYG'
# cmap = 'viridis'
# cmap = 'plasma'
# cmap = 'Reds'
cmap = 'Oranges'

# Extract within-study results (src == trg)
df_within = all_scores[
    (all_scores['met'] == metric_name) & 
    (all_scores['src'] == all_scores['trg'])
].reset_index(drop=True)

# Compute mean and std across splits
df_within_agg = df_within.groupby(['model', 'src']).agg(
    mean_splits=('value', 'mean'),
    std_splits=('value', 'std')
).reset_index()

# Get min and max values for colorbar
min_value = df_within_agg['mean_splits'].min()
max_value = df_within_agg['mean_splits'].max()

# Calculate inverse variance (1/std^2) for bubble size
df_within_agg['inv_variance'] = 1 / (df_within_agg['std_splits'] ** 2)
# Normalize inverse variance for bubble sizes (scale to [50, 500])
inv_var_min, inv_var_max = df_within_agg['inv_variance'].min(), df_within_agg['inv_variance'].max()
df_within_agg['bubble_size'] = bubble_size_min + bubble_size_max * (df_within_agg['inv_variance'] - inv_var_min) / (inv_var_max - inv_var_min)

# Pivot for plotting
df_mean_pivot = df_within_agg.pivot(index='model', columns='src', values='mean_splits')
df_size_pivot = df_within_agg.pivot(index='model', columns='src', values='bubble_size')

# Reorder indices and columns
df_mean_pivot = df_mean_pivot.loc[model_order, datasets_order]
df_size_pivot = df_size_pivot.loc[model_order, datasets_order]

# Create bubble heatmap
plt.figure(figsize=(8, 6))
sns.set_style('whitegrid')
x, y = np.meshgrid(range(len(datasets_order)), range(len(model_order)))
x = x.flatten()
y = y.flatten()
means = df_mean_pivot.values.flatten()
sizes = df_size_pivot.values.flatten()

# Plot scatter with dynamic colorbar range
scatter = plt.scatter(x, y, s=sizes, c=means, cmap=cmap,
    vmin=min_value, vmax=max_value,
    edgecolors='black', linewidth=0.5)
plt.colorbar(scatter, label='Mean R²')

# Add text annotations for mean R² values
for i, dataset in enumerate(datasets_order):
    for j, model in enumerate(model_order):
        mean_r2 = df_mean_pivot.loc[model, dataset]
        if not np.isnan(mean_r2):
            # Choose text color based on bubble color (mean_r2)
            text_color = 'white' if mean_r2 < 0 else 'black'  # White for negative (dark), black for positive (light)
            plt.text(
                i, j, f'{mean_r2:.2f}',
                ha='center', va='center',
                fontsize=fontsize, color=text_color,
                weight='bold',  # Bold for better contrast
                path_effects=[path_effects.withStroke(linewidth=0.5, foreground='white')]
            )

plt.xticks(range(len(datasets_order)), datasets_order, rotation=45, ha='right')
plt.yticks(range(len(model_order)), [model_name_mapping.get(m.lower(), m) for m in model_order])
plt.xlabel('Dataset')
plt.ylabel('Model')
plt.title('Within-Dataset R² Performance (Bubble Size: Inverse Variance)')
plt.tight_layout()
plt.savefig(comment_outdir / f'bubble_heatmap_within_dataset.{file_format}', dpi=dpi, bbox_inches='tight')
plt.close()

outpath = comment_outdir / f'bubble_heatmap_within_dataset.{file_format}'
print(f'Bubble heatmap saved in: {outpath}')
# print(f'Bubble heatmap saved in: {comment_outdir / "bubble_heatmap_within_dataset.png"}')
print(f'Value range: min={min_value:.2f}, max={max_value:.2f}')

Bubble heatmap saved in: /nfs/ml_lab/projects/improve/data/experiments/cross-dataset-drp-paper/results_for_paper_revision/reviewer2_comment8/bubble_heatmap_within_dataset.png
Value range: min=0.57, max=0.85


In [None]:
# Extract within-study results (src == trg)
df_within = all_scores[
    (all_scores['met'] == metric_name) & 
    (all_scores['src'] == all_scores['trg'])
].reset_index(drop=True)

# Compute mean and std across splits
df_within_agg = df_within.groupby(['model', 'src']).agg(
    mean_splits=('value', 'mean'),
    std_splits=('value', 'std')
).reset_index()

# Reorder data for plotting
df_within_agg['src'] = pd.Categorical(df_within_agg['src'], categories=datasets_order, ordered=True)
df_within_agg['model'] = pd.Categorical(df_within_agg['model'], categories=model_order, ordered=True)

# Create grouped bar plot
plt.figure(figsize=(10, 6))
sns.set_style('whitegrid')
bar_plot = sns.barplot(
    data=df_within_agg,
    x='src',
    y='mean_splits',
    hue='model',
    palette='tab10'
)

# Add error bars
bar_width = 0.8 / len(model_order)  # Width per bar
for i, dataset in enumerate(datasets_order):
    for j, model in enumerate(model_order):
        subset = df_within_agg[(df_within_agg['model'] == model) & (df_within_agg['src'] == dataset)]
        if not subset.empty:
            mean = subset['mean_splits'].iloc[0]
            std = subset['std_splits'].iloc[0]
            x_pos = i + (j - (len(model_order) - 1) / 2) * bar_width
            plt.errorbar(x=x_pos, y=mean, yerr=std, fmt='none', capsize=3, color='black')

# Customize plot
plt.xlabel('Dataset')
plt.ylabel('Mean R²')
plt.title('Within-Dataset R² Performance with Error Bars')
plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

# Set y-axis limits dynamically
r2_min = (df_within_agg['mean_splits'] - df_within_agg['std_splits']).min()
r2_max = (df_within_agg['mean_splits'] + df_within_agg['std_splits']).max()
plt.ylim(min(0, r2_min - 0.05), max(1.0, r2_max + 0.05))

plt.savefig(comment_outdir / f'bar_plot_within_r2.{file_format}', dpi=300, bbox_inches='tight')
plt.close()

outpath = comment_outdir / f'bar_plot_within_r2.{file_format}'
print(f'Bar plot saved in: {outpath}')
# print(f'Bar plot saved in: {comment_outdir / "bar_plot_within_r2.png"}')
print(f'R² range (with error bars): min={r2_min:.2f}, max={r2_max:.2f}')

Bar plot saved in: /nfs/ml_lab/projects/improve/data/experiments/cross-dataset-drp-paper/results_for_paper_revision/reviewer2_comment8/bar_plot_within_r2.png
R² range (with error bars): min=0.53, max=0.86


# The Added Value of Gn and Gna (Reviewer 2, comment 4)