In [None]:
import os
import pandas as pd
os.chdir('../')

# Define source paths
data_dir = 'data'
out_dir = 'output'
ds_info = pd.read_csv('resources/datasets.csv')
datasets = ds_info['file']

In [None]:
# calculations for fig 3Q
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy.sparse import issparse
from scipy.linalg import pinv
import matplotlib.pyplot as plt

from src.preprocess import get_matched_data
from src.cov import cross_covariance_analysis

# === Define dataset paths ===
data_paths = [
    "TianKampmann2019_day7neuron.h5ad",
    "TianKampmann2021_CRISPRa.h5ad",
    # "TianKampmann2021_CRISPRi.h5ad",
]
reference_path = "TianKampmann2021_CRISPRi.h5ad"


# === Analysis Loop ===
for data_name_1 in data_paths:
    data_path_1 = os.path.join(data_dir, data_name_1)
    data_path_2 = os.path.join(data_dir, reference_path)
    print(f"Processing {data_path_1} vs {data_path_2}")

    # Load and preprocess data
    adata1, adata2, X0_1, X0_2 = get_matched_data(data_path_1, data_path_2)
    adata1.uns["_file"] = data_path_1
    adata2.uns["_file"] = data_path_2

    # Run cross-covariance R² analysis and save plots
    df_results = cross_covariance_analysis(adata1.copy(), X0_1, adata2.copy(), X0_2)


In [None]:

# === Define dataset paths ===
data_paths = [
    "TianKampmann2021_CRISPRi.h5ad",
    # "TianKampmann2021_CRISPRi.h5ad",
]
reference_path = "TianKampmann2021_CRISPRa.h5ad"

# === Analysis Loop ===
for data_name_1 in data_paths:
    data_path_1 = os.path.join(data_dir, data_name_1)
    data_path_2 = os.path.join(data_dir, reference_path)
    print(f"Processing {data_path_1} vs {data_path_2}")
    # Load and preprocess data
    adata1, adata2, X0_1, X0_2 = get_matched_data(data_path_1, data_path_2)
    adata1.uns["_file"] = data_path_1
    adata2.uns["_file"] = data_path_2

    # Run cross-covariance R² analysis and save plots
    df_results = cross_covariance_analysis(adata1.copy(), X0_1, adata2.copy(), X0_2)

In [None]:
# === Plotting Section ===
import matplotlib.pyplot as plt

load_dir = "r2_cross_covariance_acrosscelltype"
save_dir = "HS_figs"
os.makedirs(save_dir, exist_ok=True)

csv_files = [
    "R2_cross_TianKampmann2019_day7neuron_vs_TianKampmann2021_CRISPRi.csv",
    "R2_cross_TianKampmann2021_CRISPRa_vs_TianKampmann2021_CRISPRi.csv",
    "R2_cross_TianKampmann2021_CRISPRi_vs_TianKampmann2021_CRISPRa.csv",
]

for csv_file in csv_files:
    df = pd.read_csv(os.path.join(load_dir, csv_file))
    basename = csv_file.replace("R2_cross_", "").replace(".csv", "")

    for ds in ["1", "2"]:
        df_subset = df[df["dataset"] == ds]
        if df_subset.empty:
            continue

        xcol = "R2_Sigma1_dX" if ds == "1" else "R2_Sigma2_dX"
        ycol = "R2_Sigma2_dX" if ds == "1" else "R2_Sigma1_dX"

        plt.figure(figsize=(6, 6))
        plt.scatter(df_subset[xcol], df_subset[ycol], alpha=0.6)
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel("R² (true Σ)", fontsize=14)
        plt.ylabel("R² (cross Σ)", fontsize=14)
        plt.title(f"R²: {basename} (dataset {ds})", fontsize=15)
        plt.grid(True)
        plt.tight_layout()

        plot_filename = f"scatter_R2_{basename}_ds{ds}.svg"
        plt.savefig(os.path.join(save_dir, plot_filename))
        plt.show()

