In [None]:
import os
import pandas as pd
os.chdir('../')

# Setting
data_dir = 'data'
ds_info = pd.read_csv('resources/datasets.csv')
datasets = ds_info['file']
dataset_names = ds_info['name']

In [None]:
# Egene analysis calculation

from src.preprocess import get_data

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import eigh
from scipy.sparse import issparse
import pickle


# === Parameters ===
top_k = 30
save_dir = "egene_analysis"
os.makedirs(save_dir, exist_ok=True)

# === Load your gene index helper ===
def get_gene_index(gene_name, gene_names):
    idx = np.where(gene_names == gene_name)[0]
    return int(idx[0]) if idx.size > 0 else None


# === Loop over datasets ===
for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    print(f"\nProcessing {dataset_name}")

    # === Load data and get all perturbations
    adata, X0, X1 = get_data(0,data_path)

    adata.obs['perturbation_base'] = adata.obs['perturbation'].str.replace(r'g\d+$', '', regex=True)
    perturbations = adata.obs['perturbation_base'].unique()
    perturbations = [p for p in perturbations if p != 'control']

    # === Compute covariance of control
    Sigma = np.cov(X0, rowvar=False)
    eigvals, eigvecs = eigh(Sigma)
    sorted_idx = np.argsort(eigvals)[::-1]
    eigvecs_sorted = eigvecs[:, sorted_idx[:top_k]]
    eigvals_sorted = eigvals[sorted_idx[:top_k]]

    alpha_squared_records = []

    for pert in perturbations:
        selected_pert_data = adata[adata.obs['perturbation_base'] == pert]
        if selected_pert_data.shape[0] < 2:
            print(f"Skipping {pert} (too few samples)")
            continue

        X1 = selected_pert_data.X.toarray() if issparse(selected_pert_data.X) else selected_pert_data.X
        delta_X = np.mean(X1, axis=0) - np.mean(X0, axis=0)

        alpha = eigvecs_sorted.T @ delta_X
        alpha_squared = alpha**2
        alpha_squared_records.append({
            "perturbation": pert,
            "alpha_squared": alpha_squared
        })


    # === Save projection data ===
    alpha_save_path = os.path.join(save_dir, f"{dataset_name}_alpha_squared.pkl")
    with open(alpha_save_path, "wb") as f:
        pickle.dump(alpha_squared_records, f)

    print(f"Saved results for {dataset_name} to {alpha_save_path}")


In [None]:
# FIG 5B , FIG S5B heatmaps


import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

# === Parameters ===
top_k = 10  # Only use top 10 eigenmodes
save_dir = "egene_analysis"
os.makedirs(save_dir, exist_ok=True)


# === Loop through datasets ===
for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    alpha_path = os.path.join(save_dir, f"{dataset_name}_alpha_squared.pkl")

    if not os.path.exists(alpha_path):
        print(f"Skipping {dataset_name}, missing alpha_squared.pkl")
        continue

    with open(alpha_path, "rb") as f:
        alpha_squared_records = pickle.load(f)

    # Extract top-k components
    perturbations = []
    alpha_matrix_list = []
    for record in alpha_squared_records:
        alpha_sq = np.array(record["alpha_squared"])
        if len(alpha_sq) < top_k:
            continue
        alpha_matrix_list.append(alpha_sq[:top_k])
        perturbations.append(record["perturbation"])

    if not alpha_matrix_list:
        print(f"No valid perturbations found for {dataset_name}")
        continue

    alpha_matrix = np.array(alpha_matrix_list).T  # shape: (top_k, num_perts)
    alpha_matrix_norm = alpha_matrix / np.sum(alpha_matrix, axis=0, keepdims=True)

    # === Clustered heatmap with dendrograms ===
    g = sns.clustermap(
        alpha_matrix_norm,
        cmap="viridis",
        row_cluster=True,
        col_cluster=True,
        xticklabels=perturbations,
        yticklabels=[f"PC{i+1}" for i in range(top_k)],
        figsize=(14, 8),
        dendrogram_ratio=(0.2, 0.2),  # show row and col dendrograms
        cbar_pos=(0.02, 0.8, 0.02, 0.18)  # standard colorbar
    )
    g.fig.suptitle(
        rf"{dataset_name}: Clustered Heatmap of Normalized $\alpha_{{ij}}^2 / \sum_i \alpha_{{ij}}^2$",
        y=1.02
    )
    plt.setp(g.ax_heatmap.get_xticklabels(), rotation=90)

    # === Save ===
    heatmap_path = os.path.join(save_dir, f"{dataset_name}_alpha_squared_clustermap_top{top_k}.svg")
    plt.savefig(heatmap_path, bbox_inches='tight')
    plt.close()
    print(f"Saved clustered heatmap: {heatmap_path}")


In [None]:
# FIG 5D


import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

# === Paths ===
alpha_dir = "egene_analysis"
r2_dir = "output/r2_histograms"
output_dir = "FIG_TRADE"
os.makedirs(output_dir, exist_ok=True)

