In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import subprocess
import tempfile
from scipy.stats import norm

import pyarrow.feather as feather
from concurrent.futures import ProcessPoolExecutor, as_completed

import matplotlib.pyplot as plt

# Functions

## --- heterogeneity helpers (same as before) ---

In [None]:
def stouffer_meta(pvals, weights=None):
    pvals = np.asarray(pvals)
    valid = np.isfinite(pvals) & (pvals > 0) & (pvals < 1)
    if valid.sum() == 0:
        return np.nan
    z = norm.isf(pvals[valid])
    if weights is None:
        w = np.ones(valid.sum())
    else:
        w = np.asarray(weights)[valid]
    z_meta = np.sum(w * z) / np.sqrt(np.sum(w**2))
    return norm.sf(z_meta)

def i_squared(z_vals):
    z_vals = np.asarray(z_vals)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)
    if k < 2:
        return np.nan
    Q = np.sum((z - z.mean())**2)
    return max(0, (Q - (k - 1)) / Q)

def tau_squared_dl(z_vals, weights=None):
    z_vals = np.asarray(z_vals)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)
    if k < 2:
        return np.nan
    if weights is None:
        w = np.ones(k)
    else:
        w = np.asarray(weights)[valid]
    z_bar = np.sum(w * z) / np.sum(w)
    Q = np.sum(w * (z - z_bar)**2)
    denom = np.sum(w) - (np.sum(w**2) / np.sum(w))
    return max(0, (Q - (k - 1)) / denom)

def prediction_interval(z_vals, weights=None):
    z_vals = np.asarray(z_vals)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)
    if k < 2:
        return np.nan, np.nan, np.nan
    if weights is None:
        w = np.ones(k)
    else:
        w = np.asarray(weights)[valid]
    z_meta = np.sum(w * z) / np.sqrt(np.sum(w**2))
    tau2 = tau_squared_dl(z, weights=w)
    se_pred = np.sqrt(tau2 + 1)
    PI_low = z_meta - 1.96 * se_pred
    PI_high = z_meta + 1.96 * se_pred
    return z_meta, PI_low, PI_high

In [None]:
def fisher_z(r):
    """Fisher z-transform for correlations."""
    r = np.asarray(r, dtype=float)
    valid = np.isfinite(r) & (np.abs(r) < 1)
    z = np.full_like(r, np.nan, dtype=float)
    z[valid] = 0.5 * np.log((1 + r[valid]) / (1 - r[valid]))
    return z, valid

def fisher_meta_rho(rhos, weights=None):
    """
    Fixed-effect meta-analysis for rho using Fisher's z-transform.
    Returns (meta_rho, z_bar) where meta_rho is back-transformed.
    """
    rhos = np.asarray(rhos, dtype=float)
    z, valid = fisher_z(rhos)
    if valid.sum() == 0:
        return np.nan, np.nan
    z = z[valid]
    if weights is None:
        w = np.ones_like(z)
    else:
        w = np.asarray(weights, dtype=float)[valid]
    z_bar = np.sum(w * z) / np.sum(w)
    meta_rho = np.tanh(z_bar)
    return meta_rho, z_bar

def i_squared_from_z(z_vals):
    """
    I² for transformed effects (z-values), analogous to i_squared on z-scores.
    """
    z_vals = np.asarray(z_vals, dtype=float)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)
    if k < 2:
        return np.nan
    Q = np.sum((z - z.mean())**2)
    return max(0, (Q - (k - 1)) / Q)

def tau_squared_dl_z(z_vals, weights=None):
    """
    DerSimonian–Laird tau² on z-scale for ρ.
    """
    z_vals = np.asarray(z_vals, dtype=float)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)
    if k < 2:
        return np.nan
    if weights is None:
        w = np.ones(k)
    else:
        w = np.asarray(weights, dtype=float)[valid]
    z_bar = np.sum(w * z) / np.sum(w)
    Q = np.sum(w * (z - z_bar)**2)
    denom = np.sum(w) - (np.sum(w**2) / np.sum(w))
    return max(0, (Q - (k - 1)) / denom)


