In [1]:
import json
import os
import warnings
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap, Normalize

import importlib
import utils  # Import your module
from utils import model_name_mapping, metrics_name_mapping

# After making changes to utils.py, reload it
importlib.reload(utils)

# filepath = Path(__file__).parent
filepath = Path(os.path.abspath(''))
print(filepath)

# ----------------------------------------------------------------
# metrics_name_mapping = {
#     "r2": "R²",
#     "mae": "MAE",
#     "rmse": "RMSE",
#     "stgr": "STGR",
#     "stgi": "STGI",
# }

# model_name_mapping = {
#     "deepcdr": "DeepCDR",
#     "graphdrp": "GraphDRP",
#     "hidra": "HiDRA",
#     "lgbm": "LGBM",
#     "tcnns": "tCNNS",
#     "uno": "UNO",
# }

/nfs/lambda_stor_01/data/apartin/projects/IMPROVE/csa-paper-clean


In [2]:
# paths
scores_dir = Path('outputs/s1_scores')
# outdir = Path('outputs/s4_stats/figures')
# outdir.mkdir(parents=True, exist_ok=True)
stats_dir = Path('outputs/s4_stats')
stats_dir.mkdir(parents=True, exist_ok=True)

# file_format = 'eps'
# file_format = 'jpeg'
file_format = 'png'
# file_format = 'tiff'
dpi = 600

filename = 'all_models_scores.csv'
canc_col_name = 'improve_sample_id'
drug_col_name = 'improve_chem_id'

# dataset orders
datasets_order_size = ['gCSI', 'CCLE', 'GDSCv2', 'GDSCv1', 'CTRPv2']
datasets_order_alpha = ['CCLE', 'CTRPv2', 'GDSCv1', 'GDSCv2', 'gCSI']
# default order for within-study visuals
datasets_order = datasets_order_size
show_plot = True

all_scores = pd.read_csv(scores_dir / filename, sep=',')
all_scores.iloc[:3,:]

Unnamed: 0,met,split,value,src,trg,model
0,mse,0,0.006295,CCLE,CCLE,deepcdr
1,mse,1,0.00595,CCLE,CCLE,deepcdr
2,mse,2,0.005129,CCLE,CCLE,deepcdr


In [3]:
# Specify the metric
metric_name = 'r2'

# Specify the models you want to include
models_to_include = []  # Replace with your desired models
# models_to_include = ["deepcdr", "deepttc", "graphdrp", "hidra", "lgbm", "uno"]  # Replace with your desired models
# models_to_include = ["graphdrp"]  # Replace with your desired models

# Statistical Tests (Reviewer 3, comment 1)

In [4]:
# Reviewer 3 (Wilcoxon) output dirs
r3_fig_dir = stats_dir / 'reviewer3_comment1' / 'figures'
r3_fig_dir.mkdir(parents=True, exist_ok=True)

from itertools import combinations

# Stats output dirs
all_wilcoxon_tests_outpath = stats_dir / 'reviewer3_comment1' / 'all_wilcoxon_tests'
all_wilcoxon_tests_outpath.mkdir(parents=True, exist_ok=True)

# Data
df = all_scores.copy()

# Filtering for R² scores
r2_df = df[df['met'] == 'r2'][['src', 'trg', 'model', 'split', 'value']]
r2_df['model'] = r2_df['model'].map(model_name_mapping)

# Defining models and datasets
models = r2_df['model'].unique()
src_datasets = r2_df['src'].unique()
trg_datasets = r2_df['trg'].unique()

model_pairwise_pairs = list(combinations(models, 2))
print(f'Total model unique pairs: {len(model_pairwise_pairs)}')

# Preparing results
results = []
skipped = []
alpha = 0.05 / len(model_pairwise_pairs)

Total model unique pairs: 21


In [5]:
# Load precomputed Wilcoxon results from the script
results_path = stats_dir / 'reviewer3_comment1' / f'wilcoxon_tests_{metric_name}_all_combos.csv'
res_df = pd.read_csv(results_path)
print(f'Loaded Wilcoxon results from: {results_path}')

display(res_df.head())

