In [None]:
import pandas as pd
import numpy as np
import anndata as ad
from scipy.sparse import csr_matrix

def load_merfish_astrocytes(design_path, expression_path, expression_threshold=1.0, min_samples=100):
    # === Load data ===
    expr_df = pd.read_csv(expression_path, index_col=0)
    design_df = pd.read_csv(design_path, index_col=0)

    print("Expression shape (cells × genes):", expr_df.shape)
    print("Design shape (perturbations × cells):", design_df.shape)

    # Transpose design to cells × perturbations
    pert_df = design_df.T
    assert pert_df.shape[0] == expr_df.shape[0], "Mismatch: design and expression should have same # of cells"

    # === Build preliminary AnnData (before gene filtering)
    obs_df_pre = pd.DataFrame(index=expr_df.index)
    X_pre = csr_matrix(expr_df.values)
    adata_pre = ad.AnnData(X_pre, obs=obs_df_pre)
    adata_pre.var_names = expr_df.columns
    adata_pre.obs_names_make_unique()

    # === Filter genes by expression
    gene_means = np.array(adata_pre.X.mean(axis=0)).flatten()
    keep_genes = gene_means >= expression_threshold
    adata = adata_pre[:, keep_genes]

    print("After filtering: measured genes =", adata.shape[1])

    # === Restrict to only perturbations that target measured genes
    measured_genes = set(adata.var_names)
    perturbed_genes = set(design_df.index)
    valid_perturbations = sorted(perturbed_genes.intersection(measured_genes))

    print(f"Valid perturbations targeting measured genes: {len(valid_perturbations)}")

    # Subset the design matrix
    design_df = design_df.loc[valid_perturbations]
    pert_df = design_df.T

    # === Build perturbation labels
    pert_labels = []
    for row in pert_df.itertuples(index=False):
        perturbed = pert_df.columns[np.array(row) == 1]
        if any(ctrl in perturbed for ctrl in ["Control_NT", "Control_ST"]):
            label = "control"
        else:
            label = "+".join(sorted(perturbed)) if len(perturbed) > 0 else "control"
        pert_labels.append(label)

    # Add labels to obs
    adata.obs['perturbation'] = pert_labels

    # === Filter perturbations by min count
    pert_counts = adata.obs['perturbation'].value_counts()
    keep_perts = pert_counts[pert_counts >= min_samples].index
    adata = adata[adata.obs['perturbation'].isin(keep_perts)]

    # === Extract control matrix
    control_mask = adata.obs['perturbation'] == "control"
    X0 = adata[control_mask].X.toarray()

    print("Final AnnData shape:", adata.shape)
    print("Top perturbations:\n", adata.obs['perturbation'].value_counts().head())
    print("Shared genes:", list(set(valid_perturbations) & set(adata.var_names))[:10])

    return adata, X0, valid_perturbations

# === USAGE ===
adata, X0, valid_genes = load_merfish_astrocytes(
    "perturbation_design_astrocytes.csv",
    "merfish_perturbed_cells_astrocytes.csv",
    expression_threshold=1.0,
    min_samples=0
)

print("Example genes:", adata.var_names[:5].tolist())
print("Valid perturbed+measured genes:", valid_genes[:10])
print("X0 shape:", X0.shape)


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.linalg import pinv
from scipy.sparse import issparse, csr_matrix
from anndata import AnnData

def compute_covariance(X):
    return np.cov(X, rowvar=False)

def compute_average_response(X0, X1):
    return X1.mean(axis=0) - X0.mean(axis=0)

def compute_soft_fraction(Sigma, u, threshold_mode='fraction_variance', threshold_value=0.7):
    lambda_vals, V = np.linalg.eigh(Sigma)
    idx = np.argsort(lambda_vals)[::-1]
    lambda_vals = lambda_vals[idx]
    V = V[:, idx]
    c = V.T @ u
    c2 = c ** 2
    if threshold_mode == 'fraction_variance':
        total_var = np.sum(lambda_vals)
        cum_var = np.cumsum(lambda_vals)
        soft_indices = np.where(cum_var <= threshold_value * total_var)[0]
    elif threshold_mode == 'relative_max':
        lambda_max = np.max(lambda_vals)
        soft_indices = np.where(lambda_vals >= threshold_value * lambda_max)[0]
    elif threshold_mode == 'elbow':
        diffs = np.diff(lambda_vals)
        second_diffs = np.diff(diffs)
        elbow_idx = np.argmax(second_diffs)
        soft_indices = np.arange(elbow_idx + 1)
    else:
        raise ValueError("Invalid threshold_mode.")
    f_soft = np.sum(c2[soft_indices]) / np.sum(c2)
    return f_soft, soft_indices

import pandas as pd
import numpy as np
import anndata as ad
from scipy.sparse import csr_matrix

