In [26]:
import json
import os
import warnings
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap, Normalize

import importlib
import utils  # Import your module
from utils import model_name_mapping, metrics_name_mapping

# After making changes to utils.py, reload it
importlib.reload(utils)

# filepath = Path(__file__).parent
filepath = Path(os.path.abspath(''))
print(filepath)

metrics_name_mapping = {
    "r2": "R²",
    "mae": "MAE",
    "rmse": "RMSE",
    "stgr": "STGR",
    "stgi": "STGI",
}

model_name_mapping = {
    "deepcdr": "DeepCDR",
    "graphdrp": "GraphDRP",
    "hidra": "HiDRA",
    "lgbm": "LGBM",
    "tcnns": "tCNNS",
    "uno": "UNO",
}

/nfs/ml_lab/projects/improve/data/experiments/cross-dataset-drp-paper


In [27]:
datadir = Path('splits_averaged')
outdir = filepath / 'results_for_paper_revision'
os.makedirs(outdir, exist_ok=True)

# file_format = 'eps'
# file_format = 'jpeg'
# file_format = 'png'
file_format = 'tiff'
dpi = 600

filename = 'all_models_scores.csv'
canc_col_name = 'improve_sample_id'
drug_col_name = 'improve_chem_id'

# datasets_order = ['CCLE', 'CTRPv2', 'GDSCv1', 'GDSCv2', 'gCSI']  # alphabetical order
datasets_order = ['gCSI', 'CCLE', 'GDSCv2', 'GDSCv1', 'CTRPv2']  # order by sample size
show_plot = True

all_scores = pd.read_csv(filepath / datadir / filename, sep=',')
all_scores.iloc[:3,:]

Unnamed: 0,met,split,value,src,trg,model
0,mse,0,0.006295,CCLE,CCLE,deepcdr
1,mse,1,0.00595,CCLE,CCLE,deepcdr
2,mse,2,0.005129,CCLE,CCLE,deepcdr


In [28]:
# Specify the metric
metric_name = "r2"

# Specify the models you want to include
models_to_include = []  # Replace with your desired models
# models_to_include = ["deepcdr", "graphdrp", "hidra", "lgbm", "uno"]  # Replace with your desired models
# models_to_include = ["graphdrp"]  # Replace with your desired models

# Statistical Tests (Reviewer 3, comment 1)

In [30]:
from scipy.stats import wilcoxon
from itertools import combinations

# Filtering for R² scores
r2_df = all_scores[all_scores['met'] == 'r2'][['src', 'trg', 'model', 'split', 'value']]
r2_df['model'] = r2_df['model'].map(model_name_mapping)

# Defining models and datasets
models = r2_df['model'].unique()
src_datasets = r2_df['src'].unique()
trg_datasets = r2_df['trg'].unique()

# Preparing results
results = []
skipped = []
alpha = 0.05 / 15  # Bonferroni correction for 15 pairwise comparisons

In [31]:
# Iterate over source-target pairs
for src in src_datasets:
    for trg in trg_datasets:
        # Data for this pair
        pair_df = r2_df[(r2_df['src'] == src) & (r2_df['trg'] == trg)]
        if pair_df.empty:
            continue

        model_pairs = combinations(models, 2)

        # Pairwise Wilcoxon tests
        for model1, model2 in model_pairs:
            scores1 = pair_df[pair_df['model'] == model1]['value'].values
            scores2 = pair_df[pair_df['model'] == model2]['value'].values
            
            if len(scores1) == 10 and len(scores2) == 10:  # Ensure 10 splits
                try:
                    stat, p = wilcoxon(scores1, scores2, alternative='two-sided')
                    mean_diff = np.mean(scores1 - scores2)
                    # if p < alpha:
                    results.append({
                        'src': src,
                        'trg': trg,
                        'model1': model1,
                        'model2': model2,
                        'median_model1': np.median(scores1),
                        'median_model2': np.median(scores2),
                        'mean_r2_diff': mean_diff,
                        'p_value': p,
                        'significant': p < alpha
                    })
                except:
                    print(f"Wilcoxon test failed for {model1} vs {model2} on {src} → {trg}: {e}")
            else:
                print(f"Skipping {model1} vs {model2} on {src} → {trg} due to insufficient splits: {len(scores1)} vs {len(scores2)}")
                skipped.append({
                    'src': src,
                    'trg': trg,
                    'model1': model1,
                    'model2': model2,
                    'len_model1': len(scores1),
                    'len_model2': len(scores2)
                })