Loaded Wilcoxon results from: outputs/s4_stats/reviewer3_comment1/wilcoxon_tests_r2_all_combos.csv


Unnamed: 0,src,trg,model1,model2,median_model1,median_model2,mean_r2_diff,p_value,significant
0,CCLE,CCLE,DeepCDR,DeepTTC,0.761925,0.791364,-0.023587,0.019531,False
1,CCLE,CCLE,DeepCDR,GraphDRP,0.761925,0.748778,0.01962,0.037109,False
2,CCLE,CCLE,DeepCDR,HiDRA,0.761925,0.761346,0.009445,0.492188,False
3,CCLE,CCLE,DeepCDR,LGBM,0.761925,0.804346,-0.03549,0.001953,True
4,CCLE,CCLE,DeepCDR,tCNNS,0.761925,0.71913,0.061112,0.013672,False


In [6]:
# Generating boxplots
for src in src_datasets:
    for trg in trg_datasets:
        pair_df = r2_df[(r2_df['src'] == src) & (r2_df['trg'] == trg)]
        res_df_src_trg_combo = res_df[(res_df['src'] == src) & (res_df['trg'] == trg)]
        res_df_src_trg_combo.to_csv(stats_dir / 'reviewer3_comment1' / 'all_wilcoxon_tests' / f'wilcoxon_tests_r2_{src}_{trg}_combos.csv', index=False)
        plt.figure(figsize=(8, 6))
        sns.boxplot(x='model', y='value', data=pair_df, hue='model', palette='Set3', legend=False,
            showmeans=True, meanprops={'marker': 'o', 'markerfacecolor': 'black'})
        plt.title(f'R² Scores: {src} → {trg}')
        plt.xlabel('Model')
        plt.ylabel('R²')
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.savefig(r3_fig_dir / f'boxplot_{src}_{trg}.{file_format}')
        plt.close()

In [7]:
# Save a selected subset of wilcoxon test results (used for S4 in the paper)

# Filter 1: Significant results (p-value < alpha)
# alpha = 0.05 / 21
sig_df = res_df[res_df['significant'] == True]

# Filter 2: At least one positive median R²
th = 0.1  # Threshold for median R²
pos_median_df = sig_df[(sig_df['median_model1'] > th) | (sig_df['median_model2'] > th)]

# Filter 3: Exclude within-dataset pairs (source != target)
cross_df = pos_median_df[pos_median_df['src'] != pos_median_df['trg']]

# Filter 4 & 5: Select top 3 rows per source-target pair by absolute mean_r2_diff
selected_rows = []
src_trg_pairs = cross_df.groupby(['src', 'trg'])
for (src, trg), group in src_trg_pairs:
    # Sort by absolute mean_r2_diff (descending) and take top 3
    top_rows = group.sort_values(by='mean_r2_diff', key=abs, ascending=False).head(3)
    selected_rows.append(top_rows)

# Combine selected rows
selected_df = pd.concat(selected_rows, ignore_index=True)

# Save selected results to Step 4 stats
sel_dir = stats_dir / 'reviewer3_comment1'
sel_dir.mkdir(parents=True, exist_ok=True)
selected_df.to_csv(sel_dir / f'wilcoxon_tests_{metric_name}_selected.csv', index=False)

# Print summary for verification
print(f'Total selected rows: {len(selected_df)}')
print(f"Selected source-target pairs: {selected_df[['src', 'trg']].drop_duplicates().shape[0]}")
print(selected_df[['src', 'trg', 'model1', 'model2', 'median_model1', 'median_model2', 'mean_r2_diff', 'p_value']].head())

Total selected rows: 36
Selected source-target pairs: 12
      src   trg    model1 model2  median_model1  median_model2  mean_r2_diff  \
0    CCLE  gCSI  GraphDRP    UNO      -0.383007       0.200078     -0.567031   
1    CCLE  gCSI     tCNNS    UNO      -0.113735       0.200078     -0.506696   
2    CCLE  gCSI      LGBM    UNO      -0.119092       0.200078     -0.308608   
3  CTRPv2  CCLE     tCNNS    UNO       0.347740       0.627901     -0.306836   
4  CTRPv2  CCLE  GraphDRP  tCNNS       0.597406       0.347740      0.273507   

    p_value  
