In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import plotly.express as px
import seaborn as sns
import sklearn.metrics as skm
import colorcet as cc
import sklearn as sk
import sklearn.decomposition as decomp
import sklearn.pipeline as pipe
import sklearn.neighbors as nbr
import sklearn.base as skbase
import sklearn.model_selection as skms
import pickle
import os
import joblib
import itertools

sns.set_style("whitegrid")
custom_cm = cc.m_CET_L17_r
custom_cm.set_extremes(under= "lightgrey")

In [None]:
os.makedirs(
    "figures",
    exist_ok= True,
)
os.makedirs(
    "pickles",
    exist_ok= True,
)
os.makedirs(
    "data",
    exist_ok= True
)

In [None]:
merged_data = sc.read_h5ad("data/merged.h5ad")
merged_data

In [None]:
sc.pl.violin(
    merged_data,
    ["sum", "detected", "subsets_mito_percent"],
    multi_panel= True,
    save= "_preproc.pdf"
)

In [None]:
sc.pl.scatter(
    merged_data,
    "sum",
    "detected",
    color= "subsets_mito_percent",
    color_map= "viridis",
    save= "_sum_vs_detected"
)

In [None]:
merged_data.var["gene_symbol"] = merged_data.var["gene_symbol"].astype("str")
merged_data.var.loc[merged_data.var["gene_symbol"] == "nan", "gene_symbol"] = merged_data.var.loc[merged_data.var["gene_symbol"] == "nan"].index
dupes = merged_data.var["gene_symbol"].duplicated()
merged_data.var.loc[dupes, "gene_symbol"] = merged_data.var.loc[dupes, "gene_symbol"].str.cat(
    merged_data.var.loc[dupes, "gene_ids"].astype(str).apply(lambda x: x[-6:]), sep= "-"
)

In [None]:
sc.pl.highest_expr_genes(
    merged_data,
    gene_symbols= "gene_symbol",
    save= True
)

In [None]:
sc.pp.highly_variable_genes(
    merged_data,
    n_top_genes= 2000,
    flavor= "seurat_v3",
)

sc.pl.highly_variable_genes(
    merged_data,
    save= True,
)

In [None]:
merged_data.layers["norm"] = sc.pp.normalize_total(
    merged_data,
    copy= True,
    exclude_highly_expressed= True,
    key_added= "norm_factor",
).X

merged_data.layers["log"] = sc.pp.log1p(
    merged_data,
    copy= True,
    layer= "norm", 
).layers["norm"]


merged_data.layers["norm_scaled_genes"] = sc.pp.scale(
    merged_data,
    copy= True,
    layer= "log"
).layers["log"]

In [None]:
class ScPCA(skbase.TransformerMixin, skbase.BaseEstimator):
    def __init__(self, layer= None, n_comps= None, mask= None):
        self.layer = layer
        self.n_comps = n_comps
        self.mask = mask

    def fit(self, X, y= None):
        return self

    def transform(self, X):
        return sc.pp.pca(
            X,
            n_comps= self.n_comps,
            mask_var= self.mask,
            layer= self.layer,
            copy= True,
        )

class ScNeighbors(skbase.TransformerMixin, skbase.BaseEstimator):
    def __init__(self, n_neighbors= 15, n_pcs= None):
        self.n_neighbors = n_neighbors
        self.n_pcs = n_pcs

    def fit(self, X, y= None):
        return self
        
    def transform(self, X):
        return sc.pp.neighbors(
            X,
            n_neighbors= self.n_neighbors,
            n_pcs= self.n_pcs,
            copy= True,
        )
    
class ScLeiden(skbase.TransformerMixin, skbase.BaseEstimator):
    def __init__(self, resolution= 1):
        self.resolution = resolution

    def fit(self, X, y= None):
        return self

    def transform(self, X):
        return sc.tl.leiden(
            X,
            resolution= self.resolution,
            flavor= "igraph",
            copy= True,
        )

class ScScore(skbase.TransformerMixin, skbase.BaseEstimator):

    def fit(self, X, y= None):
        return self

    def score(estimator, X, y= None, sample_weight= None):
        return skm.silhouette_score(
            X.obsm["X_pca"],
            labels= X.obs["leiden"]
        )

In [None]:
pca = ScPCA(layer= "norm_scaled_genes", mask= "highly_variable")
neighbors = ScNeighbors()
scleid = ScLeiden()
scscorer = ScScore()
workflow = pipe.make_pipeline(pca, neighbors, scleid, scscorer)
param_grid = {
    "scpca__n_comps": range(25, 35),
    "scneighbors__n_neighbors": range(40, 60),
    "scleiden__resolution": np.linspace(0.1, 2, 10) 
}
X_train, X_test = skms.train_test_split(
    merged_data,
    test_size= 0.2,
    random_state= 0,
)
kfold = skms.KFold(
    shuffle= True,
    random_state= 0,
)


