### Load needed libraries

In [None]:
import os
import shutil
import scanpy as sc
import pandas as pd
import numpy as np
import re
import seaborn as sns
import matplotlib
import copy
from matplotlib import pyplot as plt
from matplotlib import cm as cm
%matplotlib inline
from scipy import stats as sp_stats
import warnings
from datetime import datetime
from helper_functions import *

warnings.filterwarnings("ignore")
sc.settings.n_jobs = 32
sc.set_figure_params(scanpy=True, dpi=100, dpi_save=500, frameon=False, vector_friendly=True, figsize=(10,10), format='png')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams["axes.grid"] = False

pwd = os.getcwd()

### Load the needed data files

In [None]:
# Cluster order and colors
color_order = pd.read_csv(os.path.join(pwd, "input", "cluster_order_and_colors.csv"))

# From https://sea-ad-single-cell-profiling.s3.amazonaws.com/index.html#MTG/RNAseq/Supplementary%20Information/
pvalues = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "pvalues.h5ad"))
effect_sizes = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "effect_sizes.h5ad"))
effect_sizes_early = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "effect_sizes_early.h5ad"))
effect_sizes_late = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "effect_sizes_late.h5ad"))
gene_dynamic_space = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "gene_dynamic_space.h5ad"))
mean_expression = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "mean_expression.h5ad"))
fraction_expressed = sc.read_h5ad(os.path.join(pwd, "input", "Figure 5 and Extended Data Figure 10", "fraction_expressed.h5ad"))

# From https://sea-ad-single-cell-profiling.s3.amazonaws.com/index.html#MTG/RNAseq/
adata = sc.read_h5ad(os.path.join(pwd, "input", "SEAAD_MTG_RNAseq_final-nuclei.2024-02-13.h5ad"))

### Figure 5a

In [None]:
FDR = 0.01
Expected = FDR * pvalues.shape[0]
pvalues.var["Number of significant genes"] = (pvalues.layers["pvalues"] < FDR).sum(axis=0)

df = sc.get.var_df(pvalues, ["Subclass", "Number of significant genes"]).reset_index()

df["Subclass"] = df["Subclass"].astype("category")
df["Subclass"] = df["Subclass"].cat.reorder_categories(color_order["subclass_label"].unique())

plt.rcParams["figure.figsize"] = (10,4)
ax = sns.swarmplot(
    data=df,
    x="Subclass",
    y="Number of significant genes",
    hue="index",
    palette=color_order["cluster_color"].to_list()
)
ax.set_xlabel("");
ax.set_ylabel("No. of significant genes");
plt.xticks(rotation=90, ha="right");
plt.legend("", frameon=False);
plt.axhline(y=Expected, color="grey", linestyle="--");
plt.savefig(os.path.join(pwd, "output", "Figure 5a_swarmplot_Number of significant genes_by_Supertype_groupby_Subclass.pdf"), bbox_inches="tight")
plt.show();

### Extended Data Figure 10b

In [None]:
# Histogram
FDR = 0.01
df = effect_sizes.layers["effect_sizes"] * (pvalues.layers["pvalues"] < FDR)
plt.rcParams["figure.figsize"] = (4,4)
ax = sns.histplot(
    df[df != 0],
    color="grey",
    stat="proportion"
)
ax.set(xlim=(-5,5));
ax.set(xlabel="Significant effect sizes along CPS");
plt.savefig(os.path.join(pwd, "output", "Extended Data Figure 10b_histogram_Significant_effect_sizes.pdf"), bbox_inches="tight")
plt.show()

# Scatterplot
df = sc.get.var_df(pvalues, ["Subclass", "Number of significant genes"])
df = df.merge(adata.obs["Supertype"].value_counts(sort=False), left_index=True, right_index=True)
df = df.loc[df["count"] < 15000, :]

ax = sns.scatterplot(
    data=df,
    x="count",
    y="Number of significant genes",
    color="grey",
)
ax.set(xlabel="No. of nuclei per type", ylabel="No. of significant genes");
plt.xticks(rotation=90, ha="right");
plt.savefig(os.path.join(pwd, "output", "Extended Data Figure 10b_scatterplot_Number of nuclei_versus_Number of significant genes.pdf"), bbox_inches="tight")
plt.show();