In [None]:
# fig 3Q



import os
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

# === Set directories ===
load_dir = "r2_cross_covariance_acrosscelltype"
save_dir = "HS_figs"
os.makedirs(save_dir, exist_ok=True)

# === Files to plot ===
csv_files = {
    "Day7 vs CRISPRi": "R2_cross_TianKampmann2019_day7neuron_vs_TianKampmann2021_CRISPRi.csv",
    "CRISPRa vs CRISPRi": "R2_cross_TianKampmann2021_CRISPRa_vs_TianKampmann2021_CRISPRi.csv",
    "CRISPRi vs CRISPRa": "R2_cross_TianKampmann2021_CRISPRi_vs_TianKampmann2021_CRISPRa.csv",
}

# === Colors for each comparison ===
colors = {
    "Day7 vs CRISPRi": "tab:blue",
    "CRISPRa vs CRISPRi": "tab:orange",
    "CRISPRi vs CRISPRa": "tab:green",
}

# === Initialize plot and data containers ===
plt.figure(figsize=(7, 7))
true_all = []
cross_all = []

# === Plot all points in one scatter plot and collect data ===
for label, file in csv_files.items():
    path = os.path.join(load_dir, file)
    if not os.path.exists(path):
        print(f"Missing file: {file}")
        continue

    df = pd.read_csv(path)
    if df.empty or not {"R2_Sigma1_dX", "R2_Sigma2_dX"}.issubset(df.columns):
        print(f"Skipping invalid or empty file: {file}")
        continue

    x_vals = df["R2_Sigma1_dX"].values
    y_vals = df["R2_Sigma2_dX"].values
    plt.scatter(x_vals, y_vals, alpha=0.5, label=label, s=20, color=colors[label])

    # Append to combined lists
    true_all.extend(x_vals)
    cross_all.extend(y_vals)

# === Reference line and labels ===
plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
plt.xlabel("R² (true Σ)", fontsize=14)
plt.ylabel("R² (cross Σ)", fontsize=14)
plt.title("R² comparison across neuron datasets", fontsize=15)
plt.legend(fontsize=11)
plt.grid(True)
plt.tight_layout()

# === Save and show plot ===
plt.savefig(os.path.join(save_dir, "R2_neuron_comparison_all.svg"))
plt.show()

# === Compute total R² between cross and true Σ ===
if len(true_all) > 1 and len(true_all) == len(cross_all):
    overall_r2 = r2_score(true_all, cross_all)
    print(f"Overall R² (cross Σ vs true Σ across all comparisons): {overall_r2:.4f}")
else:
    print("Not enough data to compute overall R².")


In [None]:
# calculations for FIG 3 R


import os
import numpy as np
import pandas as pd
from scipy.sparse import issparse

# === Dataset lists ===
all_paths = ds_info['file']
neuron_targets = [
    f"{data_dir}/TianKampmann2021_CRISPRi.h5ad",
    f"{data_dir}/TianKampmann2021_CRISPRa.h5ad",
    f"{data_dir}/TianKampmann2019_day7neuron.h5ad"
]

save_dir = "r2_null_cross_vs_true"
os.makedirs(save_dir, exist_ok=True)

def compute_r2(Sigma, delta_X, gene_idx, epsilon=1e-8):
    sigma_col = Sigma[:, gene_idx]
    u_opt = np.dot(sigma_col, delta_X) / (np.dot(sigma_col, sigma_col) + epsilon)
    pred = u_opt * sigma_col
    valid = np.abs(delta_X) > 0
    if not np.any(valid):
        return np.nan
    return 1.0 - np.sum((delta_X[valid] - pred[valid])**2) / (np.sum(delta_X[valid]**2) + epsilon)