0  0.001953  
1  0.001953  
2  0.001953  
3  0.001953  
4  0.001953  


In [8]:
# Assume r2_df is available from your code (R² scores for all models, splits, src, trg)
# r2_df = df[df['met'] == 'r2'][['src', 'trg', 'model', 'split', 'value']]
# r2_df['model'] = r2_df['model'].map(model_name_mapping)

# Compute mean R² per model for each source-target pair
mean_r2_df = r2_df.groupby(['src', 'trg', 'model'])['value'].mean().reset_index(name='mean_r2')

# Identify pairs where all models have mean R² lower than a threshold (e.g., 0.1)
th = 0.1  # Threshold for mean R²
negative_pairs = []
src_trg_groups = mean_r2_df.groupby(['src', 'trg'])
for (src, trg), group in src_trg_groups:
    if (group['mean_r2'] < th).all():
        negative_pairs.append((src, trg))

# Save negative pairs to CSV for reference
negative_pairs_df = pd.DataFrame(negative_pairs, columns=['src', 'trg'])
r2_stat_dir = stats_dir / 'reviewer2_comment8'
r2_stat_dir.mkdir(parents=True, exist_ok=True)
negative_pairs_df.to_csv(r2_stat_dir / 'negative_r2_pairs.csv', index=False)

# Print results
print(f'Source-target pairs with all negative mean R²: {len(negative_pairs)}')
print(negative_pairs_df)

# Identify valid pairs for boxplots (exclude negative pairs and within-dataset pairs)
valid_pairs = [(src, trg) for src, trg in r2_df[['src', 'trg']].drop_duplicates().values 
               if src != trg and (src, trg) not in negative_pairs]
print(f'Valid source-target pairs for boxplots: {len(valid_pairs)}')
print(pd.DataFrame(valid_pairs, columns=['src', 'trg']))

Source-target pairs with all negative mean R²: 8
      src     trg
0    CCLE  CTRPv2
1    CCLE  GDSCv1
2    CCLE  GDSCv2
3  GDSCv1  CTRPv2
4    gCSI    CCLE
5    gCSI  CTRPv2
6    gCSI  GDSCv1
7    gCSI  GDSCv2
Valid source-target pairs for boxplots: 12
       src     trg
0     CCLE    gCSI
1   CTRPv2    CCLE
2   CTRPv2  GDSCv1
3   CTRPv2  GDSCv2
4   CTRPv2    gCSI
5   GDSCv1    CCLE
6   GDSCv1  GDSCv2
7   GDSCv1    gCSI
8   GDSCv2    CCLE
9   GDSCv2  CTRPv2
10  GDSCv2  GDSCv1
11  GDSCv2    gCSI


# Bubble Heatmap (Reviewer 2, comment 8)

In [9]:
# Note: Use precomputed within-study summaries from Step 3 (no recompute here)
# -----------------------------------------------------------

ws_dir = Path('outputs/s3_GaGnGna/within_study')
mean_path = ws_dir / f'{metric_name}_mean_within_study_all_models.csv'
std_path  = ws_dir / f'{metric_name}_std_within_study_all_models.csv'

# Load summaries for reference display (optional)
df_mean = pd.read_csv(mean_path, index_col=0)
df_std  = pd.read_csv(std_path,  index_col=0)

df_mean.index.name = None
df_mean.columns.name = None
df_std.index.name = None
df_std.columns.name = None

print('Mean across splits (loaded from Step 3)')
display(df_mean)
print('Std across splits (loaded from Step 3)')
display(df_std)

Mean across splits (loaded from Step 3)


Unnamed: 0,gCSI,CCLE,GDSCv2,GDSCv1,CTRPv2,Mean across models
deepcdr,0.72,0.766,0.76,0.704,0.811,0.752
deepttc,0.759,0.789,0.775,0.737,0.849,0.782
graphdrp,0.736,0.746,0.765,0.733,0.855,0.767
hidra,0.711,0.756,0.768,0.722,0.832,0.758
lgbm,0.782,0.801,0.764,0.695,0.784,0.765
tcnns,0.591,0.705,0.648,0.575,0.639,0.631
uno,0.774,0.796,0.775,0.738,0.841,0.785
Mean across datasets,0.725,0.766,0.751,0.701,0.801,


