In [None]:
# Inference first, then R2 calculation

In [None]:
import os
os.chdir('../')
data_dir = 'data'
out_dir = 'output'

In [None]:
# this is the inference code 
# this saves many figures per perturbation (you can comment out if you want)
from src.preprocess import get_data


import pytensor
pytensor.config.cxx = ""
pytensor.config.mode = "NUMBA"

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.linalg import pinv
from scipy.sparse import issparse
from scipy.stats import nbinom
import pymc as pm
import os
import pickle
import anndata as ad
import scanpy as sc

def sample_zinb(mean, var, zero_prob, size):
    if var <= mean:
        var = mean + 1e-3
    p = mean / var
    r = mean**2 / (var - mean)
    nb_samples = nbinom.rvs(n=r, p=p, size=size)
    zeros = np.random.rand(size) < zero_prob
    nb_samples[zeros] = 0
    return nb_samples

def compute_zinb_covariance(X0, n_cells=None, seed=42):
    if issparse(X0):
        X0 = X0.toarray()
    np.random.seed(seed)
    n_obs, n_genes = X0.shape
    n_cells = n_obs if n_cells is None else n_cells
    mu = X0.mean(axis=0)
    var_emp = X0.var(axis=0)
    zero_prob = (X0 == 0).mean(axis=0)
    low_range = np.random.uniform(-1, -.5, size=n_genes // 2)
    high_range = np.random.uniform(.5, 1, size=n_genes - len(low_range))
    scaling_factors = 10**np.random.permutation(np.concatenate([low_range, high_range]))
    target_var = scaling_factors * var_emp
    synthetic_data = np.zeros((n_cells, n_genes))
    for g in range(n_genes):
        synthetic_data[:, g] = sample_zinb(mu[g], target_var[g], zero_prob[g], n_cells)
    Sigma_zinb = np.cov(synthetic_data, rowvar=False)
    return Sigma_zinb, synthetic_data, scaling_factors

def compute_average_response(X0, X1):
    return X1.mean(axis=0) - X0.mean(axis=0)

def compute_sparse_perturbation(Sigma_inv, delta_X, gene_indices, top_k=200):
    u_hat = Sigma_inv @ delta_X
    sorted_indices = np.argsort(np.abs(u_hat))[::-1]
    top_indices = sorted_indices[:top_k].tolist()
    for gene_index in gene_indices:
        if gene_index not in top_indices:
            top_indices[0] = gene_index
    u_sparse = np.zeros_like(u_hat)
    u_sparse[top_indices] = u_hat[top_indices]
    return u_sparse, sorted(top_indices)

def run_mcmc_horseshoe_learnable_sigma(Sigma_sub, delta_X_sub, draws=1000, tune=1000):
    G = Sigma_sub.shape[0]
    with pm.Model() as model:
        lambda_ = pm.HalfCauchy("lambda", beta=1.0, shape=G)
        log_tau = pm.Normal("log_tau", mu=-4, sigma=1)
        tau = pm.Deterministic("tau", pm.math.exp(log_tau))
        z = pm.Normal("z", 0, 1, shape=G)
        u = pm.Deterministic("u", z * tau * lambda_)
        log_sigma = pm.Normal("log_sigma_obs", mu=-2, sigma=2)
        sigma_obs = pm.Deterministic("sigma_obs", pm.math.exp(log_sigma))
        mu_x = pm.math.dot(Sigma_sub, u)
        obs = pm.Normal("obs", mu=mu_x, sigma=sigma_obs, observed=delta_X_sub)
        trace = pm.sample(draws=draws, tune=tune, target_accept=0.95, max_treedepth=10, chains=8, cores=8, progressbar=True)
    return trace

def save_u_samples_summary_double_pert(data_path, output_dir="u_samples_summaries_double"):
    os.makedirs(output_dir, exist_ok=True)
    adata, X0, X1 = get_data(0, data_path)
    if adata is None:
        return None

    gene_names = np.array(adata.var_names.tolist())
    Sigma = np.cov(X0, rowvar=False)
    Sigma_inv = pinv(Sigma)

    perturbations = [p for p in adata.obs['perturbation'].unique() if '_' in p]
    all_rows = []
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")

    for pert in perturbations:
        pert_split = pert.split('_')
        if len(pert_split) != 2:
            continue
        g1, g2 = pert_split
        gene_indices = [np.where(gene_names == g)[0][0] for g in [g1, g2] if g in gene_names]
        if len(gene_indices) < 2:
            continue

        X1 = adata[adata.obs['perturbation'] == pert].X
        X1 = X1.toarray() if issparse(X1) else X1
        delta_X = compute_average_response(X0, X1)

        u_hat = Sigma_inv @ delta_X
        abs_u = np.abs(u_hat)
        sorted_indices = np.argsort(-abs_u)
        ranks = {idx: rank for rank, idx in enumerate(sorted_indices)}

        print(f"↳ {pert}:")
        for g in [g1, g2]:
            idx = np.where(gene_names == g)[0][0]
            print(f"   - Gene {g} ranked #{ranks[idx]} with |u_hat| = {abs_u[idx]:.4f}")

        u_sparse, sparse_indices = compute_sparse_perturbation(Sigma_inv, delta_X, gene_indices, top_k=200)
        Sigma_sub = Sigma[np.ix_(sparse_indices, sparse_indices)]
        delta_X_sub = delta_X[sparse_indices]

        try:
            trace = run_mcmc_horseshoe_learnable_sigma(Sigma_sub, delta_X_sub)
        except Exception as e:
            print(f"[SKIP] {pert} due to MCMC error: {e}")
            continue

        u_samples = trace.posterior['u'].stack(sample=("chain", "draw")).values
        u_mean_local = np.mean(u_samples, axis=1)
        u_std_local = np.std(u_samples, axis=1)
        pip_local = np.mean(np.abs(u_samples) > 0.05, axis=1)

        for j, idx in enumerate(sparse_indices):
            row = {
                "Gene": gene_names[idx],
                "Perturbation": pert,
                "IsTruePerturbation": int(idx in gene_indices),
                "PIP": pip_local[j],
                "U_Mean": u_mean_local[j],
                "U_Std": u_std_local[j],
                "U_Samples": u_samples[j].tolist() if idx in gene_indices else None
            }
            all_rows.append(row)

        df_pert = pd.DataFrame([row for row in all_rows if row["Perturbation"] == pert])
        if df_pert.empty:
            continue

        true_gene_list = [gene_names[i] for i in gene_indices]
        gene_to_idx = {row["Gene"]: i for i, row in df_pert.iterrows()}

        plt.figure(figsize=(12, 4))
        plt.scatter(range(len(df_pert)), df_pert["PIP"], alpha=0.5)
        for g in true_gene_list:
            idx = gene_to_idx.get(g)
            if idx is not None:
                plt.scatter(idx, df_pert.loc[idx, "PIP"], color='red', s=60, label=g)
        plt.title(f"PIP for {pert} — {dataset_name}")
        plt.xlabel("Sparse Gene Index")
        plt.ylabel("Posterior Inclusion Probability")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{dataset_name}_{pert}_PIP.png")
        plt.show()
        plt.close()

        plt.figure(figsize=(12, 4))
        x = range(len(df_pert))
        y = df_pert["U_Mean"]
        yerr = df_pert["U_Std"]
        plt.errorbar(x, y, yerr=yerr, fmt='o', alpha=0.7, capsize=3)
        for g in true_gene_list:
            idx = gene_to_idx.get(g)
            if idx is not None:
                plt.scatter(idx, df_pert.loc[idx, "U_Mean"], color='red', s=60, label=g)
        plt.title(f"Posterior mean ± std for {pert} — {dataset_name}")
        plt.xlabel("Sparse Gene Index")
        plt.ylabel("Posterior Mean of $u$")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{dataset_name}_{pert}_UMean.png")
        plt.show()
        plt.close()

        for g in true_gene_list:
            idx = gene_to_idx.get(g)
            if idx is not None:
                u_samples = df_pert.loc[idx, "U_Samples"]
                plt.figure(figsize=(8, 4))
                plt.hist(u_samples, bins=50, density=True, alpha=0.7)
                plt.axvline(0, color='black', linestyle='--')
                plt.yscale('log')
                plt.title(f"Posterior of $u_{{{g}}}$ — {pert} — {dataset_name}")
                plt.xlabel("Value of $u$")
                plt.ylabel("Density")
                plt.tight_layout()
                plt.savefig(f"{output_dir}/{dataset_name}_{pert}_{g}_posterior.png")
                plt.show()
                plt.close()

    summary_df = pd.DataFrame(all_rows)
    save_path = os.path.join(output_dir, os.path.basename(data_path).replace(".h5ad", "_usamples_double.pkl"))
    with open(save_path, "wb") as f:
        pickle.dump(summary_df, f)
    print(f"Saved summary + plots for: {data_path}")
    return summary_df

In [None]:
double_pert_paths = [
    "TianKampmann2019_day7neuron.h5ad",
    "NormanWeissman2019_filtered.h5ad"
]

for path in double_pert_paths:
    path = os.path.join(data_dir, path)
    save_u_samples_summary_double_pert(data_path=path, output_dir="u_samples_summaries_double")


In [None]:
# this produces the figures for fig4 supplement 
# runs inference and plots a few specific double perturbations


import pytensor
pytensor.config.cxx = ""
pytensor.config.mode = "NUMBA"

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.linalg import pinv
from scipy.sparse import issparse
import pymc as pm
import os
import pickle


def compute_average_response(X0, X1):
    return X1.mean(axis=0) - X0.mean(axis=0)

def compute_sparse_perturbation(Sigma_inv, delta_X, gene_indices, top_k=200):
    u_hat = Sigma_inv @ delta_X
    sorted_indices = np.argsort(np.abs(u_hat))[::-1]
    top_indices = sorted_indices[:top_k].tolist()
    for gene_index in gene_indices:
        if gene_index not in top_indices:
            top_indices[0] = gene_index
    u_sparse = np.zeros_like(u_hat)
    u_sparse[top_indices] = u_hat[top_indices]
    return u_sparse, sorted(top_indices)

def run_mcmc_horseshoe_learnable_sigma(Sigma_sub, delta_X_sub, draws=1000, tune=1000):
    G = Sigma_sub.shape[0]
    with pm.Model() as model:
        lambda_ = pm.HalfCauchy("lambda", beta=1.0, shape=G)
        log_tau = pm.Normal("log_tau", mu=-4, sigma=1)
        tau = pm.Deterministic("tau", pm.math.exp(log_tau))
        z = pm.Normal("z", 0, 1, shape=G)
        u = pm.Deterministic("u", z * tau * lambda_)
        log_sigma = pm.Normal("log_sigma_obs", mu=-2, sigma=2)
        sigma_obs = pm.Deterministic("sigma_obs", pm.math.exp(log_sigma))
        mu_x = pm.math.dot(Sigma_sub, u)
        obs = pm.Normal("obs", mu=mu_x, sigma=sigma_obs, observed=delta_X_sub)
        trace = pm.sample(draws=draws, tune=tune, target_accept=0.95, max_treedepth=10, chains=8, cores=8)
    return trace

def run_inference_for_selected_double_perturbations(data_path, selected_dps, output_dir="fig4supplement"):
    os.makedirs(output_dir, exist_ok=True)
    adata, X0, _ = get_data(0, data_path)
    if adata is None:
        print(f"Could not load valid data from: {data_path}")
        return None

    gene_names = np.array(adata.var_names.tolist())
    Sigma = np.cov(X0, rowvar=False)
    Sigma_inv = pinv(Sigma)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    all_rows = []

    for pert in selected_dps:
        if pert not in adata.obs['perturbation'].unique():
            print(f"Perturbation {pert} not found in dataset.")
            continue

        try:
            g1, g2 = pert.split("_")
        except:
            print(f"Invalid format for {pert}. Skipping.")
            continue

        gene_indices = [np.where(gene_names == g)[0][0] for g in [g1, g2] if g in gene_names]
        if len(gene_indices) < 2:
            print(f"One or both genes not found in var_names for {pert}")
            continue

        X1 = adata[adata.obs['perturbation'] == pert].X
        X1 = X1.toarray() if issparse(X1) else X1
        delta_X = compute_average_response(X0, X1)

        u_hat = Sigma_inv @ delta_X
        abs_u = np.abs(u_hat)
        sorted_indices = np.argsort(-abs_u)
        ranks = {idx: rank for rank, idx in enumerate(sorted_indices)}

        print(f"↳ {pert}:")
        for g in [g1, g2]:
            idx = np.where(gene_names == g)[0][0]
            print(f"   - Gene {g} ranked #{ranks[idx]} with |u_hat| = {abs_u[idx]:.4f}")

        u_sparse, sparse_indices = compute_sparse_perturbation(Sigma_inv, delta_X, gene_indices, top_k=200)
        Sigma_sub = Sigma[np.ix_(sparse_indices, sparse_indices)]
        delta_X_sub = delta_X[sparse_indices]

        try:
            trace = run_mcmc_horseshoe_learnable_sigma(Sigma_sub, delta_X_sub)
        except Exception as e:
            print(f"MCMC failed for {pert}: {e}")
            continue

        u_samples = trace.posterior['u'].stack(sample=("chain", "draw")).values
        u_mean_local = np.mean(u_samples, axis=1)
        u_std_local = np.std(u_samples, axis=1)
        pip_local = np.mean(np.abs(u_samples) > 0.05, axis=1)

        for j, idx in enumerate(sparse_indices):
            row = {
                "Gene": gene_names[idx],
                "Perturbation": pert,
                "IsTruePerturbation": int(idx in gene_indices),
                "PIP": pip_local[j],
                "U_Mean": u_mean_local[j],
                "U_Std": u_std_local[j],
                "U_Samples": u_samples[j].tolist() if idx in gene_indices else None
            }
            all_rows.append(row)

        df_pert = pd.DataFrame([row for row in all_rows if row["Perturbation"] == pert])
        if df_pert.empty:
            continue

        true_gene_list = [gene_names[i] for i in gene_indices]
        gene_to_idx = {row["Gene"]: i for i, row in df_pert.iterrows()}

        plt.figure(figsize=(12, 4))
        plt.scatter(range(len(df_pert)), df_pert["PIP"], alpha=0.5)
        for g in true_gene_list:
            idx = gene_to_idx.get(g)
            if idx is not None:
                plt.scatter(idx, df_pert.loc[idx, "PIP"], color='red', s=60, label=g)
        plt.title(f"PIP for {pert} — {dataset_name}")
        plt.xlabel("Sparse Gene Index")
        plt.ylabel("Posterior Inclusion Probability")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{dataset_name}_{pert}_PIP.svg")
        plt.close()

        plt.figure(figsize=(12, 4))
        x = range(len(df_pert))
        y = df_pert["U_Mean"]
        yerr = df_pert["U_Std"]
        plt.errorbar(x, y, yerr=yerr, fmt='o', alpha=0.7, capsize=3)
        for g in true_gene_list:
            idx = gene_to_idx.get(g)
            if idx is not None:
                plt.scatter(idx, df_pert.loc[idx, "U_Mean"], color='red', s=60, label=g)
        plt.title(f"Posterior mean ± std for {pert} — {dataset_name}")
        plt.xlabel("Sparse Gene Index")
        plt.ylabel("Posterior Mean of $u$")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{output_dir}/{dataset_name}_{pert}_UMean.svg")
        plt.close()

        for g in true_gene_list:
            idx = gene_to_idx.get(g)
            if idx is not None:
                u_samples = df_pert.loc[idx, "U_Samples"]
                plt.figure(figsize=(8, 4))
                plt.hist(u_samples, bins=50, density=True, alpha=0.7)
                plt.axvline(0, color='black', linestyle='--')
                plt.yscale('log')
                plt.title(f"Posterior of $u_{{{g}}}$ — {pert} — {dataset_name}")
                plt.xlabel("Value of $u$")
                plt.ylabel("Density")
                plt.tight_layout()
                plt.savefig(f"{output_dir}/{dataset_name}_{pert}_{g}_posterior.svg")
                plt.close()

    summary_df = pd.DataFrame(all_rows)
    save_path = os.path.join(output_dir, f"{dataset_name}_selected_usamples.pkl")
    with open(save_path, "wb") as f:
        pickle.dump(summary_df, f)
    print(f"Saved summary and plots to: {output_dir}")
    return summary_df



In [None]:
# === Call for selected perturbations ===
selected_double_perturbations = ['ZC3HAV1_HOXC13', 'ZC3HAV1_CEBPE', 'SGK1_S1PR2', 'PTPN12_SNAIL1']
run_inference_for_selected_double_perturbations(
    data_path=os.path.join(data_dir, "NormanWeissman2019_filtered.h5ad"),
    selected_dps=selected_double_perturbations,
    output_dir="fig4supplement"
) 

In [None]:
# uses the inference data generated above to produce rocauc curves for 
# FIG 4 H


import os
import glob
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, average_precision_score

# === Load summaries ===
def load_u_samples_summary(filepath):
    with open(filepath, "rb") as f:
        return pickle.load(f)

summary_files = sorted(glob.glob("u_samples_summaries_double/*.pkl"))
print(f"Found {len(summary_files)} summaries.")

dfs = []
for file in summary_files:
    df = load_u_samples_summary(file)
    df['Dataset'] = os.path.basename(file).replace("_usamples_with_lr.pkl", "")
    dfs.append(df)

all_df = pd.concat(dfs, ignore_index=True)
print(f"Loaded {len(all_df)} rows total.")

# === ROC AUC for max(|U ± std|) — Per Dataset ===
datasets = all_df['Dataset'].unique()
roc_data_all = []

plt.figure(figsize=(10, 8))

for dataset in datasets:
    df_subset = all_df.query(f"Dataset == '{dataset}'")
    if df_subset['IsTruePerturbation'].sum() == 0:
        print(f"⚠️ Skipping {dataset} (no true perturbations)")
        continue

    optimistic = np.abs(df_subset['U_Mean'] + df_subset['U_Std'])
    pessimistic = np.abs(df_subset['U_Mean'] - df_subset['U_Std'])
    max_up_down = np.maximum(optimistic, pessimistic)

    fpr, tpr, _ = roc_curve(df_subset['IsTruePerturbation'], max_up_down)
    roc_auc = roc_auc_score(df_subset['IsTruePerturbation'], max_up_down)
    ap_score = average_precision_score(df_subset['IsTruePerturbation'], max_up_down)

    plt.plot(fpr, tpr, label=f"{dataset} (AUC = {roc_auc:.2f})", linestyle='-.')

    roc_data_all.append({
        "Dataset": dataset,
        "Metric": "Max_Abs_U_plus_minus_Std",
        "ROC_AUC": roc_auc,
        "AveragePrecision": ap_score
    })

# === Finalize individual plot ===
plt.plot([0, 1], [0, 1], 'k:', lw=1)
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.title('ROC Curves per Dataset: max(|U ± std|)', fontsize=18)
plt.legend(fontsize=8)
plt.grid(True)
plt.tight_layout()
plt.savefig("roc_curve_individual_maxUstd.svg")
plt.show()

# === Combined ROC Curve (Separate Plot) ===
optimistic_all = np.abs(all_df['U_Mean'] + all_df['U_Std'])
pessimistic_all = np.abs(all_df['U_Mean'] - all_df['U_Std'])
max_up_down_all = np.maximum(optimistic_all, pessimistic_all)

fpr_comb, tpr_comb, _ = roc_curve(all_df['IsTruePerturbation'], max_up_down_all)
roc_auc_comb = roc_auc_score(all_df['IsTruePerturbation'], max_up_down_all)
ap_score_comb = average_precision_score(all_df['IsTruePerturbation'], max_up_down_all)

# Save combined stats
roc_data_all.append({
    "Dataset": "Combined",
    "Metric": "Max_Abs_U_plus_minus_Std",
    "ROC_AUC": roc_auc_comb,
    "AveragePrecision": ap_score_comb
})

# Plot combined
plt.figure(figsize=(8, 6))
plt.plot(fpr_comb, tpr_comb, color='black', label=f"Combined (AUC = {roc_auc_comb:.2f})", linewidth=2)
plt.plot([0, 1], [0, 1], 'k--', lw=1)
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('ROC Curve (Combined): max(|U ± std|)', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig("roc_curve_combined_maxUstd.svg")
plt.show()

# === Save summary ===
roc_summary_full_df = pd.DataFrame(roc_data_all)
roc_summary_full_df.to_csv("rocauc_summary_full_maxUstd_per_dataset.csv", index=False)
print("Saved summary to 'rocauc_summary_full_maxUstd_per_dataset.csv'")


In [None]:
# R2 (FIG3 plots and supps) BELOW HERE

In [None]:
# this is code to calcualte different R2 metric for each double perturbation
# FIG 3N,O

import anndata as ad
import numpy as np
import scanpy as sc
import logging
import matplotlib.pyplot as plt
from scipy.sparse import issparse
from scipy.linalg import pinv
import os

# ---------- Helper Functions ----------

def is_outlier(adata, column, nmads):
    vals = adata.obs[column]
    median = np.median(vals)
    mad = np.median(np.abs(vals - median))
    threshold = nmads * mad
    return (vals > median + threshold) | (vals < median - threshold)

def quality_control_filter(adata, percent_threshold=20, nmads=5, mt_nmads=5, mt_per=20):
    adata.var_names = adata.var_names.astype(str)
    adata.var_names_make_unique()
    adata.var['mt'] = adata.var_names.str.startswith('MT-')
    adata.var['ribo'] = adata.var_names.str.startswith(('RPS', 'RPL'))
    adata.var['hb'] = adata.var_names.str.contains('^HB[^(P)]')

    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt', 'ribo', 'hb'],
                               inplace=True, percent_top=[percent_threshold], log1p=True)

    adata.obs['outlier'] = (is_outlier(adata, 'log1p_total_counts', nmads) |
                             is_outlier(adata, 'log1p_n_genes_by_counts', nmads) |
                             is_outlier(adata, 'pct_counts_in_top_20_genes', nmads))
    adata.obs['mt_outlier'] = is_outlier(adata, 'pct_counts_mt', mt_nmads) | (adata.obs['pct_counts_mt'] > mt_per)

    adata = adata[adata.obs['n_genes_by_counts'] > 200]
    gene_counts = np.sum(adata.X > 0, axis=0)
    genes_to_keep = np.array(gene_counts).flatten() >= 3
    adata = adata[:, genes_to_keep]
    adata = adata[(~adata.obs.outlier) & (~adata.obs.mt_outlier)]

    return adata

def get_double_perts(adata):
    adata = adata.copy()
    adata.obs['n_perts'] = adata.obs.perturbation.str.split('_').apply(len)
    double_perts = adata.obs[adata.obs['n_perts'] == 2].perturbation.unique()
    pert_counts = adata.obs['perturbation'].value_counts()

    dp_list = []
    for dp in double_perts:
        dp_count = pert_counts.get(dp, 0)
        g1, g2 = dp.split('_')
        g1_count = pert_counts.get(g1, 0)
        g2_count = pert_counts.get(g2, 0)
        if min(g1_count, g2_count, dp_count) > 50:
            dp_list.append(dp)
    return dp_list

def get_data(selected_index, data_path):
    adata = ad.read_h5ad(data_path)
    print(f"Original data shape: {adata.shape}")
    adata = quality_control_filter(adata)

    if issparse(adata.X):
        total_counts = adata.X.sum(axis=1).A1
    else:
        total_counts = adata.X.sum(axis=1)
    adata = adata[total_counts >= 500].copy()

    pert_counts = adata.obs['perturbation'].value_counts()
    valid_perts = pert_counts[pert_counts >= 100].index.tolist()
    adata = adata[adata.obs['perturbation'].isin(valid_perts)]
    print(f"Number of valid perturbations: {len(valid_perts)}")

    dp_list = get_double_perts(adata)
    if len(dp_list) == 0:
        print(f"No double perturbations in {data_path}, skipping.")
        return None, None, None, None

    if selected_index >= len(dp_list):
        print(f"Selected index {selected_index} exceeds available double perts ({len(dp_list)}), skipping.")
        return None, None, None, None

    selected_pert_name = dp_list[selected_index]
    print('selected_pert_name =', selected_pert_name)

    gene_names = adata.var_names.to_numpy()
    gene_map = {g: i for i, g in enumerate(gene_names)}
    double_pert_genes = set()
    for dp in dp_list:
        g1, g2 = dp.split('_')
        if g1 in gene_map:
            double_pert_genes.add(gene_map[g1])
        if g2 in gene_map:
            double_pert_genes.add(gene_map[g2])

    control_data0 = adata[adata.obs['perturbation'] == 'control']
    X_dense = control_data0.X.toarray() if issparse(control_data0.X) else control_data0.X
    gene_means = X_dense.mean(axis=0)
    valid_genes = np.where(gene_means >= 1.0)[0]
    all_valid_genes = np.unique(np.concatenate([valid_genes, sorted(double_pert_genes)]))
    adata = adata[:, all_valid_genes]

    control_data0 = adata[adata.obs['perturbation'] == 'control']
    selected_pert_data = adata[adata.obs['perturbation'] == selected_pert_name]
    n_samples = min(control_data0.shape[0], selected_pert_data.shape[0])
    control_data = control_data0[:n_samples]
    selected_pert_data = selected_pert_data[:n_samples]

    X0_full = control_data0.X.toarray() if issparse(control_data0.X) else control_data0.X
    X0 = control_data.X.toarray() if issparse(control_data.X) else control_data.X
    X1 = selected_pert_data.X.toarray() if issparse(selected_pert_data.X) else selected_pert_data.X

    print('Shapes X0, X1:', X0.shape, X1.shape)
    return adata, X0_full, X0, X1






def full_double_perturbation_analysis(data_path, save_dir="r2_histograms"):
    adata, X0, _, _ = get_data(0, data_path)
    if adata is None:
        print('No double perts in dataset ', data_path)
        return None

    def get_gene_index(gene_name, gene_names):
        indices = np.where(gene_names == gene_name)[0]
        return int(indices[0]) if indices.size > 0 else None

    def compute_covariance(X):
        return np.cov(X, rowvar=False)

    def compute_average_response(X0, X1):
        return X1.mean(axis=0) - X0.mean(axis=0)

    gene_names = np.array(adata.var_names.tolist())
    X0_dense = X0.toarray() if issparse(X0) else X0
    Sigma = compute_covariance(X0_dense)

    # Null model
    X0_shuffled = X0.copy()
    for g in range(X0.shape[1]):
        np.random.shuffle(X0_shuffled[:, g])
    Sigma_S = compute_covariance(X0_shuffled)

    perturbations = [p for p in adata.obs['perturbation'].unique() if "_" in p]
    epsilon = 1e-8
    R2_real, R2_null, R2_A_list, R2_B_list, R2_additive, pert_names = [], [], [], [], [], []

    for pert in perturbations:
        pert_A, pert_B = pert.split("_")
        gene_idx_A = get_gene_index(pert_A, gene_names)
        gene_idx_B = get_gene_index(pert_B, gene_names)

        if gene_idx_A is None or gene_idx_B is None:
            print(f"Skipping {pert} (missing gene index)")
            continue

        # Load expression for double and singles
        X1_AB = adata[adata.obs['perturbation'] == pert].X
        X1_A = adata[adata.obs['perturbation'] == pert_A].X
        X1_B = adata[adata.obs['perturbation'] == pert_B].X

        X1_AB = X1_AB.toarray() if issparse(X1_AB) else X1_AB
        X1_A = X1_A.toarray() if issparse(X1_A) else X1_A
        X1_B = X1_B.toarray() if issparse(X1_B) else X1_B

        # ΔX
        delta_X_AB = compute_average_response(X0_dense, X1_AB)
        delta_X_A = compute_average_response(X0_dense, X1_A)
        delta_X_B = compute_average_response(X0_dense, X1_B)

        # Fit using real covariance
        Sigma_AB = Sigma[:, [gene_idx_A, gene_idx_B]]
        u_AB, _, _, _ = np.linalg.lstsq(Sigma_AB, delta_X_AB, rcond=None)
        pred_AB = Sigma_AB @ u_AB

        Sigma_A = Sigma[:, [gene_idx_A]]
        Sigma_B = Sigma[:, [gene_idx_B]]

        u_A, _, _, _ = np.linalg.lstsq(Sigma_A, delta_X_A, rcond=None)
        u_B, _, _, _ = np.linalg.lstsq(Sigma_B, delta_X_B, rcond=None)

        pred_A = Sigma_A @ u_A
        pred_B = Sigma_B @ u_B
        pred_add = pred_A + pred_B

        # Fit using null covariance
        Sigma_AB_null = Sigma_S[:, [gene_idx_A, gene_idx_B]]
        u_null, _, _, _ = np.linalg.lstsq(Sigma_AB_null, delta_X_AB, rcond=None)
        pred_null = Sigma_AB_null @ u_null

        # Filter for informative genes
        mean_diff = np.abs(X1_AB.mean(axis=0) - X0_dense.mean(axis=0))
        valid_idx = np.where(mean_diff > 0)[0]
        if len(valid_idx) == 0:
            print(f"Skipping {pert} (no gene change)")
            continue

        # Filter all deltas and predictions
        delta = delta_X_AB[valid_idx]
        pred_AB = pred_AB[valid_idx]
        pred_null = pred_null[valid_idx]
        pred_A = pred_A[valid_idx]
        pred_B = pred_B[valid_idx]
        pred_add = pred_add[valid_idx]

        def R2(y_true, y_pred):
            return 1.0 - np.sqrt(np.sum((y_true - y_pred) ** 2)) / np.sqrt(np.sum(y_true ** 2) + epsilon)

        # Compute all R²s
        R2_real.append(R2(delta, pred_AB))
        R2_null.append(R2(delta, pred_null))
        R2_A_list.append(R2(delta, pred_A))
        R2_B_list.append(R2(delta, pred_B))
        R2_additive.append(R2(delta, pred_add))
        pert_names.append(pert)

    # === Save & Plot ===
    os.makedirs(save_dir, exist_ok=True)
    base = os.path.basename(data_path).replace('.h5ad', '')
    np.savez(
        os.path.join(save_dir, f"R2_{base}_doublepert.npz"),
        perturbations=np.array(pert_names),
        R2_real=np.array(R2_real),
        R2_null=np.array(R2_null),
        R2_A=np.array(R2_A_list),
        R2_B=np.array(R2_B_list),
        R2_additive=np.array(R2_additive)
    )


        # Histogram with 5 R² distributions
    plt.figure(figsize=(10, 7))
    bins = np.linspace(0, 1, 40)

    plt.hist(R2_real, bins=bins, alpha=0.7, label='Double Pert. R²', density=True)
    plt.hist(R2_additive, bins=bins, alpha=0.7, label='Additive (A + B) R²', density=True)
    plt.hist(R2_A_list, bins=bins, alpha=0.5, label='Single A R²', density=True)
    plt.hist(R2_B_list, bins=bins, alpha=0.5, label='Single B R²', density=True)
    plt.hist(R2_null, bins=bins, alpha=0.3, label='Null R² (shuffled)', density=True)

    plt.xlabel("R² Score", fontsize=18)
    plt.ylabel("Density", fontsize=18)
    plt.title(f"Double Perturbation R² — {base}", fontsize=20)
    plt.legend(fontsize=12)
    plt.yscale('log')
    plt.grid(True)
    plt.tight_layout()

    save_path = os.path.join(save_dir, f"double_r2_hist_{base}.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Saved 5-R² histogram to {save_path}")

    return {
        "perturbations": pert_names,
        "R2_real": R2_real,
        "R2_null": R2_null,
        "R2_A": R2_A_list,
        "R2_B": R2_B_list,
        "R2_additive": R2_additive
    }


datapaths = [
    "NormanWeissman2019_filtered.h5ad",
    "TianKampmann2019_day7neuron.h5ad",
]

for path in datapaths:
    path = os.path.join(data_dir, path)
    result = full_double_perturbation_analysis(path)
import numpy as np
import matplotlib.pyplot as plt
import os


save_dir = os.path.join(out_dir, "r2_histograms")
os.makedirs(save_dir, exist_ok=True)

# === Collect R² mean per dataset ===
mean_r2_list = []
for path in datapaths:
    base = os.path.basename(path).replace(".h5ad", "")
    npz_path = os.path.join(save_dir, f"R2_{base}_doublepert.npz")

    if not os.path.exists(npz_path):
        print(f"Running analysis on {base}...")
        result = full_double_perturbation_analysis(path, save_dir=save_dir)
        if result is None:
            continue
    else:
        result = dict(np.load(npz_path, allow_pickle=True))

    mean_r2 = {
        key: np.mean(result[key]) for key in
        ["R2_null", "R2_A", "R2_B", "R2_additive", "R2_real"]
    }
    mean_r2_list.append(mean_r2)

# === Compute average across datasets ===
conditions = ["R2_null", "R2_A", "R2_B", "R2_additive", "R2_real"]
condition_labels = ["Shuffled", "Single A", "Single B", "Additive", "True Σ"]

mean_r2_avg = {
    cond: np.mean([m[cond] for m in mean_r2_list]) for cond in conditions
}
yvals = [mean_r2_avg[c] for c in conditions]

# === Bar plot of averaged R²s ===
plt.figure(figsize=(8, 6))
plt.bar(condition_labels, yvals, color='skyblue')
plt.ylabel("Mean R² (avg. across datasets)", fontsize=16)
plt.title("Average R² per Method", fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=12)
plt.grid(True, axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "mean_r2_avg_barplot.svg"))
plt.close()