# === Iterate over all dataset pairs: data_path_1 vs neuron_target_2 ===
for data_path in all_paths:
    data_path_1 = os.path.join(data_dir, data_path)
    for data_path_2 in neuron_targets:
        print(f"Comparing {data_path_1} vs {data_path_2}")
        adata1, adata2, X0_1, X0_2 = get_matched_data(data_path_1, data_path_2, expression_threshold=1.0, min_samples=100)

        gene_names = np.array(adata1.var_names)
        X0_1_dense = X0_1.toarray() if issparse(X0_1) else X0_1
        X0_2_dense = X0_2.toarray() if issparse(X0_2) else X0_2

        Sigma_true = np.cov(X0_1_dense, rowvar=False)
        Sigma_cross = np.cov(X0_2_dense, rowvar=False)

        # Mean-field from shuffled X0_2
        X0_shuffled = X0_2_dense.copy()
        for g in range(X0_shuffled.shape[1]):
            np.random.shuffle(X0_shuffled[:, g])
        Sigma_meanfield = np.cov(X0_shuffled, rowvar=False)

        Sigma_random = np.cov(np.random.randn(*X0_1_dense.shape), rowvar=False)

        results = []
        perturbations = [p for p in adata1.obs['perturbation'].unique() if p != 'control' and '_' not in p]

        for pert in perturbations:
            if pert not in gene_names:
                continue
            gene_idx = np.where(gene_names == pert)[0][0]
            X1 = adata1[adata1.obs['perturbation'] == pert].X
            X1 = X1.toarray() if issparse(X1) else X1
            delta_X = X1.mean(axis=0) - X0_1_dense.mean(axis=0)

            results.append({
                "perturbation": pert,
                "R2_true": compute_r2(Sigma_true, delta_X, gene_idx),
                "R2_cross": compute_r2(Sigma_cross, delta_X, gene_idx),
                "R2_meanfield": compute_r2(Sigma_meanfield, delta_X, gene_idx),
                "R2_random": compute_r2(Sigma_random, delta_X, gene_idx),
            })

        df = pd.DataFrame(results)
        base_1 = os.path.basename(data_path_1).replace(".h5ad", "")
        base_2 = os.path.basename(data_path_2).replace(".h5ad", "")
        out_name = f"r2_compare_{base_1}_vs_{base_2}.csv"
        df.to_csv(os.path.join(save_dir, out_name), index=False)


In [None]:
#  FIG 3 R



import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# === Directories ===
load_dir = "r2_null_cross_vs_true"
save_dir = "HS_figs"
os.makedirs(save_dir, exist_ok=True)

# === Neuron dataset identifiers ===
neuron_datasets = {
    "TianKampmann2019_day7neuron",
    "TianKampmann2021_CRISPRi",
    "TianKampmann2021_CRISPRa"
}

# === Classification logic ===
def classify_comparison(base1, base2):
    if base1 in neuron_datasets and base2 in neuron_datasets:
        if base1 == base2:
            return "same_dataset_true"
        else:
            return "same_celltype_cross"
    else:
        return "cross_celltype"

# === Initialize containers ===
r2_by_type = {
    "random_cross": [],
    "meanfield_cross": [],
    "meanfield_same": [],
    "cross_celltype": [],
    "same_celltype_cross": [],
    "same_dataset_true": [],
}

r2_individual_curves = {
    "same": [],
    "cross": []
}

# === Collect data ===
for fname in os.listdir(load_dir):
    if not fname.endswith(".csv"):
        continue

    base1, base2 = fname.replace("r2_compare_", "").replace(".csv", "").split("_vs_")
    df = pd.read_csv(os.path.join(load_dir, fname))
    if df.empty:
        continue

    ctype = classify_comparison(base1, base2)

    # Collect means for aggregation
    if ctype == "same_dataset_true":
        r2_by_type["same_dataset_true"].append(df["R2_true"].mean())
        r2_by_type["meanfield_same"].append(df["R2_meanfield"].mean())
        curve = [
            df["R2_random"].mean(),
            df["R2_meanfield"].mean(),
            df["R2_cross"].mean(),
            df["R2_true"].mean()
        ]
        r2_individual_curves["same"].append(curve)

    elif ctype == "same_celltype_cross":
        r2_by_type["same_celltype_cross"].append(df["R2_cross"].mean())
        r2_by_type["meanfield_same"].append(df["R2_meanfield"].mean())
        curve = [
            df["R2_random"].mean(),
            df["R2_meanfield"].mean(),
            df["R2_cross"].mean(),
            np.nan  # no true Σ in this case
        ]
        r2_individual_curves["same"].append(curve)

    elif ctype == "cross_celltype":
        r2_by_type["cross_celltype"].append(df["R2_cross"].mean())
        r2_by_type["meanfield_cross"].append(df["R2_meanfield"].mean())
        r2_by_type["random_cross"].append(df["R2_random"].mean())
        curve = [
            df["R2_random"].mean(),
            df["R2_meanfield"].mean(),
            df["R2_cross"].mean(),
            np.nan  # no true Σ here either
        ]
        r2_individual_curves["cross"].append(curve)