In [None]:
with open("pickles/gridsearch_1000", mode= "br") as f:
    grids = pickle.load(f)

In [None]:
grids.best_params_

In [None]:
grids_df = pd.DataFrame({
    "n_neighbors": grids.cv_results_["param_scneighbors__n_neighbors"],
    "n_comps": grids.cv_results_["param_scpca__n_comps"],
    "resolution": grids.cv_results_["param_scleiden__resolution"],
    "iter": grids.cv_results_["iter"],
    "mean_test_score": grids.cv_results_["mean_test_score"],
})
grids_df["resolution"] = grids_df["resolution"].round(2)

In [None]:
combos = itertools.combinations(
    grids_df.columns.drop(["mean_test_score", "iter"]),
    r= 2
)
for combo in combos:
    dummy = grids_df.groupby(
        by= list(combo)
    ).agg(lambda x: np.nan).pivot_table(
        columns= combo[0],
        index=combo[1],
        values= "mean_test_score",
        dropna= False,
    )
    fig, axs = plt.subplots(
        ncols= grids_df["iter"].nunique(),
        figsize= (15, 5),
    ) 
    for i in grids_df["iter"].unique():
        data = grids_df.loc[grids_df["iter"] == i]
        sns.heatmap(
            dummy.fillna(data.groupby(
                by= list(combo)
            ).mean().pivot_table(
                columns= combo[0], 
                index= combo[1], 
                values= "mean_test_score"
            )),
            cmap= "cet_rainbow4",
            ax= axs[i]
        )
    fig.tight_layout()
    plt.show()

In [None]:
data = grids_df.loc[grids_df["iter"] == grids_df["iter"].max()]
fig = px.scatter_3d(
    data,
    x= "n_comps",
    y= "n_neighbors",
    z= "resolution",
    color= "mean_test_score",
    color_continuous_scale= cc.rainbow4
)
fig.show()

In [None]:
sc.pp.pca(
    merged_data,
    layer= "norm_scaled_genes",
    mask_var= "highly_variable",
    n_comps= grids.best_params_["scpca__n_comps"],
)

sc.pl.pca_variance_ratio(
    merged_data,
    log= True
)

In [None]:
sc.pl.pca(
    merged_data,
    dimensions= [(0,1), (2, 3), (4, 5), (6, 7)],
    ncols= 2,
    color= "subsets_mito_percent"
)

In [None]:
sc.pp.neighbors(
    merged_data,
    n_neighbors= grids.best_params_["scneighbors__n_neighbors"],
)
sc.tl.umap(
    merged_data,
    min_dist= 0.6,
)
sc.tl.leiden(
    merged_data,
    resolution= grids.best_params_["scleiden__resolution"],
)
skm.silhouette_score(
    merged_data.obsm["X_pca"],
    labels= merged_data.obs["leiden"]
)

In [None]:
merged_data.write_h5ad("data/merged.h5ad")

In [None]:
merged_data

In [None]:
sc.pl.umap(
    merged_data,
    color= [
        "tissue_location",
        "disease_timing",
        "sample_id",
        "leiden",
        "subsets_mito_percent",
    ],
    gene_symbols= "gene_symbol",
    cmap= custom_cm,
    palette= cc.glasbey_category10,
    size= 240000 / merged_data.n_obs,
    ncols= 2,
    wspace= 0.25,
)

In [None]:
sc.pl.umap(
    merged_data,
    color= [
        "CDH1",
        "SFTPB",
        "KDM1A",
        "FLI1",
    ],
    gene_symbols= "gene_symbol",
    vmin= 0.1,
    cmap= custom_cm,
    palette= cc.glasbey_category10,
    size= 240000 / merged_data.n_obs,
    ncols= 2,
    wspace= 0.25,
)

In [None]:
cdh1_up = []
with open("gene_lists/CDH1_up.txt") as f:
    cdh1_up = f.read().split()
cdh1_up = pd.Series(cdh1_up)
cdh1_up = cdh1_up[cdh1_up.isin(merged_data.var["gene_symbol"])]
cdh1_up.isin(merged_data.var["gene_symbol"]).mean()

In [None]:
cdh1_dn = []
with open("gene_lists/CDH1_dn.txt") as f:
    cdh1_dn = f.read().split()
cdh1_dn = pd.Series(cdh1_dn)
cdh1_dn = cdh1_dn[cdh1_dn.isin(merged_data.var["gene_symbol"])]
cdh1_dn.isin(merged_data.var["gene_symbol"]).mean()

In [None]:
sc.pl.matrixplot(
    merged_data,
    groupby= "leiden",
    var_names=  np.concat([cdh1_up, ["CDH1", "KDM1A"]]),
    gene_symbols= "gene_symbol",
    layer= "norm_scaled_genes",
    dendrogram= True,
    save= "cdh1_up.png",
    swap_axes= True,
)