In [None]:
# Saving significant results
res_df = pd.DataFrame(results)
res_df.to_csv(outdir / 'wilcoxon_tests_r2_all_combos.csv', index=False)

# Generating boxplots for key pairs
for src in src_datasets:
    for trg in trg_datasets:
        pair_df = r2_df[(r2_df['src'] == src) & (r2_df['trg'] == trg)]
        res_df_src_trg_combo = res_df[(res_df['src'] == src) & (res_df['trg'] == trg)]
        res_df_src_trg_combo.to_csv(outdir / f'wilcoxon_tests_r2_{src}_{trg}_combos.csv', index=False)
        plt.figure(figsize=(8, 6))
        # sns.boxplot(x='model', y='value', data=pair_df, palette='Set3')
        # sns.boxplot(x='model', y='value', data=pair_df, hue='model', palette='Set3', legend=False)
        sns.boxplot(x='model', y='value', data=pair_df, hue='model', palette='Set3', legend=False,
            showmeans=True, meanprops={"marker":"o", "markerfacecolor":"black"})
        plt.title(f'R² Scores: {src} → {trg}')
        plt.xlabel('Model')
        plt.ylabel('R²')
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.savefig(outdir / f'boxplot_{src}_{trg}.png')
        plt.close()

print("Wilcoxon statistical tests completed. Results saved to 'wilcoxon_tests_r2_all_combos.csv'.")

Wilcoxon statistical tests completed. Results saved to 'wilcoxon_tests_r2_all_combos.csv'.


# Bubble Heatmap (Reviewer 3, comment 8)

# Within-study results

In [None]:
# Extract all within-study results (src == trg)
df = all_scores[
    (all_scores["met"] == metric_name) & 
    (all_scores["src"] == all_scores["trg"])  # src == trg
].reset_index(drop=True)
print(df.shape)
df.head()

df = df.groupby(["model", "src"]).agg(mean_splits=("value", "mean"), std_splits=("value", "std")).reset_index()

df_mean = df.sort_values(by=["src", "mean_splits"], ascending=[True, False]).reset_index(drop=True)  # compute mean
df_std = df.sort_values(by=["src", "std_splits"], ascending=[True, False]).reset_index(drop=True)    # compute std
# display(df.iloc[:7,:])

# Mean across splits
df_mean = df_mean.pivot(index="src", columns="model", values="mean_splits")#.reset_index(drop=False)
df_mean.index.name = None
df_mean.columns.name = None
df_mean = df_mean.T
df_mean = df_mean.round(3)
df_mean.index = df_mean.index.map(model_name_mapping)
df_mean = df_mean[datasets_order]
print('Mean across splits')
display(df_mean)

# Std across splits
df_std = df_std.pivot(index="src", columns="model", values="std_splits")#.reset_index(drop=False)
df_std.index.name = None
df_std.columns.name = None
df_std = df_std.T
df_std = df_std.round(3)
df_std.index = df_std.index.map(model_name_mapping)
df_std = df_std[datasets_order]
print('Std across splits')
display(df_std)

In [None]:
print('Mean across splits')
display(df_mean)

# Add new row to df_mean containing the mean of each column
datasets_mean = df_mean.mean(axis=0)
df_mean.loc['Mean across datasets'] = datasets_mean

# Add new column to df_mean containing the mean of each row
models_mean = df_mean.mean(axis=1)
df_mean['Mean across models'] = models_mean

# Assign NA to cell of (mean_dataset, mean_model)
df_mean.loc['Mean across datasets', 'Mean across models'] = np.nan
df_mean.to_csv(outdir / f'{metric_name}_mean_within_study_all_models.csv')

print('Mean across splits (including across models and datasets)')
display(df_mean)

In [None]:
print('Std across splits')
display(df_std)

# Add new row to df_mean containing the mean of each column
datasets_std = df_std.mean(axis=0)
df_std.loc['Mean across datasets'] = datasets_std

# Add new column to df_mean containing the mean of each row
models_std = df_std.mean(axis=1)
df_std['Mean across models'] = models_std

# Assign NA to cell of (mean_dataset, mean_model)
df_std.loc['Mean across datasets', 'Mean across models'] = np.nan
df_std.to_csv(outdir / f'{metric_name}_std_within_study_all_models.csv')

print('Std across splits (including across models and datasets)')
display(df_std)

In [None]:
datasets_mean

In [None]:
models_mean

In [None]:
# --------------------------------------------------------------------------------------
# Generalization scores (violin) from a single source to all targets (src != trg)
# --------------------------------------------------------------------------------------

if len(models_to_include) == 0:
    models_to_include = all_scores["model"].unique()