# === Labels and means ===
labels = [
    "Random Σ (cross)",
    "Mean-field Σ (cross)",
    "Full Σ (same/cross)",
    "True Σ (same dataset)"
]

means_same = [
    np.nanmean(r2_by_type["random_cross"]),
    np.nanmean(r2_by_type["meanfield_same"]),
    np.nanmean(r2_by_type["same_celltype_cross"]),
    np.nanmean(r2_by_type["same_dataset_true"]),
]

means_cross = [
    np.nanmean(r2_by_type["random_cross"]),
    np.nanmean(r2_by_type["meanfield_cross"]),
    np.nanmean(r2_by_type["cross_celltype"]),
    np.nanmean(r2_by_type["same_dataset_true"]),  # same true for both plots
]

# === Plot 4: Combined Averages + All Dataset Curves ===
plt.figure(figsize=(8, 6))

# Plot faint individual lines
for curve in r2_individual_curves["same"]:
    plt.plot(labels, curve, color='salmon', alpha=0.2, linewidth=1)

for curve in r2_individual_curves["cross"]:
    plt.plot(labels, curve, color='blueviolet', alpha=0.2, linewidth=1)

# Plot the means
plt.plot(labels, means_same, marker="o", linewidth=2.5, color="salmon", label="Same Cell Type")
plt.plot(labels, means_cross, marker="s", linewidth=2.5, color="blueviolet", label="Cross Cell Type")

# Final formatting
plt.ylabel("Average R²", fontsize=14)
plt.xticks(rotation=30, ha="right", fontsize=12)
plt.yticks(fontsize=12)
plt.title("R² – Same vs Cross Cell Type (with Dataset Curves)", fontsize=15)
plt.legend(fontsize=11)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "R2_same_vs_cross_with_individuals.svg"))
plt.close()


In [None]:
# pvalues for FIG 3 R

import os
import pickle
import numpy as np
import pandas as pd
from scipy.stats import wilcoxon
import matplotlib.pyplot as plt
import seaborn as sns

# === Directories ===
load_dir = "r2_null_cross_vs_true"
save_dir = "HS_figs"
os.makedirs(save_dir, exist_ok=True)

# === Neuron identifiers ===
neuron_datasets = {
    "TianKampmann2019_day7neuron",
    "TianKampmann2021_CRISPRi",
    "TianKampmann2021_CRISPRa"
}

def classify_comparison(base1, base2):
    if base1 in neuron_datasets and base2 in neuron_datasets:
        if base1 == base2:
            return "same_dataset_true"
        else:
            return "same_celltype_cross"
    else:
        return "cross_celltype"

# === Collect curves per type ===
r2_individual_curves = {"same": [], "cross": []}

for fname in os.listdir(load_dir):
    if not fname.endswith(".csv"):
        continue

    base1, base2 = fname.replace("r2_compare_", "").replace(".csv", "").split("_vs_")
    df = pd.read_csv(os.path.join(load_dir, fname))
    if df.empty:
        continue

    ctype = classify_comparison(base1, base2)
    curve = [
        df["R2_random"].mean(),
        df["R2_meanfield"].mean(),
        df["R2_cross"].mean(),
        df["R2_true"].mean() if "R2_true" in df.columns else np.nan
    ]

    if ctype in {"same_dataset_true", "same_celltype_cross"}:
        r2_individual_curves["same"].append(curve)
    elif ctype == "cross_celltype":
        r2_individual_curves["cross"].append(curve)