# === Load mean and median participation ratios ===
pr_stats = {}
for file in os.listdir(alpha_dir):
    if file.endswith("_alpha_squared.pkl"):
        dataset = file.replace("_alpha_squared.pkl", "")
        alpha_path = os.path.join(alpha_dir, file)

        with open(alpha_path, "rb") as f:
            alpha_squared_records = pickle.load(f)

        participation_ratios = []
        for rec in alpha_squared_records:
            alpha_sq = np.array(rec["alpha_squared"])
            numerator = np.sum(alpha_sq) ** 2
            denominator = np.sum(alpha_sq ** 2)
            PR = numerator / denominator
            participation_ratios.append(PR)

        if participation_ratios:
            pr_stats[dataset] = {
                "mean_pr": np.mean(participation_ratios),
                "median_pr": np.median(participation_ratios)
            }

# === Load mean R² ===
r2_means = {}
for file in os.listdir(r2_dir):
    if file.endswith("_results.csv"):
        dataset = file.replace("_results.csv", "")
        df = pd.read_csv(os.path.join(r2_dir, file))
        if "R2_real" in df:
            r2_means[dataset] = np.mean(df["R2_real"].dropna())

# === Merge summaries ===
rows = []
for dataset in pr_stats:
    if dataset in r2_means:
        rows.append({
            "dataset": dataset,
            "mean_pr": pr_stats[dataset]["mean_pr"],
            "median_pr": pr_stats[dataset]["median_pr"],
            "mean_r2": r2_means[dataset]
        })

summary_df = pd.DataFrame(rows)

# === Function to plot and fit line ===
def plot_with_fit(x, y, labels, xlabel, ylabel, title, filename, color_palette):
    X = np.array(x).reshape(-1, 1)
    Y = np.array(y)
    reg = LinearRegression()
    reg.fit(X, Y)
    Y_pred = reg.predict(X)
    r2 = r2_score(Y, Y_pred)
    slope = reg.coef_[0]

    plt.figure(figsize=(8, 6))
    for i, label in enumerate(labels):
        plt.scatter(x[i], y[i], color=color_palette[i], s=80, label=label)
        plt.text(x[i], y[i], label, fontsize=9, ha='left', va='bottom')

    plt.plot(x, Y_pred, color='black', linestyle='--', label=f"Fit (R² = {r2:.2f}, Slope = {slope:.2f})")
    plt.xlabel(xlabel, fontsize=14)
    plt.ylabel(ylabel, fontsize=14)
    plt.title(title, fontsize=15)
    plt.grid(True)
    plt.legend(fontsize=9, loc='best')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

# === Generate color palette ===
palette = plt.cm.tab10.colors if len(summary_df) <= 10 else plt.cm.tab20.colors

# === Plot Mean PR vs Mean R² ===
plot_with_fit(
    x=summary_df["mean_r2"],
    y=summary_df["mean_pr"],
    labels=summary_df["dataset"],
    xlabel="Mean R² (Real Σ)",
    ylabel="Mean Participation Ratio",
    title="Mean PR vs Mean R² Across Datasets",
    filename="mean_PR_vs_mean_R2.svg",
    color_palette=palette[:len(summary_df)]
)

# === Plot Median PR vs Mean R² ===
plot_with_fit(
    x=summary_df["mean_r2"],
    y=summary_df["median_pr"],
    labels=summary_df["dataset"],
    xlabel="Mean R² (Real Σ)",
    ylabel="Median Participation Ratio",
    title="Median PR vs Mean R² Across Datasets",
    filename="median_PR_vs_mean_R2.svg",
    color_palette=palette[:len(summary_df)]
)

# === Save data table ===
summary_df.to_csv(os.path.join(output_dir, "PR_vs_R2_summary.csv"), index=False)


In [None]:
# FIG 5 A, FIG S5A


import os
import numpy as np
import matplotlib.pyplot as plt
import pickle

# === Parameters ===
save_dir = "egene_analysis"
os.makedirs(save_dir, exist_ok=True)

# === Loop through datasets ===
for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    alpha_path = os.path.join(save_dir, f"{dataset_name}_alpha_squared.pkl")

    if not os.path.exists(alpha_path):
        print(f"Skipping {dataset_name}, missing alpha_squared.pkl")
        continue

    with open(alpha_path, "rb") as f:
        alpha_squared_records = pickle.load(f)

    participation_ratios = []

    for record in alpha_squared_records:
        alpha_squared = np.array(record["alpha_squared"])
        numerator = np.sum(alpha_squared) ** 2
        denominator = np.sum(alpha_squared**2)
        PR = numerator / denominator
        participation_ratios.append(PR)

    if not participation_ratios:
        print(f"No valid alpha_squared data for {dataset_name}")
        continue

    # === Plot histogram for this dataset ===
    plt.figure(figsize=(8, 4))
    plt.hist(participation_ratios, bins=20, color='skyblue', edgecolor='k', alpha=0.8, density = True)
    plt.xlabel("Participation Ratio")
    plt.ylabel("Density")
    plt.title(f"Distribution of PR — {dataset_name}")
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()

    # === Save plot ===
    hist_path = os.path.join(save_dir, f"{dataset_name}_participation_ratio_hist.svg")
    plt.savefig(hist_path)
    plt.close()
    print(f"Saved PR histogram: {hist_path}")