In [None]:
sc.pl.matrixplot(
    merged_data,
    groupby= "leiden",
    var_names=  np.concat([cdh1_dn, ["CDH1", "KDM1A"]]),
    gene_symbols= "gene_symbol",
    layer= "norm_scaled_genes",
    dendrogram= True,
    save= "cdh1_dn.png",
    swap_axes= True,
)

In [None]:
sc.pl.heatmap(
    merged_data,
    groupby= "leiden",
    var_names=  np.concat([cdh1_dn, ["CDH1", "KDM1A"]]),
    gene_symbols= "gene_symbol",
    layer= "norm_scaled_genes",
    cmap= custom_cm,
    vmax= 10,
    save= "cm_cdh1_dn.png",
)

In [None]:
cm_df = merged_data[
        :, merged_data.var["gene_symbol"].isin(cdh1_dn)
    ].to_df("norm_scaled_genes")
cm_df.columns = cm_df.columns.map(merged_data.var["gene_symbol"])
cm_df["leiden"] = merged_data.obs["leiden"]

g = sns.clustermap(
    cm_df.groupby("leiden").mean().T,
    cmap= "viridis",
)
g.savefig("figures/cdh1_dn_cm.png")

In [None]:
cm_df = merged_data[
        :, merged_data.var["gene_symbol"].isin(cdh1_up)
    ].to_df("norm_scaled_genes")
cm_df.columns = cm_df.columns.map(merged_data.var["gene_symbol"])
cm_df["leiden"] = merged_data.obs["leiden"]
g = sns.clustermap(
    cm_df.groupby("leiden").mean().T,
    cmap= "viridis"
)
g.savefig("figures/cdh1_up_cm.png")

In [None]:
sc.tl.rank_genes_groups(
    merged_data,
    groupby= "leiden",
    mask_var= "highly_variable",
    layer= "log",
    method= "wilcoxon",
)

sc.tl.filter_rank_genes_groups(
    merged_data
)

In [None]:
sc.pl.rank_genes_groups(
    merged_data,
    groups= ["4", "11"],
    gene_symbols= "gene_symbol",
    key= "rank_genes_groups"
)

In [None]:
sc.pl.rank_genes_groups_heatmap(
    merged_data,
    groupby= "leiden",
    gene_symbols= "gene_symbol",
    standard_scale= "var",
    layer= "norm_scaled_genes",
    figsize= (10, 10),
    cmap= "viridis",
    save= "_merged.png",
)

In [None]:
sc.pl.rank_genes_groups_heatmap(
    merged_data,
    groupby= "leiden",
    n_genes= -10,
    groups= ["4", "5", "6", "11"],
    gene_symbols= "gene_symbol",
    standard_scale= "var",
    layer= "norm_scaled_genes",
    figsize= (10, 10),
    cmap= "viridis",
)

In [None]:
merged_data.obs["cdh1_active"] = merged_data.obs["leiden"].isin(["10", "8", "9"])
merged_data.obs["cdh1_inactive"] = merged_data.obs["leiden"].isin(["7", "1", "2"])
mapping = {
    "10": "active",
    "8": "active",
    "9": "active",
    "7": "inactive",
    "1": "inactive",
    "2": "inactive",
}
merged_data.obs["cdh1_axis"] = merged_data.obs["leiden"].map(mapping)
merged_data.obs["cdh1_axis"] = merged_data.obs["cdh1_axis"].fillna("not expressed")
merged_data.obs["cdh1_axis"] = merged_data.obs["cdh1_axis"].astype(pd.CategoricalDtype(ordered= True))

In [None]:
sc.tl.rank_genes_groups(
    merged_data,
    groupby= "cdh1_axis",
    layer= "log",
    method= "wilcoxon",
    reference= "not expressed",
    mask_var= "highly_variable",
    key_added= "cdh1_rank_genes",
)

sc.tl.filter_rank_genes_groups(
    merged_data,
    groupby= "cdh1_axis",
    key_added= "filtered_cdh1_rank_genes",
    key= "cdh1_rank_genes",
)

sc.tl.dendrogram(
    merged_data,
    groupby= "cdh1_axis",
)

In [None]:
sc.pl.rank_genes_groups(
    merged_data,
    gene_symbols= "gene_symbol",
    key= "cdh1_rank_genes",
)

In [None]:
sc.pl.rank_genes_groups_heatmap(
    merged_data,
    groupby= "cdh1_axis",
    var_group_labels= "cdh1_axis",
    gene_symbols= "gene_symbol",
    layer= "norm_scaled_genes",
    standard_scale= "var",
    dendrogram= "dendrogram_cdh1_axis",
    key= "filtered_cdh1_rank_genes",
    cmap= "viridis"
)

In [None]:
sc.tl.score_genes(
    merged_data,
    merged_data.var["gene_symbol"].isin(pd.concat([cdh1_dn, cdh1_up])).index,
    layer= "norm_scaled_genes",
)

In [None]:
sc.pl.paga(
    merged_data,
    node_size_scale= 5,
    fontoutline= 2,
)