same_curves = np.array(r2_individual_curves["same"])
cross_curves = np.array(r2_individual_curves["cross"])

# === Stats function for one-sided test ===
def print_wilcox_pval(label, a1, a2, alternative="greater"):
    mask = ~np.isnan(a1) & ~np.isnan(a2)
    if mask.sum() < 3:
        print(f"{label}: Not enough valid samples (n={mask.sum()})")
        return
    stat, p = wilcoxon(a1[mask], a2[mask], alternative=alternative)
    print(f"{label} (alt: {alternative}):")
    print(f"  ➤ p = {p:.6g}\n")

# === Output ===
print("\nWilcoxon Signed-Rank Tests (Same Cell Type):")
print_wilcox_pval("Same: Random < Mean-field", same_curves[:, 0], same_curves[:, 1], alternative="less")
print_wilcox_pval("Same: Mean-field < Cross", same_curves[:, 1], same_curves[:, 2], alternative="less")
print_wilcox_pval("Same: Cross < True Σ", same_curves[:, 2], same_curves[:, 3], alternative="less")

print("\nWilcoxon Signed-Rank Tests (Cross Cell Type):")
print_wilcox_pval("Cross: Random < Mean-field", cross_curves[:, 0], cross_curves[:, 1], alternative="less")
print_wilcox_pval("Cross: Mean-field < Cross", cross_curves[:, 1], cross_curves[:, 2], alternative="less")


In [None]:
# FIG S3E


import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

# === Directories ===
load_dir = "r2_null_cross_vs_true"
save_dir = "HS_figs"
os.makedirs(save_dir, exist_ok=True)

# === Define target and exclusions ===
target_dataset = "TianKampmann2021_CRISPRi"
exclude_datasets = {
    "TianKampmann2021_CRISPRi",
    "TianKampmann2021_CRISPRa",
    "TianKampmann2019_day7neuron"
}

# === Accumulate R² pairs ===
r2_true_all = []
r2_cross_all = []

for fname in os.listdir(load_dir):
    if not fname.endswith(".csv"):
        continue

    base1, base2 = fname.replace("r2_compare_", "").replace(".csv", "").split("_vs_")
    if base2 != target_dataset or base1 in exclude_datasets:
        continue

    df = pd.read_csv(os.path.join(load_dir, fname))
    if df.empty or not {"R2_true", "R2_cross"}.issubset(df.columns):
        continue

    valid = ~df["R2_true"].isna() & ~df["R2_cross"].isna()
    r2_true_all.extend(df["R2_true"][valid])
    r2_cross_all.extend(df["R2_cross"][valid])

r2_true_all = np.array(r2_true_all)
r2_cross_all = np.array(r2_cross_all)

# === Build grid for 2D KDE ===
xy = np.vstack([r2_true_all, r2_cross_all])
kde = gaussian_kde(xy)

x_grid = np.linspace(0, 1, 200)
y_grid = np.linspace(0, 1, 200)
X, Y = np.meshgrid(x_grid, y_grid)
Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)

# === Plot contour ===
plt.figure(figsize=(7, 6))
contour = plt.contourf(X, Y, Z, levels=50, cmap="coolwarm")
plt.plot([0, 1], [0, 1], 'k--', lw=1)

plt.xlabel("R² (true Σ)", fontsize=14)
plt.ylabel("R² (cross Σ)", fontsize=14)
plt.title("Contour: R² Cross vs True Σ (Non-neuron → CRISPRi)", fontsize=15)
plt.colorbar(contour, label="Density")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "R2_true_vs_cross_contour_non_neuron_to_CRISPRi.svg"))
plt.close()