# === Line + point plot of averaged R²s ===
plt.figure(figsize=(8, 6))
plt.plot(condition_labels, yvals, marker='o', linewidth=2, markersize=8, color='crimson')
plt.ylabel("Mean R² (avg. across datasets)", fontsize=16)
plt.title("Average R² per Method", fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=12)
plt.grid(True, axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "mean_r2_avg_lineplot.svg"))
plt.close()


In [None]:
# R2 histograms for individual datasets 
# Fig S3D


import os
import numpy as np
import matplotlib.pyplot as plt

# === Paths ===
datapaths = [
    "NormanWeissman2019_filtered.h5ad",
    "TianKampmann2019_day7neuron.h5ad",
]
save_dir = os.path.join(out_dir, "r2_histograms")
os.makedirs(save_dir, exist_ok=True)

# === Plot one histogram per dataset ===
for p in datapaths:
    path = os.path.join(data_dir, p)
    base = os.path.basename(path).replace(".h5ad", "")
    npz_path = os.path.join(save_dir, f"R2_{base}_doublepert.npz")
    if not os.path.exists(npz_path):
        print(f"Missing: {npz_path}")
        continue

    data = dict(np.load(npz_path, allow_pickle=True))

    plt.figure(figsize=(10, 7))
    bins = np.linspace(0, 1, 40)

    plt.hist(data["R2_real"], bins=bins, alpha=0.7, label='Double Pert. R²', density=True)
    plt.hist(data["R2_additive"], bins=bins, alpha=0.7, label='Additive R²', density=True)
    plt.hist(data["R2_A"], bins=bins, alpha=0.5, label='Single A R²', density=True)
    plt.hist(data["R2_B"], bins=bins, alpha=0.5, label='Single B R²', density=True)
    plt.hist(data["R2_null"], bins=bins, alpha=0.3, label='Shuffled R²', density=True)

    plt.xlabel("R² Score", fontsize=16)
    plt.ylabel("Density", fontsize=16)
    plt.title(f"Double Perturbation R² — {base}", fontsize=18)
    plt.legend(fontsize=12)
    plt.yscale("log")
    plt.grid(True)
    plt.tight_layout()

    plt.savefig(os.path.join(save_dir, f"double_r2_hist_{base}.svg"))
    plt.close()
    print(f"Saved SVG for {base}")