In [None]:
# FIG S5C

import os
import numpy as np
import matplotlib.pyplot as plt
import pickle

# === Parameters ===
save_dir = "egene_analysis"
output_dir = "FIG_TRADE"
os.makedirs(output_dir, exist_ok=True)

# === Define Tian datasets ===
tian_datasets = {
    "TianKampmann2021_CRISPRa",
    "TianKampmann2021_CRISPRi",
    "TianKampmann2019_day7neuron"
}

# === Collect PRs ===
all_prs = []
tian_prs = []

for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    alpha_path = os.path.join(save_dir, f"{dataset_name}_alpha_squared.pkl")

    if not os.path.exists(alpha_path):
        print(f"Skipping {dataset_name}, missing alpha_squared.pkl")
        continue

    with open(alpha_path, "rb") as f:
        alpha_squared_records = pickle.load(f)

    for record in alpha_squared_records:
        alpha_squared = np.array(record["alpha_squared"])
        numerator = np.sum(alpha_squared) ** 2
        denominator = np.sum(alpha_squared ** 2)
        PR = numerator / denominator

        all_prs.append(PR)
        if dataset_name in tian_datasets:
            tian_prs.append(PR)

# === Plotting function ===
def plot_pr_histogram(prs, label, fname):
    plt.figure(figsize=(8, 4))
    plt.hist(prs, bins=40, density=True, alpha=0.85, color="steelblue", edgecolor="k")
    plt.title(f"Participation Ratio Distribution — {label}")
    plt.xlabel("Participation Ratio")
    plt.ylabel("Density")
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, fname))
    plt.close()

# === Plot: All Datasets ===
plot_pr_histogram(all_prs, "All Datasets", "PR_hist_all.svg")

# === Plot: Tian Datasets Combined ===
plot_pr_histogram(tian_prs, "Tian Datasets (CRISPRa/i + Day7Neuron)", "PR_hist_Tian.svg")


In [None]:
# participation ratio stats

import os
import numpy as np
import matplotlib.pyplot as plt
import pickle

# === Parameters ===
save_dir = "egene_analysis"
os.makedirs(save_dir, exist_ok=True)


all_participation_ratios = []

# === Loop through datasets ===
for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    alpha_path = os.path.join(save_dir, f"{dataset_name}_alpha_squared.pkl")

    if not os.path.exists(alpha_path):
        print(f"Skipping {dataset_name}, missing alpha_squared.pkl")
        continue

    with open(alpha_path, "rb") as f:
        alpha_squared_records = pickle.load(f)

    participation_ratios = []

    for record in alpha_squared_records:
        alpha_squared = np.array(record["alpha_squared"])
        numerator = np.sum(alpha_squared) ** 2
        denominator = np.sum(alpha_squared**2)
        PR = numerator / denominator
        participation_ratios.append(PR)

    if not participation_ratios:
        print(f"No valid alpha_squared data for {dataset_name}")
        continue

    # Append to global list
    all_participation_ratios.extend(participation_ratios)

   
# === Compute and report summary statistics ===
all_participation_ratios = np.array(all_participation_ratios)
mean_pr = np.mean(all_participation_ratios)
stderr_pr = np.std(all_participation_ratios, ddof=1) / np.sqrt(len(all_participation_ratios))

print("\n📊 Overall Participation Ratio Summary")
print(f"Average PR across all datasets: {mean_pr:.2f}")
print(f"Standard Error of the Mean (SEM): {stderr_pr:.2f}")
print(f"Total perturbations analyzed: {len(all_participation_ratios)}")


In [None]:
# GO heatmap analysis


import os
import numpy as np
import pandas as pd
import anndata as ad
import pickle
from scipy.linalg import eigh
from scipy.sparse import issparse
from gprofiler import GProfiler
from collections import defaultdict

# === Parameters ===
top_k = 30
top_n_genes = 200
save_dir = "egene_analysis"
os.makedirs(save_dir, exist_ok=True)

gp = GProfiler(return_dataframe=True)