filtered_data = all_scores[
    (all_scores["model"].isin(models_to_include)) & 
    (all_scores["met"] == metric_name) & 
    (all_scores["src"] == all_scores["trg"])
].reset_index(drop=True)

# Map the model names to their corresponding names using model_name_mapping
filtered_data['model'] = filtered_data['model'].map(model_name_mapping)

utils.boxplot_violinplot_within_study(
    df=filtered_data,
    metric_name=metric_name,
    models_to_include=models_to_include,
    outdir=outdir,
    file_format=file_format,
    dpi=dpi,
    ymin=0.35,
    ymax=0.9,
    datasets_order=datasets_order
)

del filtered_data

In [None]:
# --------------------------------------------------------------------------------------
# Generalization scores (boxplot) from a single source to all targets (src != trg)
# --------------------------------------------------------------------------------------

# Specify the source dataset for filtering
source_dataset = "CTRPv2"  # Replace with the specific source dataset

if len(models_to_include) == 0:
    models_to_include = all_scores["model"].unique()

# Filter data for specific source dataset (src = CTRPv2), and R^2 metric
filtered_data = all_scores[
    (all_scores["src"] == source_dataset) &
    (all_scores["src"] != all_scores["trg"]) &  # Exclude cases where src = trg
    (all_scores["met"] == metric_name) &
    (all_scores["model"].isin(models_to_include))  # Include only specified models
]

# Map the model names to their corresponding names using model_name_mapping
filtered_data['model'] = filtered_data['model'].map(model_name_mapping)

utils.boxplot_violinplot_cross_study(
    df=filtered_data, 
    source_dataset=source_dataset, 
    metric_name=metric_name, 
    models_to_include=models_to_include, 
    outdir=outdir,
    file_format=file_format,
    dpi=dpi,
    ymin=-0.15,
    ymax=0.7,
    datasets_order=[x for x in datasets_order if x != source_dataset]
)

del filtered_data

In [None]:
# Raw G matrices for all models (including std in parentheses) (Linear Scale)

G_palette = 'Blues'

G_mean_dfs = {}
G_std_dfs = {}

for model_name in model_name_mapping.keys():
    print(model_name)
    mean_csa_filename = f'{model_name}_{metric_name}_mean_csa_table.csv'
    std_csa_filename = f'{model_name}_{metric_name}_std_csa_table.csv'

    G_mean = pd.read_csv(filepath / datadir / mean_csa_filename, sep=',')
    G_std = pd.read_csv(filepath / datadir / std_csa_filename, sep=',')

    G_mean.set_index("src", inplace=True)
    G_std.set_index("src", inplace=True)

    # Save csv for llm
    G_mean.to_csv(outdir / f'{model_name}_{metric_name}_G_mean.csv')
    G_std.to_csv(outdir / f'{model_name}_{metric_name}_G_std.csv')

    G_mean_dfs[model_name] = G_mean
    G_std_dfs[model_name] = G_std

    utils.csa_heatmap(
        model_name=model_name, 
        metric_name=f"{metrics_name_mapping[metric_name]}",
        csa_metric_name='G',
        scores_csa_data=G_mean, 
        std_csa_data=G_std,
        vmin=0,
        vmax=1,
        outdir=outdir,
        file_format=file_format,
        dpi=dpi,
        palette=G_palette,
        decimal_digits=3,
        show=show_plot
    )

del model_name, G_mean, G_std

In [None]:
# # Print a single G matrix for a specific model
# model_id = 0
# model_name = list(G_mean_dfs.keys())[model_id]

# scores_csa_data = G_mean_dfs[model_name]
# scores_csa_data.index.name = None
# scores_csa_data.columns.name = None

# std_csa_data = G_std_dfs[model_name]
# std_csa_data.index.name = None
# std_csa_data.columns.name = None

# print(f'{model_name} r2 scores')
# print(scores_csa_data)

# print(f'{model_name} r2 stds')
# print(std_csa_data)

In [None]:
# # --------------------------------------------------------------------------------------
# ## CSA Scores with Standard Deviations (Discrete Levels Same Color)
# # --------------------------------------------------------------------------------------

# # Define discrete levels and custom colormap
# levels = [-1e6, 0, 0.25, 0.5, 0.7, 1]
# colors = ["#08306b", "#2171b5", "#6baed6", "#bdd7e7", "#eff3ff"]
# cmap = ListedColormap(colors)
# norm = BoundaryNorm(boundaries=levels, ncolors=len(colors))

# # Combine scores and stds for annotations
# combined_annotations = scores_csa_data.round(4).astype(str) + "\n(" + std_csa_data.round(4).astype(str) + ")"