In [None]:
# pvalue stuff for histograms FIG 3 N


import numpy as np
import os
from scipy.stats import ks_2samp

# === Paths and setup ===
datapaths = [
    "NormanWeissman2019_filtered.h5ad",
    "TianKampmann2019_day7neuron.h5ad",
]
save_dir = "r2_histograms"

# === Accumulate R² data across datasets ===
r2_all = {
    "Shuffled": [],
    "Single A": [],
    "Single B": [],
    "Additive": [],
    "True Σ": [],
}

# === Load data ===
for path in datapaths:
    base = os.path.basename(path).replace(".h5ad", "")
    npz_path = os.path.join(save_dir, f"R2_{base}_doublepert.npz")

    if not os.path.exists(npz_path):
        raise FileNotFoundError(f"{npz_path} not found. You must run the analysis first.")

    data = dict(np.load(npz_path, allow_pickle=True))
    r2_all["Shuffled"].extend(data["R2_null"])
    r2_all["Single A"].extend(data["R2_A"])
    r2_all["Single B"].extend(data["R2_B"])
    r2_all["Additive"].extend(data["R2_additive"])
    r2_all["True Σ"].extend(data["R2_real"])

# === Define comparison pairs ===
comparison_pairs = [
    ("Shuffled", "Single A"),
    ("Shuffled", "Single B"),
    ("Shuffled", "Additive"),
    ("Single A", "Single B"),
    ("Additive", "True Σ"),
]