# === Loop over datasets ===
for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    dataset_name = os.path.basename(data_path).replace(".h5ad", "")
    print(f"\n🔍 Processing {dataset_name}")

    try:
        adata, X0, _ = get_data(0, data_path)
    except Exception as e:
        print(f"Failed to load {dataset_name}: {e}")
        continue

    if "perturbation" not in adata.obs:
        print(f"No 'perturbation' column in {dataset_name}")
        continue

    # Compute covariance
    Sigma = np.cov(X0, rowvar=False)
    eigvals, eigvecs = eigh(Sigma)

    # Sort by descending eigenvalue
    sorted_idx = np.argsort(eigvals)[::-1]
    eigvals_sorted = eigvals[sorted_idx[:top_k]]
    eigvecs_sorted = eigvecs[:, sorted_idx[:top_k]]

    gene_names = np.array(adata.var_names)

    # === Save eigenvectors and gene names ===
    np.save(os.path.join(save_dir, f"{dataset_name}_eigvecs_sorted.npy"), eigvecs_sorted)
    np.save(os.path.join(save_dir, f"{dataset_name}_gene_names.npy"), gene_names)

    # === Get top contributing genes for each eigengene ===
    eigengene_to_genes = {}
    for i in range(top_k):
        loading_vector = eigvecs_sorted[:, i]
        top_indices = np.argsort(-np.abs(loading_vector))[:top_n_genes]
        top_gene_names = gene_names[top_indices].tolist()
        eig_label = f"Eig {i+1}"
        eigengene_to_genes[eig_label] = top_gene_names

    # === Run GO enrichment ===
    enrichment_results = {}
    for eig_label, gene_list in eigengene_to_genes.items():
        try:
            result_df = gp.profile(
                organism="hsapiens",
                query=gene_list,
                sources=["GO:BP", "GO:MF", "REAC"],
                user_threshold=0.05,
                no_evidences=False
            )
        except Exception as e:
            print(f"gProfiler failed on {eig_label} in {dataset_name}: {e}")
            result_df = pd.DataFrame()

        if not result_df.empty:
            top_terms = result_df[['name', 'p_value']].sort_values(by='p_value').head(5)
            enrichment_results[eig_label] = top_terms
        else:
            enrichment_results[eig_label] = None

    # === Build GO term matrix ===
    go_term_dict = defaultdict(dict)
    for eig, df in enrichment_results.items():
        if df is None:
            continue
        for _, row in df.iterrows():
            term = row["name"]
            score = -np.log10(row["p_value"] + 1e-300)
            go_term_dict[term][eig] = score

    go_term_df = pd.DataFrame(go_term_dict).T.fillna(0)

    # === Save results ===
    with open(os.path.join(save_dir, f"{dataset_name}_go_enrichment.pkl"), "wb") as f:
        pickle.dump(enrichment_results, f)

    go_term_df.to_csv(os.path.join(save_dir, f"{dataset_name}_go_matrix.csv"))

    print(f"Finished {dataset_name}: saved eigvecs, enrichment, and matrix")


In [None]:
# FIG 5 C


import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# === Parameters ===
save_dir = "egene_analysis"
top_terms_to_plot = 6


# === Loop over datasets and generate plots ===
for dataset_name in dataset_names:
    go_matrix_path = os.path.join(save_dir, f"{dataset_name}_go_matrix.csv")

    if not os.path.exists(go_matrix_path):
        print(f"Missing matrix: {go_matrix_path}")
        continue

    go_term_df = pd.read_csv(go_matrix_path, index_col=0)

    if go_term_df.empty:
        print(f"Empty GO matrix for {dataset_name}, skipping")
        continue

    # === Select top enriched terms (by max enrichment across any eigengene) ===
    top_terms = go_term_df.max(axis=1).sort_values(ascending=False).head(top_terms_to_plot).index
    filtered_df = go_term_df.loc[top_terms]

    # === Plot ===
    plt.figure(figsize=(12, 8))
    sns.heatmap(filtered_df, cmap="magma", cbar_kws={"label": "-log10(p-value)"})
    plt.title(f"{dataset_name} — GO Enrichment Across Eigengenes")
    plt.xlabel("Eigengene")
    plt.ylabel("GO Term")
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    plot_path = os.path.join(save_dir, f"{dataset_name}_go_heatmap.svg")
    plt.savefig(plot_path)
    plt.close()

    print(f"Saved heatmap: {plot_path}")


In [None]:
# combined datase4t GO analysis



import os
import numpy as np
import pandas as pd
from gprofiler import GProfiler
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt

# === Parameters ===
top_n_genes = 200
top_terms_to_plot = 25
top_k = 30
save_dir = "egene_analysis"
os.makedirs(save_dir, exist_ok=True)

gp = GProfiler(return_dataframe=True)

# === Dataset groups ===
groups = {
    "Tian_CRISPR_combined": [
        "TianKampmann2021_CRISPRi",
        "TianKampmann2021_CRISPRa",
        "TianKampmann2019_day7neuron"
    ],
    "All": dataset_names
}