# # Plot the combined heatmap
# plt.figure(figsize=(7, 5))
# sns.heatmap(
# scores_csa_data, 
#     annot=combined_annotations.values, 
#     fmt="", 
#     cmap=cmap, 
#     norm=norm, 
#     cbar_kws={'label': 'R² Score'}
# )

# # Customize colorbar ticks to align with levels
# colorbar = plt.gca().collections[0].colorbar
# colorbar.set_ticks(levels[1:])  # Exclude the placeholder -1e6
# colorbar.set_ticklabels(["< 0", "0-0.25", "0.25-0.5", "0.5-0.7", "> 0.7"])  # Custom labels

# plt.title("CSA Performance Scores with Standard Deviations (Discrete Levels)")
# plt.xlabel("Target Dataset")
# plt.ylabel("Source Dataset")
# plt.tight_layout()
# plt.show()

# # --------------------------------------------------------------------------------------
# ## CSA Scores with Standard Deviations (Discrete Levels Different Colors)
# # --------------------------------------------------------------------------------------

# # Define custom levels and pale colors
# levels = [-1e6, 0, 0.25, 0.5, 0.7, 1]  # Replace -float("inf") with a very small value
# colors = ["#dcd0ff", "#ffd1d1", "#ffebcc", "#ffffcc", "#d1ffd1"]  # Pale purple, red, orange, yellow, green
# cmap = ListedColormap(colors)
# norm = BoundaryNorm(boundaries=levels, ncolors=len(colors))

# # Plot heatmap
# plt.figure(figsize=(7, 5))
# sns.heatmap(scores_csa_data, annot=True, fmt=".2f", cmap=cmap, norm=norm, cbar_kws={'label': 'R² Score'})

# # Customize colorbar ticks to align with levels
# colorbar = plt.gca().collections[0].colorbar
# colorbar.set_ticks(levels[1:])  # Exclude the placeholder -1e6
# colorbar.set_ticklabels(["< 0", "0-0.25", "0.25-0.5", "0.5-0.7", "> 0.7"])  # Custom labels

# # Finalize plot
# plt.title("CSA Performance Scores with Standard Deviations (Discrete Levels)")
# plt.xlabel("Target Dataset")
# plt.ylabel("Source Dataset")
# plt.tight_layout()
# plt.show()

# Custom metrics (Ga, Gn, Gna)

In [None]:
# Use scores from a single model

# # Example CSA scores (R²)
# scores_csa_data = {
#     "CCLE": [0.7479, 0.5758, 0.4482, 0.3082, 0.0234],
#     "CTRPv2": [-0.4671, 0.8508, -0.1432, -0.0324, -1.1319],
#     "GDSCv1": [-0.2057, 0.1503, 0.734, 0.154, -0.4801],
#     "GDSCv2": [-0.2913, 0.2902, 0.2003, 0.7659, -0.7147],
#     "gCSI": [-0.3793, 0.2314, 0.4179, 0.3914, 0.733],
# }
# scores_csa_data = pd.DataFrame(scores_csa_data, index=["CCLE", "CTRPv2", "GDSCv1", "GDSCv2", "gCSI"])

# # Example CSA std deviations (optional for some metrics)
# std_csa_data = {
#     "CCLE": [0.0123, 0.0406, 0.0245, 0.0364, 0.1058],
#     "CTRPv2": [0.0249, 0.0031, 0.0135, 0.0253, 0.0901],
#     "GDSCv1": [0.078, 0.0301, 0.0065, 0.0389, 0.0304],
#     "GDSCv2": [0.0453, 0.0216, 0.0262, 0.0098, 0.0448],
#     "gCSI": [0.1053, 0.1546, 0.0903, 0.1144, 0.0314],
# }
# std_csa_data = pd.DataFrame(std_csa_data, index=["CCLE", "CTRPv2", "GDSCv1", "GDSCv2", "gCSI"])

model_name = "graphdrp"

mean_csa_filename = f'{model_name}_{metric_name}_mean_csa_table.csv'
std_csa_filename = f'{model_name}_{metric_name}_std_csa_table.csv'

scores = pd.read_csv(filepath / datadir / mean_csa_filename, sep=',')
stds = pd.read_csv(filepath / datadir / std_csa_filename, sep=',')

scores.set_index("src", inplace=True)
stds.set_index("src", inplace=True)

print(f'Model: {model_name}')
display(scores)
display(stds)

# Ga matrix

