### Notebook for Fig 3

In [None]:
import os
import pandas as pd

os.chdir("../")

### Load dataset config

In [None]:
# Define source directory
data_dir = "data"
output_dir = "output"
# Define dataset paths
ds_info = pd.read_csv('resources/datasets.csv')
datasets = ds_info['file']

#### Run for each dataset (iteratively)

In [None]:
# Fig 3 base calculations
from src.r2 import full_analysis_with_nulls_soft_and_plots


for dataset in datasets:
    print(f'Processing {dataset}')
    path = os.path.join(data_dir, dataset)
    full_analysis_with_nulls_soft_and_plots(path, output_dir)

#### or use slurm version (if applicable)

In [None]:
from src.slurm import slurm_script
import subprocess

# Setup slurm config
slurm_config = {
    'account': 'account',      # replace with your account
    'partition': 'partition',     # replace with your partition
    'nodes': 1,
    'ntasks-per-node': 10,
    'mem': '100GB',
    'time': '48:00:00',
    'verbose': 'true',
}


# Submit jobs for all datasets
for dataset in datasets:
    path = os.path.join(data_dir, dataset)
    ds_base = dataset.replace(".h5ad", "")
    slurm_config['job-name'] = f'r2-{ds_base}'
    slurm_config['output'] = f'logs/r2/{ds_base}.log'
    # Create tmp script
    script_path = slurm_script(
        slurm_config, 
        conda_env='.conda/cipher', 
        module='r2',
        input=path,
        output=output_dir,
        cache=False
    )
    if script_path is None:
        print(f'Slurm script creation failed, {dataset} skipped.')
        continue
    # Submit the job
    result = subprocess.run(["sbatch", script_path], capture_output=True, text=True)
    # Print job id
    print(f"{dataset}:", result.stdout)

In [None]:
# Plots for FIG S3C


import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.sparse import issparse
from scipy.linalg import pinv
from src.preprocess import get_data

def compute_covariance(X):
    return np.cov(X, rowvar=False)

def compute_average_response(X0, X1):
    return X1.mean(axis=0) - X0.mean(axis=0)

def compute_r2_by_gene(dx, pred, epsilon=1e-8):
    dx = np.asarray(dx)
    pred = np.asarray(pred)
    valid = (np.abs(dx) > 1e-6)
    r2 = np.full_like(dx, np.nan)
    r2[valid] = 1.0 - np.abs(dx[valid] - pred[valid]) / (np.abs(dx[valid]) + epsilon)
    return r2

# === Config ===


save_dir = os.path.join(output_dir, "r2_by_gene_top10_per_pert_y01")
os.makedirs(save_dir, exist_ok=True)

for dataset in datasets:
    print(dataset)
    path = os.path.join(data_dir, dataset)
    base_name = os.path.basename(dataset).replace('.h5ad', '')
    adata, X0 = get_data(0, path)[:2]
    X0 = X0.toarray() if issparse(X0) else X0
    Sigma = compute_covariance(X0)
    gene_names = np.array(adata.var_names.tolist())

    # Get top 10 perturbations by number of cells (excluding 'control' and unknown genes)
    pert_counts = adata.obs['perturbation'].value_counts()
    valid_perturbations = [p for p in pert_counts.index if p != 'control' and p in gene_names]
    top_perturbations = sorted(valid_perturbations, key=lambda p: pert_counts[p], reverse=True)[:10]

    for pert in top_perturbations:
        gene_idx = np.where(gene_names == pert)[0][0]
        X1 = adata[adata.obs['perturbation'] == pert].X
        X1 = X1.toarray() if issparse(X1) else X1
        if X1.shape[0] == 0:
            continue

        dx = compute_average_response(X0, X1)
        if np.allclose(dx, 0, atol=1e-6):
            continue

        sigma_col = Sigma[:, gene_idx]
        epsilon = 1e-8
        u_opt = np.dot(sigma_col, dx) / (np.dot(sigma_col, sigma_col) + epsilon)
        pred = u_opt * sigma_col
        r2_vec = compute_r2_by_gene(dx, pred)

        df = pd.DataFrame({
            "gene": gene_names,
            "r2_score": r2_vec,
            "dx": dx,
            "pred": pred,
            "X0": X0.mean(axis=0)
        }).dropna(subset=["r2_score", "X0"])

        if len(df) == 0:
            print(f"No valid data for {pert} in {base_name}")
            continue

        df.to_csv(os.path.join(save_dir, f"{base_name}_{pert}_r2_by_gene.csv"), index=False)

        # Plot for this perturbation
        plt.figure(figsize=(7, 6))
        plt.scatter(df["X0"], df["r2_score"], alpha=0.3, s=8)
        plt.xlabel("Baseline Expression (X₀)", fontsize=16)
        plt.ylabel("Gene-wise R² Score", fontsize=16)
        plt.title(f"{base_name} - {pert}: R² vs X₀", fontsize=18)
        plt.grid(True)
        plt.ylim(0,1)
        plt.tight_layout()
        fname = f"{base_name}_{pert}_r2_vs_X0.svg".replace("/", "_")
        plt.savefig(os.path.join(save_dir, fname))
        plt.close()