### Figure 5b

In [None]:
# Scatterplot
plt.rcParams["figure.figsize"] = (4,4)

effect_size_mask_radius = 0.2
mean_expression_cutoff = 0.01
fraction_expressed_cutoff = 0.01

mean_early = effect_sizes_early.layers["effect_sizes"].mean(axis=1)
mean_late = effect_sizes_late.layers["effect_sizes"].mean(axis=1)
mean_mean_expression = mean_expression.X.mean(axis=1)
mean_fraction_expressed = fraction_expressed.X.mean(axis=1)

gene_classes = pd.DataFrame(np.zeros((effect_sizes.shape[0], 1)), columns=["Gene class"], index=effect_sizes.obs_names)

good_genes = (mean_mean_expression > mean_expression_cutoff) & (mean_fraction_expressed > fraction_expressed_cutoff) & (np.sqrt(np.square(mean_early) + np.square(mean_late)) > effect_size_mask_radius)

mask_1 = (mean_late >= 3 * mean_early)
mask_2 = (mean_late >= (1/3) * mean_early)
mask_3 = (mean_late >= (-1/3) * mean_early)
mask_4 = (mean_late >= -3 * mean_early)

gene_classes.loc[~(mask_1) & (mask_2) & good_genes, "Gene class"] = 1
gene_classes.loc[(mask_1) & ~(mask_2) & good_genes, "Gene class"] = 2
gene_classes.loc[~(mask_2) & (mask_3) & good_genes, "Gene class"] = 3
gene_classes.loc[(mask_2) & ~(mask_3) & good_genes, "Gene class"] = 4
gene_classes.loc[(mask_1) & (mask_4) & good_genes, "Gene class"] = 5
gene_classes.loc[~(mask_1) & ~(mask_4) & good_genes, "Gene class"] = 6
gene_classes.loc[~(mask_3) & (mask_4) & good_genes, "Gene class"] = 7
gene_classes.loc[(mask_3) & ~(mask_4) & good_genes, "Gene class"] = 8
gene_classes["Mean early effect size"] = mean_early
gene_classes["Mean late effect size"] = mean_late
gene_classes["Gene class"] = gene_classes["Gene class"].astype("category")
gene_classes["Gene class"] = gene_classes["Gene class"].cat.reorder_categories([5.0, 1.0, 3.0, 7.0, 0.0, 8.0, 4.0, 2.0, 6.0])
gene_classes = gene_classes.loc[gene_classes["Gene class"] != 0, :].copy()

ax = sns.scatterplot(
    data=gene_classes,
    x="Mean early effect size",
    y="Mean late effect size",
    hue="Gene class",
    alpha=0.2,
    size=1,
    rasterized=True,
    palette="RdBu"
);
x = np.linspace(-2.1,2.1,100)
y = (3/1)*x
plt.plot(x, y, "--k");
y = (1/3)*x
plt.plot(x, y,  "--k");
y = (-3/1)*x
plt.plot(x, y, "--k");
y = (-1/3)*x
plt.plot(x, y,  "--k");
theta = np.linspace(0, 2*np.pi, 100)
r = effect_size_mask_radius
x = r*np.cos(theta)
y = r*np.sin(theta)
plt.plot(x, y, 'k--')
ax.set(xlim=(-2.1,2.1), ylim=(-2.1,2.1));
plt.legend(bbox_to_anchor=(1.04, 0.5), loc="center left")
plt.savefig(os.path.join(pwd, "output", "Figure 5b_scatterplot_early_versus_late_mean_effect_sizes.pdf"), bbox_inches="tight");
plt.show()



In [None]:
adata = adata[adata.obs["Neurotypical reference"] == "False", :].copy()
adata.obs["Used in analysis"] = adata.obs["Used in analysis"].astype("str").astype("category")

In [None]:
# Dynamic plots
cmap = cm.get_cmap('RdBu', 9)
color_list = [matplotlib.colors.rgb2hex(cmap(i)[:3]) for i in range(cmap.N)]