# === Load top loading genes per PC per dataset ===
def get_top_genes_by_pc(dataset):
    eigvec_path = os.path.join(save_dir, f"{dataset}_eigvecs_sorted.npy")
    genes_path = os.path.join(save_dir, f"{dataset}_gene_names.npy")

    if not (os.path.exists(eigvec_path) and os.path.exists(genes_path)):
        print(f"Missing eigenvectors or gene names for {dataset}")
        return None

    eigvecs_sorted = np.load(eigvec_path)
    gene_names = np.load(genes_path, allow_pickle=True)
    pc_to_genes = {}

    for i in range(min(top_k, eigvecs_sorted.shape[1])):
        loading_vector = eigvecs_sorted[:, i]
        top_indices = np.argsort(-np.abs(loading_vector))[:top_n_genes]
        top_genes = set(gene_names[top_indices])
        pc_to_genes[f"Eig {i+1}"] = top_genes

    return pc_to_genes

# === Run GO enrichment on intersected genes per PC ===
def run_go_heatmap_for_group(group_name, dataset_list):
    print(f"\n🔍 Processing group: {group_name}")
    all_pc_to_genes = [get_top_genes_by_pc(ds) for ds in dataset_list]

    if any(x is None for x in all_pc_to_genes):
        print(f"Skipping {group_name} due to missing data")
        return

    # Keys are guaranteed to be aligned: Eig 1 ... Eig K
    all_pcs = all_pc_to_genes[0].keys()
    go_term_dict = defaultdict(dict)

    for pc in all_pcs:
        gene_sets = [d[pc] for d in all_pc_to_genes if pc in d]
        if len(gene_sets) < 2:
            continue
        shared_genes = set.intersection(*gene_sets)
        if len(shared_genes) < 5:
            continue

        try:
            result_df = gp.profile(
                organism="hsapiens",
                query=list(shared_genes),
                sources=["GO:BP", "GO:MF", "REAC"],
                user_threshold=0.05,
                no_evidences=False
            )
        except Exception as e:
            print(f"gProfiler failed on {group_name} {pc}: {e}")
            continue

        if not result_df.empty:
            top_terms = result_df[['name', 'p_value']].sort_values(by='p_value').head(top_terms_to_plot)
            for _, row in top_terms.iterrows():
                term = row["name"]
                score = -np.log10(row["p_value"] + 1e-300)
                go_term_dict[term][pc] = score

    # === Assemble and save GO matrix ===
    go_df = pd.DataFrame(go_term_dict).T.fillna(0)

    if go_df.empty:
        print(f"No significant enrichment for {group_name}")
        return

    # Reduce to top overall terms
    top_rows = go_df.max(axis=1).sort_values(ascending=False).head(top_terms_to_plot).index
    filtered_df = go_df.loc[top_rows]

    # Plot
    plt.figure(figsize=(1 + 0.4 * len(filtered_df.columns), 0.5 * len(filtered_df)))
    sns.heatmap(filtered_df, cmap="magma", cbar_kws={"label": "-log10(p-value)"})
    plt.title(f"{group_name} — Shared Top Gene GO Enrichment per PC")
    plt.xlabel("Eigengene")
    plt.ylabel("GO Term")
    plt.tight_layout()

    plot_path = os.path.join(save_dir, f"{group_name}_GO_PC_heatmap.svg")
    plt.savefig(plot_path)
    plt.close()

    # Save matrix
    csv_path = os.path.join(save_dir, f"{group_name}_GO_PC_matrix.csv")
    filtered_df.to_csv(csv_path)

    print(f"Saved: {plot_path}, {csv_path}")

# === Run for all groups ===
for group_name, dataset_list in groups.items():
    run_go_heatmap_for_group(group_name, dataset_list)


In [None]:
# TRADE comparison

import os
import pickle
import numpy as np
from scipy.linalg import pinv
from scipy.sparse import issparse
from scipy.stats import entropy, skew, kurtosis
import pandas as pd