def load_merfish_astrocytes(design_path, expression_path, expression_threshold=1.0, min_samples=100):
    # === Load data ===
    expr_df = pd.read_csv(expression_path, index_col=0)
    design_df = pd.read_csv(design_path, index_col=0)

    print("Expression shape (cells × genes):", expr_df.shape)
    print("Design shape (perturbations × cells):", design_df.shape)

    # Transpose design to cells × perturbations
    pert_df = design_df.T
    assert pert_df.shape[0] == expr_df.shape[0], "Mismatch: design and expression should have same # of cells"

    # === Build preliminary AnnData (before gene filtering)
    obs_df_pre = pd.DataFrame(index=expr_df.index)
    X_pre = csr_matrix(expr_df.values)
    adata_pre = ad.AnnData(X_pre, obs=obs_df_pre)
    adata_pre.var_names = expr_df.columns
    adata_pre.obs_names_make_unique()

    # === Filter genes by expression
    gene_means = np.array(adata_pre.X.mean(axis=0)).flatten()
    keep_genes = gene_means >= expression_threshold
    adata = adata_pre[:, keep_genes]

    print("After filtering: measured genes =", adata.shape[1])

    # === Restrict to only perturbations that target measured genes
    measured_genes = set(adata.var_names)
    perturbed_genes = set(design_df.index)
    valid_perturbations = sorted(perturbed_genes.intersection(measured_genes))

    print(f"Valid perturbations targeting measured genes: {len(valid_perturbations)}")

    # Subset the design matrix
    design_df = design_df.loc[valid_perturbations]
    pert_df = design_df.T

    # === Build perturbation labels
    pert_labels = []
    for row in pert_df.itertuples(index=False):
        perturbed = pert_df.columns[np.array(row) == 1]
        if any(ctrl in perturbed for ctrl in ["Control_NT", "Control_ST"]):
            label = "control"
        else:
            label = "+".join(sorted(perturbed)) if len(perturbed) > 0 else "control"
        pert_labels.append(label)

    # Add labels to obs
    adata.obs['perturbation'] = pert_labels

    # === Filter perturbations by min count
    pert_counts = adata.obs['perturbation'].value_counts()
    keep_perts = pert_counts[pert_counts >= min_samples].index
    adata = adata[adata.obs['perturbation'].isin(keep_perts)]

    # === Extract control matrix
    control_mask = adata.obs['perturbation'] == "control"
    X0 = adata[control_mask].X.toarray()

    print("Final AnnData shape:", adata.shape)
    print("Top perturbations:\n", adata.obs['perturbation'].value_counts().head())
    print("Shared genes:", list(set(valid_perturbations) & set(adata.var_names))[:10])

    return adata, X0, valid_perturbations

def run_analysis_on_merfish():
    adata, X0, valid_genes = load_merfish_astrocytes(
                            "perturbation_design_astrocytes.csv",
                            "merfish_perturbed_cells_astrocytes.csv",
                            expression_threshold=1.0,
                            min_samples=0
                            )
    gene_names = np.array(adata.var_names.tolist())
    X0_dense = X0 if not issparse(X0) else X0.toarray()
    Sigma_real = compute_covariance(X0_dense)

    X0_gene_shuffled = X0_dense.copy()
    for g in range(X0_gene_shuffled.shape[1]):
        np.random.shuffle(X0_gene_shuffled[:, g])
    Sigma_null = compute_covariance(X0_gene_shuffled)

    X0_flat = np.array(X0_dense).flatten()
    X0_shuffled_flat = np.random.permutation(X0_flat)
    X0_full_shuffled = X0_shuffled_flat.reshape(X0_dense.shape)
    Sigma_rand = compute_covariance(X0_full_shuffled)

    perturbations = [p for p in adata.obs['perturbation'].unique() if p != 'control']
    epsilon = 1e-8

    R2_real, R2_null, R2_rand = [], [], []
    f_soft_scores = []
    pert_names = []

    for pert in perturbations:
        if pert not in gene_names:
            continue

        gene_idx = np.where(gene_names == pert)[0][0]
        X1 = adata[adata.obs['perturbation'] == pert].X
        X1 = X1.toarray() if issparse(X1) else X1

        delta_X = compute_average_response(X0_dense, X1)

        def predict_r2(Sigma):
            sigma_col = Sigma[:, gene_idx]
            u_opt = np.dot(sigma_col, delta_X) / (np.dot(sigma_col, sigma_col) + epsilon)
            pred = u_opt * sigma_col
            valid = np.abs(delta_X) > 0
            if not np.any(valid):
                return np.nan
            return 1.0 - np.sum((delta_X[valid] - pred[valid])**2) / (np.sum(delta_X[valid]**2) + epsilon)

        R2_real.append(predict_r2(Sigma_real))
        R2_null.append(predict_r2(Sigma_null))
        R2_rand.append(predict_r2(Sigma_rand))
        f_soft, _ = compute_soft_fraction(Sigma_real, Sigma_real[:, gene_idx])
        f_soft_scores.append(f_soft)
        pert_names.append(pert)

    save_dir = "merfish_r2_histograms"
    os.makedirs(save_dir, exist_ok=True)
    base_name = "merfish_astrocytes"

    plt.figure(figsize=(8, 6))
    plt.hist(R2_real, bins=30, alpha=0.7, label='Real Σ', density=True)
    plt.hist(R2_null, bins=30, alpha=0.7, label='Shuffled X₀', density=True)
    plt.hist(R2_rand, bins=30, alpha=0.7, label='Fully Shuffled X₀', density=True)
    plt.xlabel("R² Score", fontsize=18)
    plt.ylabel("Density", fontsize=18)
    plt.yscale('log')
    plt.title(f"R² Distributions: {base_name}", fontsize=20)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{base_name}_r2_histogram.svg"))
    plt.close()

    plt.figure(figsize=(8, 6))
    plt.hist(f_soft_scores, bins=30, alpha=0.8, density=True)
    plt.xlabel("f_soft", fontsize=18)
    plt.ylabel("Density", fontsize=18)
    plt.title(f"f_soft Distribution: {base_name}", fontsize=20)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{base_name}_fsoft_histogram.svg"))
    plt.close()

    df = pd.DataFrame({
        "perturbation": pert_names,
        "R2_real": R2_real,
        "R2_null": R2_null,
        "R2_rand": R2_rand,
        "f_soft": f_soft_scores
    })
    df.to_csv(os.path.join(save_dir, f"{base_name}_results.csv"), index=False)
    print(f"Completed analysis for {base_name}")
    return df

# === USAGE ===
result_df = run_analysis_on_merfish()