In [None]:
def to_dense_df(X, var_names, obs_names):
    """Return a dense pandas DataFrame cells×genes from AnnData.X."""
    if hasattr(X, "toarray"):
        arr = X.toarray()
    elif hasattr(X, "A"):
        arr = X.A
    else:
        arr = np.asarray(X)
    return pd.DataFrame(arr, index=obs_names, columns=var_names)

## R Cell Wrapper

In [None]:
def run_morphogam_axis_single(expr_df, axis_vec, morphogam_script,
                              design="y ~ s(t, bs='cr')",
                              idx=None, outdir=None):
    """
    Run MorphoGAM on a single replicate using temp Feather files.
    expr_df: cells × genes
    axis_vec: pandas Series (index = cells, values = t)
    """
    with tempfile.TemporaryDirectory() as tf:
        tf = Path(tf)
        expr_path = tf / "expr.feather"
        axis_path = tf / "axis.feather"
        out_name = f"results_{idx}.csv" if idx is not None else "results.csv"
        out_csv = tf / out_name if outdir is None else Path(outdir) / out_name

        # genes × cells
        expr_gc = expr_df.T.reset_index().rename(columns={"index": "gene"})
        axis_df = pd.DataFrame({"cell": axis_vec.index, "t": axis_vec.values})

        feather.write_feather(expr_gc, expr_path)
        feather.write_feather(axis_df, axis_path)

        cmd = [
            "Rscript", morphogam_script,
            "--expr", str(expr_path),
            "--axis", str(axis_path),
            "--out", str(out_csv),
            # "--design", design,
        ]
        try:
            subprocess.run(cmd, check=True)
        except subprocess.CalledProcessError as e:
            print(f"[ERROR] MorphoGAM failed: {e}")
            return None

        if not out_csv.exists():
            print("[WARN] MorphoGAM produced no output.")
            return None

        return pd.read_csv(out_csv)

## Main Pipeline

In [None]:
def _worker_morphogam_job(job):
    ct = job["celltype"]
    rep = job["rep"]
    expr_df = job["expr_df"]
    axis_vec = job["axis_vec"]
    script = job["morphogam_script"]
    design = job["design"]
    outdir = job.get("outdir", None)

    print(f"[worker] MorphoGAM for {ct}/{rep} "
          f"({expr_df.shape[0]} cells × {expr_df.shape[1]} genes)")
    res = run_morphogam_axis_single(
        expr_df, axis_vec, script, design=design,
        outdir=outdir, idx=f"{ct}_{rep}"
    )
    return ct, rep, res

In [None]:
def run_morphogam_pipeline_mp(
    adata_path,
    morphogam_script,
    celltypes=None,
    celltype_field="cell_type",
    replicate_field="slide_id",
    axis_field="axis",
    design="y ~ s(t, bs='cr')",
    n_workers=4,
    min_cells=20,
    layer=None,
    outdir=None,
):
    adata = ad.read_h5ad(adata_path)
    if axis_field not in adata.obs:
        raise ValueError(f"axis vector missing in adata.obs['{axis_field}']")

    jobs = []
    if celltypes is None: 
        celltypes = np.unique(adata.obs[celltype_field]) 

    for ct in celltypes:
        ad_ct = adata[adata.obs[celltype_field] == ct]
        reps = np.unique(ad_ct.obs[replicate_field])

        for rep in reps:
            ad_rep = ad_ct[(ad_ct.obs[replicate_field] == rep) & (~ad_ct.obs[axis_field].isna())]
            if ad_rep.n_obs < min_cells:
                continue

            # convert sparse → dense when necessary
            X = ad_rep.X if layer is None else ad_rep.layers[layer]
            expr_df = to_dense_df(X, ad_rep.var_names, ad_rep.obs_names)

            axis_vec = ad_rep.obs[axis_field]
            # axis_vec = (axis_vec - axis_vec.min()) / (axis_vec.max() - axis_vec.min())

            jobs.append(dict(
                celltype=ct,
                rep=rep,
                expr_df=expr_df,
                axis_vec=axis_vec,
                morphogam_script=morphogam_script,
                design=design,
                outdir=outdir,
            ))

    results = {ct: {} for ct in celltypes}

    with ProcessPoolExecutor(max_workers=n_workers) as ex:
        future_to_job = {ex.submit(_worker_morphogam_job, j): j for j in jobs}
        for fut in as_completed(future_to_job):
            job = future_to_job[fut]
            ct = job["celltype"]
            rep = job["rep"]
            try:
                ct_out, rep_out, res_df = fut.result()
                results[ct_out][rep_out] = res_df
            except Exception as e:
                print(f"[ERROR] MorphoGAM job {ct}/{rep} failed: {e}")
                results[ct][rep] = None

    return results