In [None]:
# Compute Ga using both implementations
Ga_bruteforce = utils.compute_aggregated_G_bruteforce(scores, normalize=False)
Ga_vectorized = utils.compute_aggregated_G_vectorized(scores, normalize=False)

# Compare results
print(f"Bruteforce Ga:\n{Ga_bruteforce}")
print(f"Vectorized Ga:\n{Ga_vectorized}")

# Check if all implementations are consistent
assert Ga_vectorized == Ga_bruteforce, "Mismatch between bruteforce and vectorized implementations!"

In [None]:
# Combine Ga scores from all models
Ga_list = []

for model_name in model_name_mapping.keys():
    print(model_name)
    
    mean_csa_filename = f'{model_name}_{metric_name}_mean_csa_table.csv'
    scores_csa_data = pd.read_csv(filepath / datadir / mean_csa_filename, sep=',')
    scores_csa_data.set_index('src', inplace=True)
    Ga = utils.compute_aggregated_G_vectorized(scores_csa_data, normalize=False)

    Ga['model'] = model_name
    Ga_list.append(Ga)

Ga_df = pd.DataFrame(Ga_list)
Ga_df.insert(loc=0, column='model', value=Ga_df.pop('model')) # place 'model' col at pos 0
Ga_df.set_index('model', inplace=True)
Ga_df.to_csv(outdir / 'Ga_table.csv')
print(Ga_df)

del model_name, scores_csa_data, Ga, Ga_list

In [None]:
# Plot Ga heatmap
utils.aggregated_G_heatmap(
    metric_name=f"Aggregated {metrics_name_mapping[metric_name]}",
    csa_metric_name='Ga',
    scores_aggregated_data=Ga_df.copy(),
    # palette="RdPu",
    palette=G_palette,
    vmin=0,
    vmax=0.6,
    outdir=outdir,
    file_format=file_format,
    dpi=dpi,
    show=show_plot
)

# Gn matrix
(this was previously Source-to-Target Generalization Ratio (STGR))

In [None]:
# Compute Gn using both implementations
Gn_bruteforce = utils.compute_Gn_bruteforce(scores)
Gn_vectorized = utils.compute_Gn_vectorized(scores)

# Convert bruteforce results to DataFrame for comparison
Gn_bruteforce_df = pd.DataFrame.from_dict(Gn_bruteforce, orient="index")

# Compare results
print(f'Bruteforce Gn:\n{Gn_bruteforce_df}')
print(f'\nVectorized Gn:\n{Gn_vectorized}')

# Check for equality
assert np.allclose(Gn_bruteforce_df, Gn_vectorized), "Gn results do not match!"
print("\nBoth implementations produce the same results!")

In [None]:
Gn_palette = 'Greens'

for model_name in model_name_mapping.keys():
    print(model_name)
    
    mean_csa_filename = f'{model_name}_{metric_name}_mean_csa_table.csv'
    scores_csa_data = pd.read_csv(filepath / datadir / mean_csa_filename, sep=',')
    scores_csa_data.set_index("src", inplace=True)
    Gn = utils.compute_Gn_vectorized(scores_csa_data)
    Gn.to_csv(outdir / f'{model_name}_{metric_name}_Gn_mean.csv') # Save csv for llm

    utils.csa_heatmap(
        model_name=model_name, 
        # metric_name="stgr",
        metric_name=f"Normalized {metrics_name_mapping[metric_name]}",
        csa_metric_name='Gn',
        scores_csa_data=Gn, 
        vmin=0,
        vmax=1,
        palette=Gn_palette,
        outdir=outdir,
        file_format=file_format,
        dpi=dpi,
        decimal_digits=3,
        show=show_plot
    )

del model_name, scores_csa_data, Gn

In [None]:
# # Print a single Gn matrix for a specific model
# model_id = 5
# model_name = list(G_mean_dfs.keys())[model_id]

# scores_csa_data = G_mean_dfs[model_name]

# Gn = utils.compute_Gn_vectorized(scores_csa_data)
# Gn.index.name = None
# Gn.columns.name = None

# print(f'{model_name} Gn')
# print(Gn)

# Gna matrix
(previously Source-to-Target Generalization Index (STGI))

In [None]:
# Compute Gna using both implementations
Gna_bruteforce = utils.compute_aggregated_G_bruteforce(scores, normalize=True)
Gna_vectorized = utils.compute_aggregated_G_vectorized(scores, normalize=True)

# Compare results
print("Bruteforce Gna:")
print(Gna_bruteforce)

print("\nVectorized Gna:")
print(Gna_vectorized)