Std across splits (loaded from Step 3)


Unnamed: 0,gCSI,CCLE,GDSCv2,GDSCv1,CTRPv2,Mean across models
deepcdr,0.02,0.023,0.007,0.008,0.005,0.013
deepttc,0.023,0.018,0.01,0.006,0.004,0.012
graphdrp,0.029,0.018,0.008,0.007,0.006,0.014
hidra,0.027,0.02,0.011,0.007,0.005,0.014
lgbm,0.02,0.011,0.008,0.006,0.003,0.01
tcnns,0.061,0.049,0.052,0.049,0.063,0.055
uno,0.025,0.012,0.007,0.007,0.006,0.011
Mean across datasets,0.029,0.021,0.015,0.013,0.013,


In [10]:
# --- Bubble Heatmap for Within-Dataset Results ---
model_order = ['DeepCDR', 'DeepTTC', 'GraphDRP', 'HiDRA', 'LGBM', 'tCNNS', 'UNO']
model_order = [s.lower() for s in model_order]  # Convert to lowercase to match the dataset

import matplotlib.patheffects as path_effects

fontsize = 8
bubble_size_min = 400
bubble_size_max = 600
# cmap = 'PiYG'
# cmap = 'viridis'
# cmap = 'plasma'
# cmap = 'Reds'
# Ensure Reviewer 2 figures directory exists
r2_fig_dir = stats_dir / 'reviewer2_comment8' / 'figures'
r2_fig_dir.mkdir(parents=True, exist_ok=True)

cmap = 'Oranges'

# Load precomputed within-study summaries from Step 3
ws_dir = Path('outputs/s3_GaGnGna/within_study')
mean_path = ws_dir / f'{metric_name}_mean_within_study_all_models.csv'
std_path  = ws_dir / f'{metric_name}_std_within_study_all_models.csv'

_df_mean = pd.read_csv(mean_path, index_col=0)
_df_std  = pd.read_csv(std_path,  index_col=0)

# Drop summary row/col if present
_df_mean = _df_mean.drop(index=['Mean across datasets'], errors='ignore')
_df_mean = _df_mean.drop(columns=['Mean across models'], errors='ignore')
_df_std  = _df_std.drop(index=['Mean across datasets'], errors='ignore')
_df_std  = _df_std.drop(columns=['Mean across models'], errors='ignore')

# Ensure lowercase model ids
_df_mean.index = [m.lower() for m in _df_mean.index]
_df_std.index  = [m.lower() for m in _df_std.index]

# Pivots for plotting (models x datasets) in desired order
df_mean_pivot = _df_mean.loc[model_order, datasets_order]
df_std_pivot  = _df_std.loc[model_order, datasets_order]

# Colorbar range from means
min_value = df_mean_pivot.min().min()
max_value = df_mean_pivot.max().max()

# Bubble sizes from inverse variance
inv_var = 1 / (df_std_pivot.replace(0, np.nan) ** 2)
inv_var = inv_var.fillna(inv_var.min())
inv_var_min, inv_var_max = inv_var.min().min(), inv_var.max().max()
df_size_pivot = bubble_size_min + bubble_size_max * (inv_var - inv_var_min) / (inv_var_max - inv_var_min)

# Create bubble heatmap
plt.figure(figsize=(8, 6))
sns.set_style('whitegrid')
x, y = np.meshgrid(range(len(datasets_order)), range(len(model_order)))
x = x.flatten()
y = y.flatten()
means = df_mean_pivot.values.flatten()
sizes = df_size_pivot.values.flatten()

scatter = plt.scatter(x, y, s=sizes, c=means, cmap=cmap,
    vmin=min_value, vmax=max_value,
    edgecolors='black', linewidth=0.5)
plt.colorbar(scatter, label='Mean R²')