# === Run KS tests ===
print("Kolmogorov–Smirnov test p-values:")
for name1, name2 in comparison_pairs:
    r2_1 = np.array(r2_all[name1])
    r2_2 = np.array(r2_all[name2])
    stat, pval = ks_2samp(r2_1, r2_2)
    print(f"{name1} vs {name2}: KS stat = {stat:.4f}, p = {pval:.3e}")


In [None]:
# pvalue stuff for FIG 3O


import os
import numpy as np
from scipy.stats import wilcoxon

# === Setup ===
datapaths = [
    "NormanWeissman2019_filtered.h5ad",
    "TianKampmann2019_day7neuron.h5ad",
]
save_dir = "r2_histograms"

# === Collect all perturbation-level R²s ===
all_pert_r2s = []

for path in datapaths:
    base = os.path.basename(path).replace(".h5ad", "")
    npz_path = os.path.join(save_dir, f"R2_{base}_doublepert.npz")

    if not os.path.exists(npz_path):
        raise FileNotFoundError(f"{npz_path} not found. Run the analysis to generate it.")
    
    result = dict(np.load(npz_path, allow_pickle=True))

    for i in range(len(result["R2_real"])):
        all_pert_r2s.append([
            result["R2_null"][i],
            result["R2_A"][i],
            result["R2_B"][i],
            result["R2_additive"][i],
            result["R2_real"][i]
        ])

# === Convert to array ===
all_pert_r2s = np.array(all_pert_r2s)

# === Define comparisons and expected direction (expect second > first) ===
comparison_pairs = [
    ("Shuffled", "Single A", 0, 1),
    ("Shuffled", "Single B", 0, 2),
    ("Shuffled", "Additive", 0, 3),
    ("Additive", "True Σ", 3, 4),
]

# === Compute Wilcoxon one-sided p-values (alternative = 'less') ===
print("Wilcoxon signed-rank test (one-sided, alternative='less') p-values:")
for name1, name2, idx1, idx2 in comparison_pairs:
    stat, pval = wilcoxon(all_pert_r2s[:, idx1], all_pert_r2s[:, idx2], alternative="less")
    print(f"{name1} < {name2}: p = {pval:.6g}")