# Check if all implementations are consistent
# assert Gna_vectorized.equals(pd.Series(Gna_bruteforce)), "Mismatch between bruteforce and simpler implementations!"
assert Gna_vectorized == Gna_bruteforce, "Mismatch between bruteforce and vectorized implementations!"

In [None]:
# Combine Gna scores from all models
Gna_list = []

for model_name in model_name_mapping.keys():
    print(model_name)
    
    mean_csa_filename = f'{model_name}_{metric_name}_mean_csa_table.csv'
    scores_csa_data = pd.read_csv(filepath / datadir / mean_csa_filename, sep=',')
    scores_csa_data.set_index('src', inplace=True)
    Gna = utils.compute_aggregated_G_vectorized(scores_csa_data, normalize=True)

    Gna['model'] = model_name
    # print(stgi)
    Gna_list.append(Gna)

Gna_df = pd.DataFrame(Gna_list)
Gna_df.insert(loc=0, column='model', value=Gna_df.pop('model')) # place 'model' col at pos 0
Gna_df.set_index('model', inplace=True)
Gna_df.to_csv(outdir / 'Gna_table.csv')
print(Gna_df)

del model_name, scores_csa_data, Gna, Gna_list

In [None]:
# Plot Gna heatmap
utils.aggregated_G_heatmap(
    metric_name=f"Aggregated normalized {metrics_name_mapping[metric_name]}",
    csa_metric_name='Gna',
    scores_aggregated_data=Gna_df.copy(),
    # palette="RdPu",
    palette=Gn_palette,
    vmin=0,
    vmax=0.6,
    outdir=outdir,
    file_format=file_format,
    dpi=dpi,
    show=show_plot
)

In [None]:
# palette = "RdGn"

# if palette in plt.colormaps():  # Check if the palette is a valid Matplotlib colormap
#     cmap = plt.get_cmap(palette)
# else:
#     cmap = sns.color_palette(palette, as_cmap=True)  # Use Seaborn for custom palettes

# print(cmap)

In [None]:
# palette in plt.colormaps()

In [None]:
   available_colormaps = plt.colormaps()
   print(available_colormaps)

In [None]:
# import matplotlib
# matplotlib.__version__

# Source-to-Target Variability Ratio (STVR)

In [None]:
# # STVR (Source-to-Target Variability Ratio)
# """
# The STVR quantifies the relative variability of a model’s performance for a 
# source-target pair by comparing the variability (standard deviation) of 
# predictions to the average performance for that pair. It is designed to provide 
# insight into the stability of predictions across different source-target combinations.

# STVR evaluates stability per source-target pair, telling you how consistent or 
# erratic the model predictions are in that scenario.

# Key Characteristics
#     - Pairwise Metric: Evaluates the variability for each source-target combination,
#         enabling fine-grained analysis of prediction stability.
#     - Normalization: Normalizes the standard deviation of predictions by the mean 
#         performance, allowing for direct comparison across source-target pairs, 
#         regardless of scale.
#     - Interpretation
#         - > 1: Indicates high variability relative to the average performance, 
#             suggesting instability.
#         - 0 < STVR < 1: Indicates low variability relative to the average performance, 
#             suggesting stability.
#         - < 0: Reflects variability relative to poor performance
#     - Caveats
#         - Sensitive to both variability (numerator) and performance (denominator): 
#             Small mean performance values in the denominator can inflate the ratio,
#             potentially misrepresenting variability.
#         - Requires sufficient sample size for robust std computation.
#     - Edge Cases
#         - Zero mean performance: If mean_abs(src→trg)=0, assign a default value (e.g., 0) 
#             to avoid division by zero.

# Formula:
#     STVR[src][trg] = std_dev(src → trg) / mean_score(src → trg)
#     Where:
#         - std_dev(src→trg): Standard deviation of predictions for the source-target pair.
#         - mean_score(src→trg): Mean performance for the source-target pair.
# """

# def compute_stvr_vectorized(scores, stds):
#     """
#     Compute STVR (Source-to-Target Variability Ratio) using a vectorized approach.

#     Args:
#         scores (pd.DataFrame): DataFrame where each cell contains a performance score.
#         stds (pd.DataFrame): DataFrame where each cell contains the standard deviation 
#                              of scores for the corresponding source-target pair.

#     Returns:
#         pd.DataFrame: A DataFrame with STVR values for each source-target pair.
#     """
#     #mean_abs = scores.abs()  # Compute mean absolute score (element-wise)
#     #stvr = stds / mean_abs.replace(0, np.nan)  # Compute STVR, avoiding division by zero
#     stvr = stds / scores.replace(0, np.nan)  # Compute STVR, avoiding division by zero
#     stvr = stvr.fillna(0)  # Replace NaN values with 0 for consistency
#     return stvr


