# Imports

In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
import sys

sys.path.append('../../Benchmarks/y_pred results')
import metrics_tools as mt

# Benchmark

In [2]:
adata_org = sc.read_h5ad('../../Datasets/preprocessed_datasets/prostate.h5ad')
adata_org.X = adata_org.X.toarray()

In [3]:
time_points = [
    'T02_Cast_Day7', 'T03_Cast_Day14', 
    'T04_Cast_Day28', 'T05_Regen_Day1', 'T06_Regen_Day2',
    'T07_Regen_Day3', 'T08_Regen_Day7', 'T09_Regen_Day14',
    'T10_Regen_Day28'
]

In [4]:
# Metrics for scDisentangle predictions and two simple baselines

all_pred_correlations = {}
all_pred_distances = {}

# Baseline 1: "pert_mean"  = L2 cells in all other timepoints (except OOD and T00)
all_pert_mean_correlations = {}
all_pert_mean_distances = {}

# Baseline 2: "ctrl_mean" = all non‑L2 cells at this timepoint
all_ctrl_mean_correlations = {}
all_ctrl_mean_distances = {}

for time_point in time_points:
    print(time_point)
    
    # Work on a fresh copy of the original data
    adata = adata_org.copy()
    
    # Load scDisentangle predictions and GT (Luminal L1/L2 + T00 controls)
    pred = sc.read_h5ad(f'predictions/{time_point}.h5ad')
    gt = sc.read_h5ad(f'ground_truth/{time_point}.h5ad')

    # Use the same library size as in prediction script
    ood_mask = ~(
        (adata_org.obs['time'] == time_point)
        & (adata_org.obs['predType'] == 'Epi_Luminal_2Psca')
    )
    adata_train = adata_org[ood_mask].copy()
    _sums = adata_train.X.sum(axis=1, keepdims=True)   
    data_median = np.median(_sums)

    sc.pp.normalize_total(adata, target_sum=data_median)
    sc.pp.log1p(adata)

    # Restrict to Luminal_2Psca cells
    pred = pred[pred.obs['predType'] == 'Epi_Luminal_2Psca'].copy()
    gt = gt[gt.obs['predType'] == 'Epi_Luminal_2Psca'].copy()

    # Split into stim (time_ood) and ctrl (T00)
    pred_stim = pred[pred.obs['time_pred'] == time_point]
    pred_ctrl = pred[pred.obs['time_pred'] == 'T00']

    gt_stim = gt[gt.obs['time'] == time_point]
    gt_ctrl = gt[gt.obs['time'] == 'T00']

    # DEGs indices for this timepoint and cell type
    try:
        degs = adata.uns['rank_genes_groups_time'][time_point]['Epi_Luminal_2Psca']
        degs_indices = [adata.var_names.get_loc(x) for x in degs]
    except Exception as e:
        print('Skipping', time_point, 'no DEGs found:', e)
        continue

    # === 1) Metrics for scDisentangle predictions ===
    pred_correlations = mt.get_correlations(
        _pred_stim=pred_stim,
        _true_stim=gt_stim,
        _pred_ctrl=pred_ctrl,
        _true_ctrl=gt_ctrl,
        degs_indices=degs_indices,
        degs_list=[200, 100, 50, 20, 10],
    )

    pred_distances = mt.get_distances(
        _pred_stim=pred_stim,
        _true_stim=gt_stim,
        _pred_ctrl=pred_ctrl,
        _true_ctrl=gt_ctrl,
        degs_indices=degs_indices,
        degs_list=[200, 100, 50, 20, 10],
    )

    all_pred_correlations[time_point] = pred_correlations
    all_pred_distances[time_point] = pred_distances

    # === 2) Baseline 1: pert_mean (L2 in all other timepoints except this and T00) ===
    pert_mean_stim = adata[(adata.obs['predType'] == 'Epi_Luminal_2Psca') &
                           (~adata.obs['time'].isin([time_point, 'T00']))].copy()
   

    pert_mean_correlations = mt.get_correlations(
            _pred_stim=pert_mean_stim,
            _true_stim=gt_stim,
            _pred_ctrl=gt_ctrl,
            _true_ctrl=gt_ctrl,
            degs_indices=degs_indices,
            degs_list=[200, 100, 50, 20, 10],
        )

    pert_mean_distances = mt.get_distances(
            _pred_stim=pert_mean_stim,
            _true_stim=gt_stim,
            _pred_ctrl=gt_ctrl,
            _true_ctrl=gt_ctrl,
            degs_indices=degs_indices,
            degs_list=[200, 100, 50, 20, 10],
        )

    all_pert_mean_correlations[time_point] = pert_mean_correlations
    all_pert_mean_distances[time_point] = pert_mean_distances


    # === 3) Baseline 2: ctrl_mean (all non‑L2 cells at this timepoint) ===
    ctrl_mean_stim = adata[(adata.obs['time'] == time_point) &
                           (adata.obs['predType'] != 'Epi_Luminal_2Psca')].copy()
  

    ctrl_mean_correlations = mt.get_correlations(
            _pred_stim=ctrl_mean_stim,
            _true_stim=gt_stim,
            _pred_ctrl=gt_ctrl,
            _true_ctrl=gt_ctrl,
            degs_indices=degs_indices,
            degs_list=[200, 100, 50, 20, 10],
        )

    ctrl_mean_distances = mt.get_distances(
            _pred_stim=ctrl_mean_stim,
            _true_stim=gt_stim,
            _pred_ctrl=gt_ctrl,
            _true_ctrl=gt_ctrl,
            degs_indices=degs_indices,
            degs_list=[200, 100, 50, 20, 10],
        )

    all_ctrl_mean_correlations[time_point] = ctrl_mean_correlations
    all_ctrl_mean_distances[time_point] = ctrl_mean_distances
    

T02_Cast_Day7


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T03_Cast_Day14


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T04_Cast_Day28


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T05_Regen_Day1


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T06_Regen_Day2


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T07_Regen_Day3


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T08_Regen_Day7


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T09_Regen_Day14


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


T10_Regen_Day28


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [5]:
import os

# Save baselines in the same CSV format as `compute_metrics.py`
data_name = 'Prostate'

for method_name, corr_dict, dist_dict in [
    ('SCDISENTANGLE', all_pred_correlations, all_pred_distances),
    ('Pert Mean', all_pert_mean_correlations, all_pert_mean_distances),
    ('STIM baseline for the ctrl_mean_stim', all_ctrl_mean_correlations, all_ctrl_mean_distances),
]:
    save_dir = f"../../Benchmarks/y_pred results/results/{data_name}/{method_name}"
    os.makedirs(save_dir, exist_ok=True)

    for ood_cov, corr_metrics in corr_dict.items():
        # Merge correlations and distances into one dict
        all_metrics = {}
        all_metrics.update(corr_metrics)
        if ood_cov in dist_dict:
            all_metrics.update(dist_dict[ood_cov])

        # Rows = metric names, cols = DEG sizes (200,100,50,20,10)
        metrics_df = pd.DataFrame(all_metrics).T
        metrics_df.index.rename('Metric', inplace=True)

        # Use "_1" as a dummy seed to match the <ood_cov>_<seed>.csv pattern
        metrics_df.to_csv(f"{save_dir}/{ood_cov}_1.csv")

save_dir

'../../Benchmarks/y_pred results/results/Prostate/STIM baseline for the ctrl_mean_stim'