gene_classes["mult_effect_sizes"] = np.abs(gene_classes["Mean early effect size"] * gene_classes["Mean late effect size"])
gene_classes["rot_mean_early"] = gene_classes["Mean early effect size"] * np.cos(np.pi / 4) - gene_classes["Mean late effect size"] * np.sin(np.pi / 4)
gene_classes["rot_mean_late"] = gene_classes["Mean late effect size"] * np.cos(np.pi / 4) + gene_classes["Mean early effect size"] * np.sin(np.pi / 4)
gene_classes["rot_mult_effect_sizes"] = np.abs(gene_classes["rot_mean_early"] * gene_classes["rot_mean_late"])
    
for j,i in enumerate(gene_classes["Gene class"].cat.categories):
    
    if i in [1, 2, 8, 7]:
        genes = gene_classes.loc[gene_classes["Gene class"] == i, "mult_effect_sizes"].sort_values(ascending=False).iloc[:50].index.to_list()
    elif i != 0:
        genes = gene_classes.loc[gene_classes["Gene class"] == i, "rot_mult_effect_sizes"].sort_values(ascending=False).iloc[:50].index.to_list()
    else:
        plt.ylim((-0.015, 0.055));
        plt.savefig(os.path.join(pwd, "output", "Figure 5b_deltaplot_positive gene dynamics.pdf"), bbox_inches="tight")
        plt.show()
        continue

    print(genes)
    
    ax = delta_plot(
        adata=adata,
        genes=genes,
        groupby="Used in analysis",
        plotby="Neurotypical reference",
        donor="Donor ID",
        across="Continuous Pseudo-progression Score",
        highlight=[],
        title="Class " + str(i) + " genes",
        colormap={"False": color_list[np.int32(j)]},
        normalize_to_start=True,
    )
plt.ylim((-0.05, 0.03));
plt.savefig(os.path.join(pwd, "output", "Figure 5b_deltaplot_negative gene dynamics.pdf"), bbox_inches="tight")
plt.show()

### Figure 5c and Extended Data Figure 10c

In [None]:
plt.rcParams["figure.figsize"] = (6,6)
sc.pl.umap(
    gene_dynamic_space,
    color=None,
    size=20,
    legend_loc="on data",
    save="_gene_dynamic_space.pdf"
)
os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space.pdf"), os.path.join(pwd, "output", "Figure 5c_umap_gene_dynamic_space.pdf"))

plt.rcParams["figure.figsize"] = (4,4)
for i in [1, 2, 3, 4]:
    sc.pl.umap(
        gene_dynamic_space,
        color=["mean_expression_" + str(i)],
        size=20,
        color_map="RdBu_r",
        vcenter=0,
        vmin=-2,
        vmax=2,
        sort_order=False,
        save="_gene_dynamic_space_mean_expression_" + str(i) + ".pdf"
    )
    if i == 2:
        os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_mean_expression_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Figure 5c_umap_gene_dynamic_space_mean_expression_" + str(i) + ".pdf"))
    else:
        os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_mean_expression_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Extended Data Figure 10c_umap_gene_dynamic_space_mean_expression_" + str(i) + ".pdf"))
    
    sc.pl.umap(
        gene_dynamic_space,
        color=["effect_sizes_early_" + str(i)],
        size=20,
        color_map="RdBu_r",
        vcenter=0,
        vmin=-2,
        vmax=2,
        sort_order=False,
        save="_gene_dynamic_space_effect_sizes_early_" + str(i) + ".pdf"
    )
    if i == 2:
        os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_effect_sizes_early_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Figure 5c_umap_gene_dynamic_space_effect_sizes_early_" + str(i) + ".pdf"))
    else:
        os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_effect_sizes_early_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Extended Data Figure 10c_umap_gene_dynamic_space_effect_sizes_early_" + str(i) + ".pdf"))
    
    sc.pl.umap(
        gene_dynamic_space,
        color=["effect_sizes_late_" + str(i)],
        size=20,
        color_map="RdBu_r",
        vcenter=0,
        vmin=-2,
        vmax=2,
        sort_order=False,
        save="_gene_dynamic_space_effect_sizes_late_" + str(i) + ".pdf"
    )
    if i == 2:
        os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_effect_sizes_late_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Figure 5c_umap_gene_dynamic_space_effect_sizes_late_" + str(i) + ".pdf"))
    else:
        os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_effect_sizes_late_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Extended Data Figure 10c_umap_gene_dynamic_space_effect_sizes_late_" + str(i) + ".pdf"))