# def compute_stvr_bruteforce(scores, stds):
#     """
#     Compute STVR (Source-to-Target Variability Ratio) for each source-target pair.
#     This implementation avoids vectorized operations.

#     Args:
#         scores (pd.DataFrame): DataFrame where each cell contains a performance score.
#         stds (pd.DataFrame): DataFrame where each cell contains the standard deviation 
#                              of scores for the corresponding source-target pair.

#     Returns:
#         pd.DataFrame: A DataFrame with STVR values for each source-target pair.
#     """
#     stvr = pd.DataFrame(index=scores.index, columns=scores.columns)

#     # Iterate through each source-target pair
#     for src in scores.index:
#         for trg in scores.columns:
#             # Get mean absolute score and std deviation
#             #mean_abs = abs(scores.loc[src, trg])
#             mena_score = scores.loc[src, trg]
#             std_dev = stds.loc[src, trg]
            
#             # Compute STVR
#             #stvr.loc[src, trg] = std_dev / mean_abs if mean_abs != 0 else 0
#             stvr.loc[src, trg] = std_dev / mena_score if mena_score != 0 else 0

#     return stvr


# stvr_bruteforce = compute_stvr_bruteforce(scores, stds).apply(pd.to_numeric)
# stvr_vectorized = compute_stvr_vectorized(scores, stds).apply(pd.to_numeric)

# print(stvr_bruteforce)
# print(stvr_vectorized)

# assert np.allclose(stvr_bruteforce.values, stvr_vectorized.values,
#         atol=1e-6,  # Adjust the absolute tolerance for small rounding errors
#         rtol=1e-5,   # Adjust the relative tolerance
#         equal_nan=True), "STVR results do not match!"

# Source-to-Target Variability Index (STVI)

In [None]:
# # STVI (Source-to-Target Variability Index)
# """
# STVI summarizes the variability of a model’s performance considering all 
# source-target pairs and normalizing the overall variability (standard deviation) 
# by the overall mean absolute performance.

# STVI aggregates variability globally, helping you compare which models are more 
# stable across the board, but not telling you which datasets or source-target 
# pairs contribute to that stability or instability.
# """

# def compute_stvi_vectorized(scores, stds):
#     """
#     Compute STVI (Source-to-Target Variability Index) using a vectorized approach.

#     Args:
#         scores (pd.DataFrame): DataFrame where each cell contains a performance score.
#         stds (pd.DataFrame): DataFrame where each cell contains the standard deviation 
#                              of scores for the corresponding source-target pair.

#     Returns:
#         float: The STVI value for the entire CSA study.
#     """
#     mean_abs = scores.abs().mean().mean()  # Compute overall mean absolute score
#     overall_std_dev = stds.values.std()  # Compute overall standard deviation of stds

#     # Compute STVI
#     return overall_std_dev / mean_abs if mean_abs != 0 else 0


# def compute_stvi_bruteforce(scores, stds):
#     """
#     Compute STVI (Source-to-Target Variability Index) for the entire CSA study.
#     This implementation avoids vectorized operations for a fully brute-force calculation.

#     Args:
#         scores (pd.DataFrame): DataFrame where each cell contains a performance score.
#         stds (pd.DataFrame): DataFrame where each cell contains the standard deviation 
#                              of scores for the corresponding source-target pair.

#     Returns:
#         float: The STVI value for the entire CSA study.
#     """
#     all_scores = []
#     all_stds = []

#     # Flatten all scores and std deviations
#     for src in mean_csa_data.index:
#         for trg in mean_csa_data.columns:
#             all_scores.append(abs(mean_csa_data.loc[src, trg]))  # Collect absolute values of scores
#             all_stds.append(std_csa_data.loc[src, trg])  # Collect standard deviations

#     # Compute mean absolute score and overall standard deviation
#     mean_abs = sum(all_scores) / len(all_scores)
#     overall_std_dev = (sum((x - np.mean(all_stds)) ** 2 for x in all_stds) / len(all_stds)) ** 0.5

#     # Compute STVI
#     return overall_std_dev / mean_abs if mean_abs != 0 else 0


# stvi_bruteforce = compute_stvi_bruteforce(mean_csa_data, std_csa_data)
# stvi_vectorized = compute_stvi_vectorized(mean_csa_data, std_csa_data)

# print(stvi_vectorized)
# print(stvi_bruteforce)

# assert np.isclose(stvi_bruteforce, stvi_vectorized, equal_nan=True), "STVI results do not match!"