def analyze_gene_contributions(adata, X0_dense, Sigma, gene_names):
    def get_gene_index(gene_name):
        idx = np.where(gene_names == gene_name)[0]
        return int(idx[0]) if idx.size > 0 else None

    def compute_average_response(X0, X1):
        return X1.mean(axis=0) - X0.mean(axis=0)

    def compute_macro_metrics(frac_vector_global, frac_row, delta_X, u_star, gene_index):
        sorted_contribs = np.sort(frac_vector_global)[::-1]
        cumsum = np.cumsum(sorted_contribs)
        thresholds = [0.5, 0.75, 0.9, 0.95, 0.99]
        genes_needed = {f"genes_for_{int(t*100)}pct": int(np.searchsorted(cumsum, t) + 1) for t in thresholds}
        eff_size_global = np.exp(entropy(frac_vector_global + 1e-12))
        eff_size_target = np.exp(entropy(frac_row + 1e-12))
        cos_sim = (u_star @ delta_X) / (np.linalg.norm(u_star) * np.linalg.norm(delta_X) + 1e-12)
        skw = skew(frac_vector_global)
        krt = kurtosis(frac_vector_global)
        top_contribs = sorted_contribs[:10]
        log_ranks = np.log(np.arange(1, 11))
        log_contribs = np.log(top_contribs + 1e-12)
        slope = np.polyfit(log_ranks, log_contribs, 1)[0] if len(log_contribs) == 10 else np.nan
        return {
            **genes_needed,
            "eff_size_global": eff_size_global,
            "eff_size_target": eff_size_target,
            "eff_size_diff": eff_size_target - eff_size_global,
            "cosine_similarity": cos_sim,
            "skewness": skw,
            "kurtosis": krt,
            "zipf_slope": slope
        }

    def compute_self_rank(frac_row, gene_index):
        ranked = np.argsort(frac_row)[::-1]
        rank = np.where(ranked == gene_index)[0]
        return int(rank[0]) + 1 if rank.size > 0 else np.nan

    Sigma_inv = pinv(Sigma)
    abs_Sigma = np.abs(Sigma)
    perturbation_counts = adata.obs['perturbation'].value_counts()
    perturbations = [p for p in perturbation_counts.index if p != 'control']

    metrics = []

    for pert in perturbations:
        gene_index = get_gene_index(pert)
        if gene_index is None:
            continue

        X1 = adata[adata.obs['perturbation'] == pert].X
        X1 = X1.toarray() if issparse(X1) else X1
        delta_X = compute_average_response(X0_dense, X1)
        u_star = Sigma_inv @ delta_X

        abs_u = np.abs(u_star)
        contrib_sums = abs_Sigma @ abs_u
        frac_vector_global = contrib_sums / (np.sum(contrib_sums) + 1e-8)

        row = abs_Sigma[gene_index, :] * abs_u
        frac_row = row / (np.sum(row) + 1e-8)

        metrics_dict = compute_macro_metrics(frac_vector_global, frac_row, delta_X, u_star, gene_index)
        metrics_dict["Perturbation"] = pert
        metrics_dict["self_rank"] = compute_self_rank(frac_row, gene_index)

        metrics.append(metrics_dict)

    return metrics

save_dir = "contribution_outputs"
os.makedirs(save_dir, exist_ok=True)


for ds in datasets:
    data_path = os.path.join(data_dir, ds)
    print(data_path)
    base = os.path.basename(data_path).replace(".h5ad", "")
    adata, X0, _ = get_data(0, data_path)
    gene_names = np.array(adata.var_names)
    X0_dense = X0.toarray() if issparse(X0) else X0
    Sigma = np.cov(X0_dense, rowvar=False)

    metrics = analyze_gene_contributions(adata, X0_dense, Sigma, gene_names)

    with open(os.path.join(save_dir, f"{base}_macro_contribution.pkl"), "wb") as f:
        pickle.dump(metrics, f)


In [None]:
# FIG 5 E/F FIG S5D


import os
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# === Setup ===
save_dir = "contribution_outputs"
output_dir = "FIG_TRADE"
os.makedirs(output_dir, exist_ok=True)

# === Load all metric files ===
all_metrics = []
for file in os.listdir(save_dir):
    if file.endswith("_macro_contribution.pkl"):
        with open(os.path.join(save_dir, file), "rb") as f:
            data = pickle.load(f)
            df = pd.DataFrame(data)
            df["dataset"] = file.replace("_macro_contribution.pkl", "")
            all_metrics.append(df)

combined_df = pd.concat(all_metrics, ignore_index=True)

# === Descriptive labels ===
label_dict = {
    "eff_size_global": "Number of Genes Driving Global Transcriptomic Change",
    "eff_size_target": "Genes contriubuting to True Perturbed Gene's response",
    "self_rank": "Rank of how much True Perturbed Gene affects its own change",
    "genes_for_50pct": "Genes Explaining 50% of Total Response",
    "genes_for_75pct": "Genes Explaining 75% of Total Response",
    "genes_for_90pct": "Genes Explaining 90% of Total Response",
    "genes_for_95pct": "Genes Explaining 95% of Total Response",
    "genes_for_99pct": "Genes Explaining 99% of Total Response"
}

metrics_to_plot = list(label_dict.keys())

# === Plotting function ===
def plot_metric_distribution(df, metric, log_scale=False):
    label = label_dict.get(metric, metric)
    plt.figure(figsize=(8, 6))

    values = df[metric].dropna()
    if log_scale:
        bins = np.logspace(np.log10(values.min() + 1e-3), np.log10(values.max() + 1), 30)
    else:
        bins = 25

    sns.histplot(values, bins=bins, kde=True, stat="density", color="steelblue", edgecolor=None, alpha=0.8)
    plt.title(f"Distribution of {label}", fontsize=14)
    plt.xlabel(label, fontsize=12)
    plt.ylabel("Density", fontsize=12)
    if log_scale:
        plt.xscale("log")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"plot_{metric}.svg"))
    plt.close()

# === Generate all plots ===
for metric in metrics_to_plot:
    logscale = metric in ["eff_size_global", "eff_size_target", "self_rank"]
    plot_metric_distribution(combined_df, metric, log_scale=logscale)


In [None]:
# FIG 5 E/F FIG S5D