# Add text annotations for mean R² values
for i, dataset in enumerate(datasets_order):
    for j, model in enumerate(model_order):
        mean_r2 = df_mean_pivot.loc[model, dataset]
        if not np.isnan(mean_r2):
            text_color = 'white' if mean_r2 < 0 else 'black'
            plt.text(
                i, j, f'{mean_r2:.2f}',
                ha='center', va='center',
                fontsize=fontsize, color=text_color,
                weight='bold',
                path_effects=[path_effects.withStroke(linewidth=0.5, foreground='white')]
            )

plt.xticks(range(len(datasets_order)), datasets_order, rotation=45, ha='right')
plt.yticks(range(len(model_order)), [model_name_mapping.get(m.lower(), m) for m in model_order])
plt.xlabel('Dataset')
plt.ylabel('Model')
plt.title('Within-Dataset R² Performance (Bubble Size: Inverse Variance)')
plt.tight_layout()

# Optional: save derived pivots
r2_derived_dir = stats_dir / 'reviewer2_comment8' / 'derived'
r2_derived_dir.mkdir(parents=True, exist_ok=True)
df_mean_pivot.to_csv(r2_derived_dir / f'{metric_name}_bubble_mean_pivot.csv')
df_size_pivot.to_csv(r2_derived_dir / f'{metric_name}_bubble_size_pivot.csv')

plt.savefig(r2_fig_dir / f'bubble_heatmap_within_dataset.{file_format}', dpi=dpi, bbox_inches='tight')
plt.close()

outpath = r2_fig_dir / f'bubble_heatmap_within_dataset.{file_format}'
print(f'Bubble heatmap saved in: {outpath}')
print(f'Value range: min={min_value:.2f}, max={max_value:.2f}')

Bubble heatmap saved in: outputs/s4_stats/reviewer2_comment8/figures/bubble_heatmap_within_dataset.png
Value range: min=0.57, max=0.85


In [11]:
# Extract within-study results (src == trg)
df_within = all_scores[
    (all_scores['met'] == metric_name) & 
    (all_scores['src'] == all_scores['trg'])
].reset_index(drop=True)

# Compute mean and std across splits
df_within_agg = df_within.groupby(['model', 'src']).agg(
    mean_splits=('value', 'mean'),
    std_splits=('value', 'std')
).reset_index()

# Reorder data for plotting
df_within_agg['src'] = pd.Categorical(df_within_agg['src'], categories=datasets_order, ordered=True)
df_within_agg['model'] = pd.Categorical(df_within_agg['model'], categories=model_order, ordered=True)

# Create grouped bar plot
plt.figure(figsize=(10, 6))
sns.set_style('whitegrid')
bar_plot = sns.barplot(
    data=df_within_agg,
    x='src',
    y='mean_splits',
    hue='model',
    palette='tab10'
)

# Add error bars
bar_width = 0.8 / len(model_order)  # Width per bar
for i, dataset in enumerate(datasets_order):
    for j, model in enumerate(model_order):
        subset = df_within_agg[(df_within_agg['model'] == model) & (df_within_agg['src'] == dataset)]
        if not subset.empty:
            mean = subset['mean_splits'].iloc[0]
            std = subset['std_splits'].iloc[0]
            x_pos = i + (j - (len(model_order) - 1) / 2) * bar_width
            plt.errorbar(x=x_pos, y=mean, yerr=std, fmt='none', capsize=3, color='black')

# Customize plot
plt.xlabel('Dataset')
plt.ylabel('Mean R²')
plt.title('Within-Dataset R² Performance with Error Bars')
plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

# Set y-axis limits dynamically
r2_min = (df_within_agg['mean_splits'] - df_within_agg['std_splits']).min()
r2_max = (df_within_agg['mean_splits'] + df_within_agg['std_splits']).max()
plt.ylim(min(0, r2_min - 0.05), max(1.0, r2_max + 0.05))

plt.savefig(r2_fig_dir / f'bar_plot_within_r2.{file_format}', dpi=300, bbox_inches='tight')
plt.close()

outpath = r2_fig_dir / f'bar_plot_within_r2.{file_format}'
print(f'Bar plot saved in: {outpath}')
print(f'R² range (with error bars): min={r2_min:.2f}, max={r2_max:.2f}')

Bar plot saved in: outputs/s4_stats/reviewer2_comment8/figures/bar_plot_within_r2.png
R² range (with error bars): min=0.53, max=0.86