In [None]:
# DO I REALLY NEED THIS? 
# def run_morphogam_plot_single(expr_df, axis_vec, genes,
#                               morphogam_plot_script,
#                               family="nb"):
#     """
#     Call morphogam_plot_axis.R to get smooth fits for selected genes.
#     Returns a DataFrame with columns: gene, t, fitted.
#     """
#     with tempfile.TemporaryDirectory() as tf:
#         tf = Path(tf)
#         expr_path = tf / "expr.feather"
#         axis_path = tf / "axis.feather"
#         genes_path = tf / "genes.txt"
#         out_csv = tf / "fits.csv"

#         expr_gc = expr_df.T.reset_index().rename(columns={"index": "gene"})
#         axis_df = pd.DataFrame({"cell": axis_vec.index, "t": axis_vec.values})

#         feather.write_feather(expr_gc, expr_path)
#         feather.write_feather(axis_df, axis_path)
#         Path(genes_path).write_text("\n".join(genes))

#         cmd = [
#             "Rscript", morphogam_plot_script,
#             "--expr", str(expr_path),
#             "--axis", str(axis_path),
#             "--genes", str(genes_path),
#             "--out", str(out_csv),
#             "--family", family,
#         ]
#         subprocess.run(cmd, check=True)
#         if not out_csv.exists():
#             return None
#         return pd.read_csv(out_csv)

## Meta Analysis

In [None]:
def meta_combine_morphogam(results_dict):
    """
    results_dict: {celltype: {replicate: df}}
    Each df must contain at least: gene, pv.t, rho.
    Returns: {celltype: DataFrame} with p- and rho-based meta statistics.
    """
    meta_out = {}

    for ct, rep_dict in results_dict.items():
        print(f"\n=== Meta-analysis (MorphoGAM): {ct} ===")
        dfs = [df for df in rep_dict.values() if df is not None]
        if len(dfs) == 0:
            continue

        all_genes = set().union(*[set(df["gene"]) for df in dfs])
        rows = []

        for gene in all_genes:
            pvals_t = []
            rhos = []

            for rep, df in rep_dict.items():
                if df is None:
                    continue
                row = df[df["gene"] == gene]
                if len(row) == 1:
                    if "pv.t" in row.columns:
                        pvals_t.append(row["pv.t"].values[0])
                    if "rho" in row.columns:
                        rhos.append(row["rho"].values[0])

            # need at least one p and one rho to do everything
            if len(pvals_t) == 0 and len(rhos) == 0:
                continue

            pvals_t = np.asarray(pvals_t, dtype=float) if len(pvals_t) else np.array([])
            rhos = np.asarray(rhos, dtype=float) if len(rhos) else np.array([])

            # --- p-value meta (Stouffer, same as before) ---
            if len(pvals_t):
                z_t = norm.isf(pvals_t)
                meta_p_t = stouffer_meta(pvals_t)
                I2_t = i_squared(z_t)
                tau2_t = tau_squared_dl(z_t)
                tau_t = np.sqrt(tau2_t) if np.isfinite(tau2_t) else np.nan
                zmeta_t, PI_low_t, PI_high_t = prediction_interval(z_t)
            else:
                meta_p_t = np.nan
                I2_t = np.nan
                tau2_t = np.nan
                tau_t = np.nan
                zmeta_t, PI_low_t, PI_high_t = (np.nan, np.nan, np.nan)

            # --- rho meta (Fisher z) ---
            if len(rhos):
                meta_rho, z_bar_rho = fisher_meta_rho(rhos)
                I2_rho = i_squared_from_z(z_bar_rho if np.ndim(z_bar_rho) else np.array([z_bar_rho]))
                # For tau² on rho, we need the vector of z's, not just z_bar:
                z_rho_vals, valid = fisher_z(rhos)
                z_rho_vals = z_rho_vals[valid]
                tau2_rho = tau_squared_dl_z(z_rho_vals)
                tau_rho = np.sqrt(tau2_rho) if np.isfinite(tau2_rho) else np.nan
            else:
                meta_rho = np.nan
                I2_rho = np.nan
                tau2_rho = np.nan
                tau_rho = np.nan

            rows.append([
                gene,
                meta_p_t, I2_t, tau2_t, tau_t,
                zmeta_t, PI_low_t, PI_high_t,
                meta_rho, I2_rho, tau2_rho, tau_rho
            ])

        df_meta = pd.DataFrame(rows, columns=[
            "gene",
            "meta_p_t", "I2_t", "tau2_t", "tau_t",
            "zmeta_t", "PI_low_t", "PI_high_t",
            "meta_rho", "I2_rho", "tau2_rho", "tau_rho"
        ])

        df_meta["fdr_t"] = df_meta["meta_p_t"].rank() / len(df_meta)

        meta_out[ct] = df_meta

    return meta_out


