In [None]:
import os
import pandas as pd
import scanpy as sc
import numpy as np
from sklearn import preprocessing as sk_preprocessing
from sklearn import linear_model as sk_linear_model
from scipy import stats as sp_stats
from matplotlib import pyplot as plt
import seaborn as sns
import warnings

sc.settings.n_jobs = 32
warnings.filterwarnings("ignore")

pwd = os.getcwd()

### Load the electrophysiology dataset

In [None]:
# From https://www.science.org/doi/full/10.1126/science.adf6484
ephys = pd.read_csv(
    os.path.join(pwd, "input", "230216_hIVSCC_LIMS_ephys.csv"),
    index_col=4
)
annotations = pd.read_csv("/allen/programs/celltypes/workgroups/hct/SEA-AD/RNAseq/scANVI/output/MTG_reference_patchseq/iterative_scANVI_results.2022-11-22.csv", index_col=0)
patchseq = sc.read_h5ad("/allen/programs/celltypes/workgroups/hct/SEA-AD/RNAseq/scANVI/input/MTG_reference_patchseq.2022-10-03.h5ad")

affected_supertypes = [
    "Sst_3",
    "Sst_19",
    "Sst_9",
    "Sst_13",
    "Sst_11",
    "Sst_20",
    "Sst_22",
    "Sst_23",
    "Sst_25",
    "Sst_2",
    "Pvalb_6",
    "Pvalb_5",
    "Pvalb_8",
    "Pvalb_1",
    "Pvalb_3",
    "Pvalb_2",
    "Pvalb_15",
    "Pvalb_14",
    "Pvalb_12",
]

# Data wrangling
tmp = patchseq.obs.loc[:, ["cell_id", "sex", "age_at_death"]].merge(annotations.loc[:, ["subclass_scANVI", "supertype_scANVI"]], left_index=True, right_index=True).reset_index()
tmp.index = tmp["cell_id"].copy()
tmp["subclass_scANVI"] = tmp["subclass_scANVI"].astype("category")
tmp["subclass_scANVI"] = tmp["subclass_scANVI"].cat.reorder_categories(["Lamp5_Lhx6", "Lamp5", "Pax6", "Sncg", "Vip", "Sst Chodl", "Sst", "Pvalb", "Chandelier", "L2/3 IT", "L4 IT", "L5 IT", "L6 IT", "L6 IT Car3", "L5 ET", "L6 CT", "L6b", "L5/6 NP"])
tmp = tmp.loc[[i in ["Sst", "Pvalb"] for i in tmp["subclass_scANVI"]]].copy()
tmp["subclass_scANVI"] = tmp["subclass_scANVI"].cat.remove_unused_categories()
tmp["affected_supertype"] = [i in affected_supertypes for i in tmp["supertype_scANVI"]]
tmp = tmp.merge(ephys, left_index=True, right_index=True)
tmp["cortex_layer"] = tmp["cortex_layer "].copy()
tmp = tmp.drop(["cortex_layer "], axis=1)

# Subset to only cells with metadata
tmp = tmp.loc[(tmp["failed_bad_rs"] == False) & (tmp["failed_no_seal"] == False) & (tmp["failed_electrode_0"] == False), :].copy()
tmp = tmp.loc[~tmp["sex"].isna(), :].copy()
tmp = tmp.loc[tmp["age_at_death"] > 0, :].copy()

tmp["affected_supertype"] = tmp["affected_supertype"].astype("category")
tmp["affected_supertype_codes"] = tmp["affected_supertype"].cat.codes

tmp["sex"] = tmp["sex"].astype("category")
tmp["sex_codes"] = tmp["sex"].cat.codes

tmp["age_at_death_binned"] = pd.cut(tmp["age_at_death"], 5)
tmp["age_at_death_binned_codes"] = tmp["age_at_death_binned"].cat.codes
tmp["age_at_death_binned_codes"] = sk_preprocessing.minmax_scale(tmp["age_at_death_binned_codes"])

tmp["paradigm"] = tmp["paradigm"].astype("category")
tmp["paradigm_codes"] = tmp["paradigm"].cat.codes

### Fit the logistic regression

In [None]:
model_output = pd.DataFrame(
    columns=["feature", "effect size", "score", "celltype"],
)
for z,i in enumerate(tmp.columns[15:61]):
    plot_graph = False
    df = tmp.loc[
        :,
        [
            "affected_supertype_codes",
            i,
            "sex_codes",
            "age_at_death_binned_codes",
            "paradigm_codes",
            "subclass_scANVI"
        ]
    ]
    df = df.dropna()
    for j in ["Sst", "Pvalb"]:
        
        y = df.loc[df["subclass_scANVI"] == j, "affected_supertype_codes"].copy()
        X = df.loc[df["subclass_scANVI"] == j, [i, "sex_codes", "age_at_death_binned_codes", "paradigm_codes"]]
        X[i] = sk_preprocessing.minmax_scale(X[i])

        model = sk_linear_model.LogisticRegression(max_iter=100, class_weight="balanced")
        model.fit(X, y)
        score = model.score(X, y)
        
        model_output = pd.concat([model_output, pd.DataFrame([[i, model.coef_[0][0], score, j]], columns=["feature", "effect size", "score", "celltype"])], axis=0)

# Determine significance after multiple hypothesis testing
model_output["pvalue_adj"] = sp_stats.false_discovery_control(1 - sp_stats.halfnorm.cdf((model_output["score"] - 0.5) / model_output["score"].std()))
final = model_output.loc[model_output["pvalue_adj"] < 0.05, :].sort_values(by="pvalue_adj")
display(final)

# Plot significant values
plt.rcParams["figure.figsize"] = (4,6)
for j,i in enumerate(final["feature"]):
    display(final.iloc[j])
    sns.boxplot(
        data=tmp,
        x="subclass_scANVI",
        y=i,
        hue="affected_supertype",
        showfliers=False
    );
    sns.stripplot(
        data=tmp,
        x="subclass_scANVI",
        y=i,
        hue="affected_supertype",
        dodge=True,
        alpha=0.5,
        color="black",
    );
    plt.xticks(rotation=45, ha="right");
    plt.legend(loc="upper left", bbox_to_anchor=(1.05, 1), title="Abundance change");
    plt.show();

final["effect size"] = np.round(final["effect size"],2)
final["score"] =  np.round(final["score"],2)
final.to_csv("Supplementary Table 8.csv")