# Define gene lists
gene_lists = {}
# Electron transport chain components, based on GO:0022900
gene_lists["Electron transport chain components"] = []
complexes = ["NDUF", "SDH", "UQC", "COX", "ATP5"]
for j,i in enumerate(complexes):
    gene_lists["Electron transport chain components"].extend(effect_sizes.obs_names[(effect_sizes.obs_names.str.startswith(i)) & ~(effect_sizes.obs_names.str.contains("-AS")) & ~(effect_sizes.obs_names.str.contains("-DT"))])
    
# Ribosomal proteins, based on GO:0006412
gene_lists["Ribosomal proteins"] = []
gene_lists["Ribosomal proteins"].extend([i for i in effect_sizes.obs_names if re.match("RP[SL]([0-9]{1,2}[ABLXY]?[0-9]?|A|P[0-9]{1})$", i) != None])

for i,j in gene_lists.items():
    gene_dynamic_space.obs["gene_list_" + str(i)] = [str(k in j) for k in gene_dynamic_space.obs_names]
    sc.pl.umap(
        gene_dynamic_space,
        color=["gene_list_" + str(i)],
        groups=["True"],
        ncols=2, size=20,
        save="_gene_dynamic_space_family_" + str(i) + ".pdf"
    )
    os.rename(os.path.join(pwd, "figures", "umap_gene_dynamic_space_family_" + str(i) + ".pdf"), os.path.join(pwd, "output", "Figure 5c_umap_gene_dynamic_family_" + str(i) + ".pdf"))

### Figure 5d

In [None]:
genes = sorted(['ATP5F1D', 'ATP5F1E', 'ATP5IF1', 'ATP5MC3', 'ATP5MD', 'ATP5ME', 'ATP5MF', 'ATP5MPL', 'ATP5PD', 'ATP5PF', 'COX14', 'COX4I1', 'COX5B', 'COX6A1', 'COX6B1', 'COX6C', 'COX7A1', 'COX7A2', 'COX7B', 'COX7C',  'NDUFA1', 'NDUFA11', 'NDUFA13', 'NDUFA3', 'NDUFA4', 'NDUFA6', 'NDUFB10', 'NDUFB11', 'NDUFB2', 'NDUFB3', 'NDUFB4', 'NDUFB8', 'NDUFB9', 'NDUFC1', 'NDUFC2', 'NDUFS5'])

# Deltaplots
subclass_colors = color_order.loc[:, ["subclass_label", "subclass_color"]].drop_duplicates()
subclass_colors.index = subclass_colors["subclass_label"].copy()
subclass_colors = subclass_colors["subclass_color"].to_dict()

sub = adata[adata.obs["Neurotypical reference"] == "False", :].copy()

ax = delta_plot(
    adata=sub,
    genes=genes,
    groupby="Class",
    groupby_subset=None,
    plotby="Subclass",
    donor="Donor ID",
    across="Continuous Pseudo-progression Score",
    highlight=[],
    colormap=subclass_colors,
    title="Expression of ETC components",
    legend=False,
    save=os.path.join(pwd, "output", "Figure 5d_deltaplot_{title}_groupby Class_plotby Subclass.pdf")
)
plt.show();

In [None]:
# Heatmap
genes.extend(["Class"])
df = sc.get.var_df(
    gene_dynamic_space,
    genes
)
df["type"] = "mean"
df.loc[df.index.str.endswith("_early"), "type"] = "early"
df.loc[df.index.str.endswith("_late"), "type"] = "late"
df = df.loc[~df.index.str.endswith("_mean"), :].copy()

df = df.groupby(["Class", "type"]).mean()
plt.rcParams["figure.figsize"] = (2.5,10)
sns.heatmap(
    data=df.T,
    cmap="RdBu_r",
    xticklabels=True,
    yticklabels=True,
    center=0,
    vmin=-2,
    vmax=2
);
plt.savefig(os.path.join(pwd, "output", "Figure 5d_heatmap_effect_sizes_of_Electron transport chain components.pdf"), bbox_inches="tight") 
plt.show()


### Clean up

In [None]:
shutil.rmtree(os.path.join(pwd, "figures"))