# Runtime Analysis

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns

# # Load the data
# df = pd.read_csv('all_models_runtimes.csv')

# # 1. Box Plot of Total Minutes by Model
# plt.figure(figsize=(12, 6))
# sns.boxplot(x='model', y='tot_mins', data=df)
# plt.title('Box Plot of Total Minutes by Model')
# plt.xticks(rotation=45)
# plt.ylabel('Total Minutes')
# plt.xlabel('Model')
# plt.tight_layout()
# plt.show()

# # 2. Bar Plot of Average Total Minutes by Source
# avg_tot_mins_src = df.groupby('src')['tot_mins'].mean().reset_index()
# plt.figure(figsize=(12, 6))
# sns.barplot(x='src', y='tot_mins', data=avg_tot_mins_src)
# plt.title('Average Total Minutes by Source')
# plt.ylabel('Average Total Minutes')
# plt.xlabel('Source')
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.show()

# # 3. Line Plot of Total Minutes by Stage
# avg_tot_mins_stage = df.groupby('stage')['tot_mins'].mean().reset_index()
# plt.figure(figsize=(12, 6))
# sns.lineplot(x='stage', y='tot_mins', data=avg_tot_mins_stage, marker='o')
# plt.title('Average Total Minutes by Stage')
# plt.ylabel('Average Total Minutes')
# plt.xlabel('Stage')
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.show()

# # 4. Heatmap of Total Minutes by Source and Target
# heatmap_data = df.pivot_table(values='tot_mins', index='src', columns='trg', aggfunc='mean')
# plt.figure(figsize=(12, 8))
# sns.heatmap(heatmap_data, annot=True, fmt=".1f", cmap='viridis')
# plt.title('Heatmap of Average Total Minutes by Source and Target')
# plt.ylabel('Source')
# plt.xlabel('Target')
# plt.tight_layout()
# plt.show()

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns

# # Load the data
# df = pd.read_csv('all_models_runtimes.csv')

# # Group by src, stage, and model, and calculate the mean and standard deviation of tot_mins
# stage_model_src_stats = df.groupby(['src', 'stage', 'model'])['tot_mins'].agg(['mean', 'std', 'count']).reset_index()

# # Calculate the standard error of the mean (sem)
# stage_model_src_stats['sem'] = stage_model_src_stats['std'] / stage_model_src_stats['count'] ** 0.5

# # Define a color palette
# colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

# # Create separate plots for each stage
# stages = stage_model_src_stats['stage'].unique()

# for stage in stages:
#     plt.figure(figsize=(14, 7))
#     stage_data = stage_model_src_stats[stage_model_src_stats['stage'] == stage]
    
#     bar_plot = sns.barplot(x='src', y='mean', hue='model', data=stage_data, palette=colors, errorbar=None)
    
#     # Add error bars for each bar
#     for index, bar in enumerate(bar_plot.patches):
#         height = bar.get_height()
#         sem = stage_data['sem'].iloc[index]
        
#         plt.errorbar(x=bar.get_x() + bar.get_width() / 2, 
#                      y=height, 
#                      yerr=sem, 
#                      fmt='none', 
#                      c='black', 
#                      capsize=5, 
#                      elinewidth=1)

#     plt.title(f'Distribution of Total Minutes for Stage {stage} with Error Bars')
#     plt.ylabel('Average Total Minutes')
#     plt.xlabel('Source')
#     plt.xticks(rotation=45)
#     plt.legend(title='Model')
#     plt.tight_layout()
#     plt.show()

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns

# # Load the data
# df = pd.read_csv('all_models_runtimes.csv')

# # Group by src, stage, and model, and calculate the mean and standard deviation of tot_mins
# stage_model_src_stats = df.groupby(['src', 'stage', 'model'])['tot_mins'].agg(['mean', 'std', 'count']).reset_index()

# # Create separate box plots for each stage
# stages = stage_model_src_stats['stage'].unique()

# # Define a color palette
# colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

# for stage in stages:
#     plt.figure(figsize=(14, 7))
#     stage_data = df[df['stage'] == stage]  # Filter data for the current stage
    
#     # Create a box plot
#     sns.boxplot(x='src', y='tot_mins', hue='model', data=stage_data, palette=colors)
    
#     plt.title(f'Distribution of Total Minutes for Stage: {stage}')
#     plt.ylabel('Total Minutes')
#     plt.xlabel('Source')
#     plt.xticks(rotation=45)
#     plt.legend(title='Model')
#     plt.tight_layout()
#     plt.show()