print("One scatter plot per perturbation saved in", save_dir)


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import ks_2samp

# ---------- Dataset Categorization ----------
is_crispr_a = ds_info['perturbation']=='CRISPRa'
crispr_a_datasets = ds_info.loc[is_crispr_a, 'name']
crispr_i_datasets = ds_info.loc[~is_crispr_a, 'name']

csv_dir = os.path.join(output_dir, "r2_histograms")
save_dir = os.path.join(output_dir, "FIG3_r2_histograms")
os.makedirs(save_dir, exist_ok=True)

In [None]:
# FIG3 C/D and associated pvalues 
from src.r2 import plot_r2_comparison

# ---------- Plotting Utility with KS Test ----------

# ---------- Generate All 4 Plots ----------
# plot_r2_comparison(crispr_a_datasets, "CRISPRa", kind="null")
plot_r2_comparison(crispr_a_datasets, "CRISPRa", csv_dir, save_dir, kind="rand")
# plot_r2_comparison(crispr_i_datasets, "CRISPRi", kind="null")
plot_r2_comparison(crispr_i_datasets, "CRISPRi", csv_dir, save_dir, kind="rand")


In [None]:
# FIG S2A-F
from src.r2 import plot_r2_histograms_for_dataset


# ---------- Plotting Per Dataset ----------
for dataset in ds_info['name']:
    print(f'Processing {dataset}')
    plot_r2_histograms_for_dataset(dataset, csv_dir, save_dir)


In [None]:
# FIG 3 E/F as well as pvalue calculation for FIG3 E/F/M
from src.r2 import plot_r2_dataset_means_grouped, print_global_test
import numpy as np


# === Accumulate all values for global tests ===
r2_real_all_a, r2_null_all_a, r2_shuff_all_a = plot_r2_dataset_means_grouped(crispr_a_datasets, "CRISPRa", csv_dir, save_dir)
r2_real_all_i, r2_null_all_i, r2_shuff_all_i = plot_r2_dataset_means_grouped(crispr_i_datasets, "CRISPRi", csv_dir, save_dir)
# Combine data from both groups
r2_real_all = np.concat([r2_real_all_a, r2_real_all_i])
r2_null_all = np.concat([r2_null_all_a, r2_null_all_i])
r2_shuff_all = np.concat([r2_shuff_all_a, r2_shuff_all_i])
# === Combined 2-sided Wilcoxon tests ===
print("\nCombined 2-sided Wilcoxon signed-rank tests (all datasets):")

print_global_test("Meanfield Σ vs Real Σ", r2_real_all, r2_null_all)
print_global_test("Shuffled Σ vs Real Σ", r2_real_all, r2_shuff_all)


In [None]:
# Calculate mean fold change in R2 over all datasets when compared to mean-field
meanfc = (r2_real_all / r2_null_all).mean()
print(f"Real vs. Mean-field Fold Change: {meanfc:.2f}")

In [None]:
# FIG3L and pval calculation 

import os
from src.r2 import plot_r2_comparisons_from_saved_csvs

csv_dir = os.path.join(output_dir, "r2_histograms")
save_dir = os.path.join(output_dir, "FIG3_r2_histograms")