import os
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# === Setup ===
save_dir = "contribution_outputs"
output_dir = "FIG_TRADE"
os.makedirs(output_dir, exist_ok=True)

# === Load all metric files ===
all_metrics = []
for file in os.listdir(save_dir):
    if file.endswith("_macro_contribution.pkl"):
        with open(os.path.join(save_dir, file), "rb") as f:
            data = pickle.load(f)
            df = pd.DataFrame(data)
            df["dataset"] = file.replace("_macro_contribution.pkl", "")
            all_metrics.append(df)

combined_df = pd.concat(all_metrics, ignore_index=True)

# === Descriptive labels ===
label_dict = {
    "eff_size_global": "Number of Genes Driving Global Transcriptomic Change",
    "eff_size_target": "Genes Contributing to True Perturbed Gene's Response",
    "self_rank": "Rank of True Perturbed Gene by Its Own Response Influence",
    # "genes_for_50pct": "Genes Explaining 50% of Total Response",
    # "genes_for_75pct": "Genes Explaining 75% of Total Response",
    # "genes_for_90pct": "Genes Explaining 90% of Total Response",
    # "genes_for_95pct": "Genes Explaining 95% of Total Response",
    "genes_for_99pct": "Genes Explaining 99% of Total Response"
}

metrics_to_plot = list(label_dict.keys())

# === Plotting function ===
def freedman_diaconis_bins(data):
    """Compute number of bins using the Freedman–Diaconis rule."""
    data = np.asarray(data)
    iqr = np.subtract(*np.percentile(data, [75, 25]))
    bin_width = 2 * iqr / (len(data) ** (1 / 3) + 1e-8)
    if bin_width == 0:
        return 25  # fallback
    return int(np.ceil((data.max() - data.min()) / bin_width))

from scipy.stats import kurtosis

def plot_metric_distribution(df, metric, num_bins=150):
    values = df[metric].dropna()
    label = label_dict.get(metric, metric)

    # Determine log-scale use
    skewness = values.skew()
    value_range = values.max() / (values.min() + 1e-5)
    use_log = skewness > 2 or value_range > 100

    # Choose bins
    if use_log:
        bins = np.logspace(np.log10(values.min() + 1e-3), np.log10(values.max() + 1), num_bins)
    else:
        bins = num_bins

    # Compute statistics
    mean_val = values.mean()
    kurt = kurtosis(values, fisher=False)

    # Plot
    plt.figure(figsize=(8, 6))
    sns.histplot(values, bins=bins, kde=False, stat="density", color="steelblue", edgecolor="black", alpha=0.85)

    # Annotations
    annotation_text = f"Mean = {mean_val:.1f}\nKurtosis = {kurt:.1f}"
    plt.annotate(annotation_text, xy=(0.98, 0.95), xycoords='axes fraction',
                 ha='right', va='top', fontsize=11, bbox=dict(facecolor='white', edgecolor='gray', alpha=0.7))

    # Labels and formatting
    plt.xlabel(label, fontsize=12)
    plt.ylabel("Density", fontsize=12)
    plt.title(f"Distribution of {label}", fontsize=14)
    if use_log:
        plt.xscale("log")
    plt.grid(True, linestyle="--", alpha=0.4)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"plot_{metric}.svg"))
    plt.close()




# === Generate all plots ===
for metric in metrics_to_plot:
    plot_metric_distribution(combined_df, metric)


In [None]:
# get all u vectors for further analysis


import os
import numpy as np
import pickle
from scipy.linalg import pinv
from scipy.sparse import issparse

from anndata import read_h5ad


save_dir = "u_vectors"
os.makedirs(save_dir, exist_ok=True)

data_paths = [
    "TianKampmann2019_day7neuron.h5ad",
    "ReplogleWeissman2022_rpe1.h5ad",
    "ReplogleWeissman2022_K562_essential.h5ad",
    "GSE264667_jurkat_raw_singlecell_01.h5ad", 
    "GSE264667_hepg2_raw_singlecell_01.h5ad",
    "NormanWeissman2019_filtered.h5ad",
    "FrangiehIzar2021_RNA.h5ad",
    "TianKampmann2021_CRISPRi.h5ad",
    "TianKampmann2021_CRISPRa.h5ad",
    "TianKampmann2019_iPSC.h5ad"
]

for data_path in data_paths:
    print(f"Processing: {data_path}")
    base = os.path.basename(data_path).replace(".h5ad", "")
    
    try:
        adata, X0, _ = get_data(0, data_path)
    except Exception as e:
        print(f"Failed on {base}: {e}")
        continue

    Sigma = np.cov(X0, rowvar=False)
    Sigma_inv = pinv(Sigma)

    gene_names = np.array(adata.var_names)
    pert_list = [p for p in adata.obs['perturbation'].unique() if p != "control"]

    results = []

    for pert in pert_list:
        X1 = adata[adata.obs['perturbation'] == pert].X
        if X1.shape[0] < 2:
            continue
        X1 = X1.toarray() if issparse(X1) else X1
        delta_X = np.mean(X1, axis=0) - np.mean(X0, axis=0)
        u_star = Sigma_inv @ delta_X

        results.append({
            "perturbation": pert,
            "u_star": u_star,
        })

    with open(os.path.join(save_dir, f"{base}_u_vectors.pkl"), "wb") as f:
        pickle.dump(results, f)

    print(f"Saved: {base}_u_vectors.pkl")