## Plotting

In [None]:
def plot_mean_vs_meta(df_meta, rep_dict, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    rep_dfs = {rep: df.set_index("gene")
               for rep, df in rep_dict.items() if df is not None}
    rep_mat = pd.concat([df["pv.t"] for df in rep_dfs.values()], axis=1)
    rep_mat.columns = list(rep_dfs.keys())
    rep_mean = rep_mat.mean(axis=1)

    merged = df_meta.set_index("gene").join(rep_mean.rename("rep_mean_p")).reset_index()

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(-np.log10(merged["rep_mean_p"] + 1e-12),
               -np.log10(merged["meta_p_t"] + 1e-12),
               s=10, alpha=0.5)
    ax.plot([0, 20], [0, 20], "--", color="red")
    ax.set_xlabel("-log10(mean replicate pv.t)")
    ax.set_ylabel("-log10(meta pv.t)")
    ax.set_title(f"{ct}: mean vs meta (MorphoGAM)")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_mean_vs_meta_morphogam.png", dpi=200)
    plt.close(fig)


def plot_meta_vs_I2(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.scatter(df_meta["I2_t"],
               -np.log10(df_meta["meta_p_t"] + 1e-12),
               s=10, alpha=0.5)
    ax.set_xlabel("I² (pv.t)")
    ax.set_ylabel("-log10(meta pv.t)")
    ax.set_title(f"{ct}: meta vs I² (MorphoGAM)")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_meta_vs_I2_morphogam.png", dpi=200)
    plt.close(fig)


def plot_meta_vs_tau(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.scatter(df_meta["tau_t"],
               -np.log10(df_meta["meta_p_t"] + 1e-12),
               s=10, alpha=0.5)
    ax.set_xlabel("tau (absolute heterogeneity)")
    ax.set_ylabel("-log10(meta pv.t)")
    ax.set_title(f"{ct}: meta vs tau (MorphoGAM)")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_meta_vs_tau_morphogam.png", dpi=200)
    plt.close(fig)


def plot_PIwidth_vs_meta(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    pi_width = df_meta["PI_high_t"] - df_meta["PI_low_t"]
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.scatter(pi_width,
               -np.log10(df_meta["meta_p_t"] + 1e-12),
               s=10, alpha=0.5)
    ax.set_xlabel("Prediction interval width")
    ax.set_ylabel("-log10(meta pv.t)")
    ax.set_title(f"{ct}: PI width vs meta (MorphoGAM)")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_PIwidth_vs_meta_morphogam.png", dpi=200)
    plt.close(fig)


def plot_rho_mean_vs_meta(df_meta, rep_dict, ct, outdir):
    """
    df_meta: meta dataframe for this celltype
    rep_dict: {replicate: df_with_rho}
    """
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)

    rep_dfs = {rep: df.set_index("gene")
               for rep, df in rep_dict.items() if df is not None and "rho" in df.columns}

    if not rep_dfs:
        print(f"[WARN] No rho in replicate results for {ct}")
        return

    rho_mat = pd.concat([df["rho"] for df in rep_dfs.values()], axis=1)
    rho_mat.columns = list(rep_dfs.keys())
    rho_mean = rho_mat.mean(axis=1)

    merged = df_meta.set_index("gene").join(rho_mean.rename("rho_mean")).reset_index()
    merged = merged.dropna(subset=["meta_rho", "rho_mean"])

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(merged["rho_mean"], merged["meta_rho"], s=10, alpha=0.5)
    lim = max(np.max(np.abs(merged["rho_mean"])), np.max(np.abs(merged["meta_rho"])))
    ax.plot([-lim, lim], [-lim, lim], "--")
    ax.set_xlabel("Mean replicate ρ")
    ax.set_ylabel("Meta ρ (Fisher)")
    ax.set_title(f"{ct}: directionality consistency (ρ)")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_rho_mean_vs_meta.png", dpi=200)
    plt.close(fig)

def plot_meta_p_vs_meta_rho(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)

    df = df_meta.copy()
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["meta_p_t", "meta_rho"])

    fig, ax = plt.subplots(figsize=(7, 6))
    ax.scatter(df["meta_rho"], -np.log10(df["meta_p_t"] + 1e-12),
               s=10, alpha=0.5)
    ax.axvline(0, lw=1, color="black")
    ax.set_xlabel("Meta ρ (directionality)")
    ax.set_ylabel("-log10(meta pv.t)")
    ax.set_title(f"{ct}: meta p vs directionality (ρ)")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_meta_p_vs_meta_rho.png", dpi=200)
    plt.close(fig)


def plot_top_gene_fits(df_meta, results_dict, ct,
                       adata_path, morphogam_plot_script,
                       celltype_field="cell_type",
                       replicate_field="slide_id",
                       axis_field="axis",
                       top_n=6,
                       family="nb",
                       outdir="plots_morphogam"):
    """
    For one celltype, take top_n genes by meta_p_t,
    pick a representative replicate (first with data),
    call morphogam_plot_axis.R to get smooth fits, and plot.
    """
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)
    top_genes = (df_meta.sort_values("meta_p_t")
                          .head(top_n)["gene"].tolist())

    # Load adata and pick first replicate with results
    adata = ad.read_h5ad(adata_path)
    ad_ct = adata[adata.obs[celltype_field] == ct]
    reps = list(results_dict.keys())
    # The replicate keys in results_dict[ct] (we pass that in)
    # We'll pick the first rep with non-None results
    rep_candidates = [r for r, df in results_dict.items() if df is not None]
    if len(rep_candidates) == 0:
        print(f"[WARN] No replicates with results for {ct} to plot fits.")
        return

    rep0 = rep_candidates[0]
    ad_rep = ad_ct[ad_ct.obs[replicate_field] == rep0]

    X = ad_rep.X
    if hasattr(X, "toarray"):
        X = X.toarray()
    expr_df = pd.DataFrame(X, index=ad_rep.obs_names,
                           columns=ad_rep.var_names)
    axis_vec = ad_rep.obs[axis_field]

    fits_df = run_morphogam_plot_single(
        expr_df, axis_vec, top_genes,
        morphogam_plot_script, family=family
    )
    if fits_df is None or fits_df.empty:
        print(f"[WARN] No fitted curves returned for {ct}.")
        return

    # Plot: one subplot per gene
    n = len(top_genes)
    ncols = min(3, n)
    nrows = int(np.ceil(n / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows),
                             squeeze=False)
    for i, g in enumerate(top_genes):
        r = i // ncols
        c = i % ncols
        ax = axes[r][c]
        d = fits_df[fits_df["gene"] == g]
        if d.empty:
            continue
        ax.plot(d["t"], d["fitted"])
        ax.set_title(g)
        ax.set_xlabel("axis (t)")
        ax.set_ylabel("fitted expression")

    # Hide unused subplots
    for j in range(n, nrows * ncols):
        r = j // ncols
        c = j % ncols
        fig.delaxes(axes[r][c])

    fig.suptitle(f"{ct}: top {n} MorphoGAM fits")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_top{n}_morphogam_fits.png", dpi=200)
    plt.close(fig)

# Run

In [None]:
# parameters
AD_PATH = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
R_SCRIPT = "/home/x-aklein2/projects/aklein/BICAN/BG/spatial_analysis/scripts/morpho_gam.R"
R_SCRIPT_PLOT = "/home/x-aklein2/projects/aklein/BICAN/BG/spatial_analysis/scripts/morphogam_plot.R"
R_HOME = "/anvil/projects/x-mcb130189/aklein/SPIDA/.pixi/envs/dist/lib/R"
OUTDIR = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/morphogam_dsid"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/morphogam_dsid"
REP_KEY = "dataset_id"
CELLTYPE_KEY = "Subclass"
AXIS = "MS_NORM"

layer="counts"
min_cells = 50
CELLTYPES = None # If wanting to plot specific celltypes

COORDS_MODE = "axis_only"
XY_KEYS = ("CENTER_X", "CENTER_Y")
# DESIGN = "y ~ s(t, bs='cr')"  # along-curve only
DESIGN=None

N_WORKERS = 1

In [None]:
OUTDIR = Path(OUTDIR)
OUTDIR.mkdir(exist_ok=True, parents=True)
image_path = Path(image_path)
image_path.mkdir(exist_ok=True, parents=True)
os.environ["R_HOME"] = R_HOME

In [None]:
results = run_morphogam_pipeline_mp(
    adata_path=AD_PATH,
    morphogam_script=R_SCRIPT,
    celltypes = CELLTYPES,
    celltype_field=CELLTYPE_KEY,
    replicate_field=REP_KEY,
    axis_field=AXIS,
    design=DESIGN,
    outdir=OUTDIR,
    n_workers=N_WORKERS,
    min_cells=min_cells,
    layer=layer,
)

In [None]:
# For combining all results in one place
all_results = []
for _ct, ct_res in results.items(): 
    for _rep, _res in ct_res.items(): 
        _res['celltype'] = _ct
        _res['replicate'] = _rep
        all_results.append(_res)

pd.concat(all_results).to_csv(OUTDIR / "results_all.csv", index=False)

In [None]:
# 2. Meta-analysis
meta = meta_combine_morphogam(results)

In [None]:
# 3. Save and plot per celltype
for ct, df_meta in meta.items():
    outdir = f"morphogam_plots/{ct}"
    Path(outdir).mkdir(parents=True, exist_ok=True)
    df_meta.to_csv(f"morphogam_meta_{ct}.csv", index=False)

    # summary plots
    plot_mean_vs_meta(df_meta, results[ct], ct, outdir)
    plot_meta_vs_I2(df_meta, ct, outdir)
    plot_meta_vs_tau(df_meta, ct, outdir)
    plot_PIwidth_vs_meta(df_meta, ct, outdir)
    plot_rho_mean_vs_meta(df_meta, results[ct], ct, outdir)
    plot_meta_p_vs_meta_rho(df_meta, ct, outdir)

    # smooth fits for top genes
    # plot_top_gene_fits(df_meta, results[ct], ct,
    #                     ADATA, MORPHO_PLOT_SCRIPT,
    #                     celltype_field="cell_type",
    #                     replicate_field="slide_id",
    #                     axis_field="axis",
    #                     top_n=6,
    #                     outdir=outdir)