# Example usage:
plot_r2_comparisons_from_saved_csvs(comparison="null", csv_dir=csv_dir, save_dir=save_dir)
# plot_r2_comparisons_from_saved_csvs(comparison="rand")


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# FIG3 G-J and FIGS2 G-N
from src.r2 import plot_best_r2_perturbation

csv_dir = os.path.join(output_dir, 'r2_histograms')
save_dir = os.path.join(output_dir, 'dx_plots')
plot_best_r2_perturbation(ds_info['name'], r2_dir=csv_dir, data_dir=data_dir, cov_dir='output/samples_summaries/u_samples_summaries', plot_dir=save_dir)

In [None]:
# ------------------ Additional analysis -------------------------

In [None]:
# run R2 calculation for all datasets over ALL genes (not just true perturbed genes)


import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from scipy.sparse import issparse
from tqdm import tqdm

def get_gene_index(gene_name, gene_names):
    indices = np.where(gene_names == gene_name)[0]
    return int(indices[0]) if indices.size > 0 else None

def compute_covariance(X):
    return np.cov(X, rowvar=False)

def compute_average_response(X0, X1):
    return X1.mean(axis=0) - X0.mean(axis=0)

def compute_r2(dx, pred, epsilon=1e-8):
    valid = ~np.isnan(dx) & ~np.isnan(pred)
    dx_valid = dx[valid]
    pred_valid = pred[valid]
    if np.linalg.norm(dx_valid) == 0:
        return np.nan
    return 1.0 - np.sum((dx_valid - pred_valid) ** 2) / (np.sum(dx_valid ** 2) + epsilon)

def full_r2_matrix_analysis(data_path, save_dir="r2_genelevel", plot_dir="r2_genelevel_plots"):
    adata, X0, _ = get_data(0, data_path)
    gene_names = np.array(adata.var_names.tolist())
    X0_dense = X0.toarray() if issparse(X0) else X0

    # === Covariances ===
    Sigma = compute_covariance(X0_dense)

    # Shuffle X0 across cells (columns stay fixed)
    X0_cell_shuffled = X0_dense.copy()
    for g in range(X0_cell_shuffled.shape[1]):
        np.random.shuffle(X0_cell_shuffled[:, g])
    Sigma_S = compute_covariance(X0_cell_shuffled)

    # Full shuffle: shuffle all values of X0 (cells × genes), then reshape
    X0_flat = np.array(X0_dense).flatten()  # fully detached copy
    X0_shuffled_flat = np.random.permutation(X0_flat)  # avoids in-place modification
    X0_full_shuffled = X0_shuffled_flat.reshape(X0_dense.shape)
    Sigma_rand = compute_covariance(X0_full_shuffled)

    # === Output setup ===
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    base_name = os.path.basename(data_path).replace('.h5ad', '').replace('.pkl', '')
    perturbations = [p for p in adata.obs['perturbation'].value_counts().index if p != 'control']
    all_rows = []
    epsilon = 1e-8

    for pert in tqdm(perturbations, desc=f"{base_name}"):
        X1 = adata[adata.obs['perturbation'] == pert].X
        X1 = X1.toarray() if issparse(X1) else X1
        delta_X = compute_average_response(X0_dense, X1)

        r2_real = []
        r2_null = []
        r2_shuff = []

        for gene_i in range(len(gene_names)):
            def predict_r2(Sigma_used):
                col = Sigma_used[:, gene_i]
                u = np.dot(col, delta_X) / (np.dot(col, col) + epsilon)
                pred = u * col
                return compute_r2(delta_X, pred, epsilon)

            r2_real_val = predict_r2(Sigma)
            r2_null_val = predict_r2(Sigma_S)
            r2_shuff_val = predict_r2(Sigma_rand)

            all_rows.append({
                "perturbation": pert,
                "gene": gene_names[gene_i],
                "gene_index": gene_i,
                "R2_real": r2_real_val,
                "R2_null": r2_null_val,
                "R2_shuffled_sigma": r2_shuff_val,
                "is_true_target": int(gene_i == get_gene_index(pert, gene_names))
            })

            r2_real.append(r2_real_val)
            r2_null.append(r2_null_val)
            r2_shuff.append(r2_shuff_val)

    # Save full table
    df = pd.DataFrame(all_rows)
    df.to_csv(os.path.join(save_dir, f"{base_name}_r2_per_gene.csv"), index=False)

    # Aggregate histograms
    plt.figure(figsize=(7, 5))
    plt.hist(df["R2_real"].dropna(), bins=40, alpha=0.6, label="Real Σ", density=True)
    plt.hist(df["R2_null"].dropna(), bins=40, alpha=0.6, label="Shuffled X₀", density=True)
    plt.hist(df["R2_shuffled_sigma"].dropna(), bins=40, alpha=0.6, label="Shuffled Σ (X₀ full)", density=True)
    plt.xlabel("R² score", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.title(f"Aggregate R² Distribution – {base_name}", fontsize=16)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, f"{base_name}_r2_aggregate_hist.png"), dpi=300)
    plt.close()