In [None]:
# entropy of u


import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.stats import entropy

# === Parameters ===
save_dir = "u_vectors"
fig_dir = "u_entropy_figures"
os.makedirs(fig_dir, exist_ok=True)

# === Collect entropies and vectors ===
all_entropies = []
entropies_by_dataset = {}
u_vectors_by_dataset = {}

for file in os.listdir(save_dir):
    if not file.endswith("_u_vectors.pkl"):
        continue
    dataset = file.replace("_u_vectors.pkl", "")
    with open(os.path.join(save_dir, file), "rb") as f:
        records = pickle.load(f)

    entropies = []
    vectors = []

    for record in records:
        u = np.abs(record["u_star"])
        p = u / (np.sum(u) + 1e-8)
        ent = entropy(p)
        entropies.append(ent)
        vectors.append((record["perturbation"], record["u_star"], ent))

    entropies_by_dataset[dataset] = entropies
    u_vectors_by_dataset[dataset] = vectors
    all_entropies.extend(entropies)

# === Plot overall entropy histogram ===
plt.figure(figsize=(8, 6))
sns.histplot(all_entropies, bins=30, kde=True, stat="density", color="teal")
plt.xlabel("Entropy of |u| Distribution", fontsize=14)
plt.ylabel("Density", fontsize=14)
plt.title("Distribution of Entropy Across All Perturbations", fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, "overall_entropy_histogram.svg"))
plt.close()

# === Plot per-dataset entropy histograms ===
for dataset, entropies in entropies_by_dataset.items():
    plt.figure(figsize=(7, 5))
    sns.histplot(entropies, bins=25, kde=True, stat="density", color="steelblue")
    plt.xlabel("Entropy of |u|", fontsize=13)
    plt.ylabel("Density", fontsize=13)
    plt.title(f"{dataset}: Entropy of Perturbation Vectors", fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, f"{dataset}_entropy_histogram.svg"))
    plt.close()

# === Plot examples of high and low entropy perturbations ===
for dataset, records in u_vectors_by_dataset.items():
    if len(records) < 6:
        continue

    records_sorted = sorted(records, key=lambda x: x[2])  # sort by entropy
    low_ents = records_sorted[:3]
    high_ents = records_sorted[-3:]

    fig, axs = plt.subplots(2, 3, figsize=(15, 6), sharey=True)
    for i, (pert, uvec, ent) in enumerate(low_ents):
        axs[0, i].plot(uvec)
        axs[0, i].set_title(f"Low Entropy\n{pert}\nH={ent:.2f}")
    for i, (pert, uvec, ent) in enumerate(high_ents):
        axs[1, i].plot(uvec)
        axs[1, i].set_title(f"High Entropy\n{pert}\nH={ent:.2f}")

    fig.suptitle(f"{dataset} — Examples of High vs Low Entropy Perturbations", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(os.path.join(fig_dir, f"{dataset}_u_vector_entropy_examples.svg"))
    plt.close()


In [None]:
# entropy of u continued

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from scipy.stats import entropy

# === Paths ===
save_dir = "u_vectors"
fig_dir = "u_entropy_figures"
os.makedirs(fig_dir, exist_ok=True)

# === Collect entropies by dataset ===
entropies_by_dataset = {}

for file in os.listdir(save_dir):
    if not file.endswith("_u_vectors.pkl"):
        continue
    dataset = file.replace("_u_vectors.pkl", "")
    with open(os.path.join(save_dir, file), "rb") as f:
        records = pickle.load(f)

    entropies = []
    for record in records:
        u = np.abs(record["u_star"])
        p = u / (np.sum(u) + 1e-8)
        ent = entropy(p)
        entropies.append(ent)

    entropies_by_dataset[dataset] = entropies

# === Overlaid entropy histograms ===
plt.figure(figsize=(10, 6))

colors = sns.color_palette("hls", len(entropies_by_dataset))

for color, (dataset, entropies) in zip(colors, entropies_by_dataset.items()):
    sns.kdeplot(entropies, fill=True, alpha=0.3, label=dataset, color=color, linewidth=2)

plt.xlabel("Entropy of |u| Distribution", fontsize=14)
plt.ylabel("Density", fontsize=14)
plt.title("Entropy of Perturbation Vectors Across Datasets", fontsize=16)
plt.legend(fontsize=9, loc="upper right", ncol=1)
plt.yscale('log')
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, "entropy_histograms_overlay.svg"))
plt.close()