# === Run on multiple datasets ===

for path in datasets:
    full_r2_matrix_analysis(path, save_dir=os.path.join(output_dir, "r2_genelevel"))


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

data_dir = os.path.join(output_dir, "r2_genelevel")
save_dir = os.path.join(output_dir, "r2_genelevel_plots_true_only")
os.makedirs(save_dir, exist_ok=True)

dataset_files = ds_info['names']

# Containers for aggregate histograms
all_real_true = []
all_meanfield = []
all_shuffled = []

for dataset in dataset_files:
    csv_path = os.path.join(data_dir, f"{dataset}_r2_per_gene.csv")
    if not os.path.exists(csv_path):
        print(f"Missing: {csv_path}")
        continue

    df = pd.read_csv(csv_path)
    if not all(col in df.columns for col in ["R2_real", "R2_null", "R2_shuffled_sigma", "is_true_target"]):
        print(f"Missing columns in {dataset}, skipping.")
        continue

    real_true = df[df["is_true_target"] == 1]["R2_real"].dropna()
    meanfield_all = df["R2_null"].dropna()
    shuffled_all = df["R2_shuffled_sigma"].dropna()

    # Accumulate for global plots
    all_real_true.extend(real_true)
    all_meanfield.extend(meanfield_all)
    all_shuffled.extend(shuffled_all)

    # Plot individual dataset histograms
    plt.figure(figsize=(7, 5))
    plt.hist(shuffled_all, bins=40, alpha=0.6, label="Shuffled Σ (full)", density=True)
    plt.hist(meanfield_all, bins=40, alpha=0.6, label="Shuffled X₀ (meanfield)", density=True)
    plt.hist(real_true, bins=40, alpha=0.6, label="Real Σ (true gene only)", density=True)

    plt.xlabel("R² score", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.title(f"R² Histogram – {dataset}", fontsize=16)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{dataset}_r2_hist_true_only.svg"), format='svg')
    plt.close()
    print(f"Saved plot for {dataset}")

# === Plot total histogram over all datasets ===
plt.figure(figsize=(8, 6))
plt.hist(all_shuffled, bins=60, alpha=0.6, label="Shuffled Σ (full)", density=True)
plt.hist(all_meanfield, bins=60, alpha=0.6, label="Shuffled X₀ (meanfield)", density=True)
plt.hist(all_real_true, bins=60, alpha=0.6, label="Real Σ (true gene only)", density=True)

plt.xlabel("R² score", fontsize=14)
plt.ylabel("Density", fontsize=14)
plt.title("Aggregate R² Distribution – All Datasets", fontsize=16)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "ALL_r2_hist_true_only.svg"), format='svg')
plt.close()
print("Saved aggregate histogram.")


In [None]:
# === Plot total histogram over all datasets ===
plt.figure(figsize=(8, 6))
plt.hist(all_shuffled, bins=60, alpha=0.6, label="Shuffled Σ (full)", density=True)
plt.hist(all_meanfield, bins=60, alpha=0.6, label="Shuffled X₀ (meanfield)", density=True)
plt.hist(all_real_true, bins=60, alpha=0.6, label="Real Σ (true gene only)", density=True)

plt.xlabel("R² score", fontsize=14)
plt.ylabel("Density", fontsize=14)
plt.title("Aggregate R² Distribution – All Datasets", fontsize=16)
plt.legend()
plt.tight_layout()
plt.yscale('log')
plt.savefig(os.path.join(save_dir, "ALL_r2_hist_true_only.svg"), format='svg')
plt.close()
print("Saved aggregate histogram.")