# Packages

In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import harmonypy as hm
from modules.visualize import *
from modules.deg_analysis import *
import seaborn as sns
from collections import Counter
from MCML.modules import MCML, bMCML

# Import dataset and annotation

In [None]:
sample_tag_mapping = {'SampleTag17_flex':'WT-DMSO',
                      'SampleTag18_flex':'3xTg-DMSO',
                      'SampleTag19_flex':'WT-SCDi',
                      'SampleTag20_flex':'3xTg-SCDi',
                      'Undetermined':'Undetermined',
                      'Multiplet':'Multiplet'}
adata = anndata.read_h5ad("data/fede_count.h5ad")
adata.obs['Sample_Tag'] = adata.obs['Sample_Tag'].map(sample_tag_mapping)
anno_df = pd.read_csv("data/fede_mapping.csv", skiprows=4)

In [None]:
# adata1 = anndata.read_h5ad("data/A_count.h5ad")
# adata1.obs['Sample_Tag'] = 'LD_5xFAD'
# adata2 = anndata.read_h5ad("data/B_count.h5ad")
# adata2.obs['Sample_Tag'] = "LD_NC"
# adata3 = anndata.read_h5ad("data/C_count.h5ad")
# adata3.obs['Sample_Tag'] = "run_5xFAD"
# adata4 = anndata.read_h5ad("data/D_count.h5ad")
# adata4.obs['Sample_Tag'] = "run_NC"
# adata = anndata.concat([adata1, adata2, adata3, adata4], axis=0)

# anno_df1 = pd.read_csv("data/A_mapping.csv", skiprows=4)
# anno_df2 = pd.read_csv("data/B_mapping.csv", skiprows=4)
# anno_df3 = pd.read_csv("data/C_mapping.csv", skiprows=4)
# anno_df4 = pd.read_csv("data/D_mapping.csv", skiprows=4)
# anno_df = pd.concat([anno_df1, anno_df2, anno_df3, anno_df4])

In [None]:
adata = annotate_adata(adata, anno_df)

# Data preprocessing

In [None]:
sc.pp.filter_cells(adata, min_genes=150)
sc.pp.filter_genes(adata, min_cells=3)

In [None]:
adata.var['mt'] = adata.var_names.str.startswith('mt-')
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
adata.obs['high_mt'] = adata.obs['pct_counts_mt'] > 50

In [None]:
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], jitter=0.4, multi_panel=True, save='violin.png')

In [None]:
adata = adata[~adata.obs['high_mt'], :]

In [None]:
adata = adata[adata.obs['Sample_Tag'] != "Multiplet", :]

In [None]:
#adata = adata[:, ~adata.var['mt']]

# Select adata object to use for analysis

In [None]:
#file_path = 'data/filtered_adata.h5ad'
#adata = anndata.read_h5ad(file_path)

In [None]:
adata.raw = adata

In [None]:
sc.pp.normalize_total(adata)
#sc.pp.log1p(adata)
#sc.pp.highly_variable_genes(adata, n_top_genes=2000, subset=True)
#sc.pp.scale(adata)

# Dimensionality reduction

In [None]:
#elbow_plot(adata, save_path='figures/elbow_plot.png')

In [None]:
#dimension_heatmap(adata, n_components=20, n_cells=500, save_path='figures/dimension_heatmap.png')

In [None]:
#sc.tl.pca(adata, n_comps=10)

In [None]:
mcml = MCML(n_latent = 50, epochs = 100) 
class_name = adata.obs['class_name'].values.tolist()
#sample_tag = adata.obs['Sample_Tag'].values.tolist()
latentMCML = mcml.fit(adata.X.toarray(), np.array([class_name]) , fracNCA = 0.8 , silent = True)
mcml.plotLosses(figsize=(10,3),axisFontSize=10,tickFontSize=8)
adata.obsm['latents'] = latentMCML

# Batch correction

In [None]:
#harmony_out = hm.run_harmony(adata.obsm['X_pca'], adata.obs, 'Sample_Tag')
#adata.obsm['X_pca_harmony'] = harmony_out.Z_corr.T

In [None]:
harmony_out = hm.run_harmony(adata.obsm['latents'], adata.obs, 'Sample_Tag')
adata.obsm['latents_harmony'] = harmony_out.Z_corr.T

# Clustering

In [None]:
sc.pp.neighbors(adata, use_rep='latents_harmony', n_pcs=50)
sc.tl.leiden(adata, resolution=0.25)

# Visualization

In [None]:
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=['leiden'], save='_leiden.png')
sc.pl.umap(adata, color=['Sample_Tag'], save='_sample_tag.png')
sc.pl.umap(adata, color=['high_mt'], save='_high_mt.png')

# Cluster annotation

In [None]:
assign_unique_cell_type_names(adata, cluster_key='leiden', cluster_types=['class_name', 'subclass_name'])

# Cluster visualization

In [None]:
sc.pl.umap(adata, color=['cluster_class_name'], save='_cluster_anno.png', title=f'After QC - {adata.shape[0]} cells', size=10)

In [None]:
sc.pl.umap(adata, color=['cluster_subclass_name'], save='_subcluster_anno.png', title=f'After QC - {adata.shape[0]} cells', size=10)

In [None]:
plot_umap(adata, cluster_type='cluster_subclass_name', legend_fontsize=7, save_path='_sample_tag')

# Keep only a specific cell type

In [None]:
# filtered_obs = adata.obs[adata.obs['cluster_subclass_name'].str.startswith('Astro-NT NN')]
# filtered_indices = adata.obs.index.get_indexer(filtered_obs.index)
# filtered_adata = anndata.AnnData(
#     X=adata.raw.X[filtered_indices, :],
#     obs=filtered_obs.copy(),
#     var=adata.raw.var.copy(),
# )
# file_path = 'data/filtered_adata.h5ad'
# filtered_adata.write_h5ad(file_path)

# Master table

In [None]:
sample_tag_counts = get_master_table(adata, cluster_type='cluster_class_name', save_path='figures/master_table')

In [None]:
sample_tag_counts

# Clusters composition analysis

In [None]:
class_level, cluster_type = 'class_name', 'cluster_class_name'

In [None]:
#create_ditto_plot(adata, ['WT-DMSO', '3xTg-DMSO', 'WT-SCDi', '3xTg-SCDi', 'Undetermined'], class_level=class_level, cluster_type=cluster_type, min_cell=100)
create_ditto_plot(adata, ['WT-DMSO'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/wt_dmso_ditto.png')
create_ditto_plot(adata, ['3xTg-DMSO'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/3xtg_dmso_ditto.png')
create_ditto_plot(adata, ['WT-SCDi'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/wt_scdi_ditto.png')
create_ditto_plot(adata, ['3xTg-SCDi'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/3xtg_scdi_ditto.png')
create_ditto_plot(adata, ['Undetermined'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/undetermined_ditto.png')

# DEG analysis

In [None]:
cell_types = [x for x in set(adata.obs.cluster_class_name.values)]

In [None]:
#ctr, cnd = 'WT-DMSO', '3xTg-DMSO'
#ctr, cnd = 'WT-DMSO', 'WT-SCDi'
#ctr, cnd = 'WT-DMSO', '3xTg-SCDi'
#ctr, cnd = '3xTg-DMSO', '3xTg-SCDi'

In [None]:
horizontal_deg_chart(adata,
                     cell_types,
                     ctr,
                     cnd,
                     min_fold_change=0.25,
                     max_p_value=0.05,
                     fig_title=f'{ctr} vs {cnd}',
                     save_path=f'figures/{ctr}_{cnd}_horizontal_deg_chart.png')

In [None]:
cell_types = ['Astro-Epen_1']

In [None]:
results_df = DEG_analysis_deseq2(adata,
                                 ctr,
                                 cnd,
                                 cell_types,
                                 save_path=f'figures/{ctr}_{cnd}_results_df.pkl',
                                 n_subsamples=5)

In [None]:
volcano_plot(results_df, 
             min_fold_change=0.25,
             max_p_value=0.05, 
             fig_title=f'{ctr} vs {cnd}',
             save_path=f'figures/{ctr}_{cnd}_volcano_plot.png')

# Retrieve NCBI gene list

In [None]:
#genes_ncbi = query_genes(adata.raw, save_path='data/genes_ncbi.pkl')

In [None]:
genes_ncbi = pickle.load(open('data/genes_ncbi.pkl', 'rb'))

# GO term enrichment analysis

In [None]:
UP_genes_id, DOWN_genes_id, UP_genes_name, DOWN_genes_name = get_DEGs(results_df,
                                                                      genes_ncbi,
                                                                      max_pval=0.05,
                                                                      min_fold_change=0.25)

In [None]:
UP_GO = go_enrichment_analysis(UP_genes_name, save_path=f'figures/{ctr}_{cnd}_GO_enrichment_analysis_UP.pkl')
DOWN_GO = go_enrichment_analysis(DOWN_genes_name, save_path=f'figures/{ctr}_{cnd}_GO_enrichment_analysis_DOWN.pkl')

In [None]:
display_go_enrichment(UP_GO,
                      namespace='BP',
                      fig_title=f'UP BP - {ctr} vs {cnd}',
                      save_path=f'figures/{ctr}_{cnd}_display_GO_enrichment_UP_BP')

display_go_enrichment(UP_GO, 
                      namespace='MF',
                      fig_title=f'UP MF - {ctr} vs {cnd}',
                      save_path=f'figures/{ctr}_{cnd}_display_GO_enrichment_UP_MF')

display_go_enrichment(UP_GO, 
                      namespace='CC',
                      fig_title=f'UP CC - {ctr} vs {cnd}',
                      save_path=f'figures/{ctr}_{cnd}_display_GO_enrichment_UP_CC')

In [None]:
display_go_enrichment(DOWN_GO, 
                      namespace='BP',
                      fig_title=f'DOWN BP - {ctr} vs {cnd}',
                      save_path=f'figures/{ctr}_{cnd}_display_GO_enrichment_DOWN_BP')

display_go_enrichment(DOWN_GO, 
                      namespace='MF',
                      fig_title=f'DOWN MF - {ctr} vs {cnd}',
                      save_path=f'figures/{ctr}_{cnd}_display_GO_enrichment_DOWN_MF')

display_go_enrichment(DOWN_GO, 
                      namespace='CC',
                      fig_title=f'DOWN CC - {ctr} vs {cnd}',
                      save_path=f'figures/{ctr}_{cnd}_display_GO_enrichment_DOWN_CC')

In [None]:
UP_KEGG = kegg_enrichment_analysis(UP_genes_name, 
                                   save_path=f'figures/{ctr}_{cnd}_KEGG_enrichment_analysis.pkl')

In [None]:
DOWN_KEGG = kegg_enrichment_analysis(DOWN_genes_name, 
                                     save_path=f'figures/{ctr}_{cnd}_KEGG_enrichment_analysis.pkl')

In [None]:
display_kegg_enrichment(UP_KEGG,
                        fig_title=f'UP pathway - {ctr} vs {cnd}',
                        save_path=f'figures/{ctr}_{cnd}_display_KEGG_enrichment_UP')

In [None]:
display_kegg_enrichment(DOWN_KEGG,
                        fig_title=f'DOWN pathway - {ctr} vs {cnd}',
                        save_path=f'figures/{ctr}_{cnd}_display_KEGG_enrichment_DOWN')

In [None]:
ctr, cnd = 'WT-DMSO', '3xTg-DMSO'

In [None]:
results_df1 = DEG_analysis_deseq2(adata, 
                          ctr, 
                          cnd, 
                          cell_types, 
                          save_path=f'figures/{ctr}_{cnd}_results_df.pkl')

In [None]:
ctr, cnd = 'WT-DMSO', '3xTg-SCDi'

In [None]:
results_df2 = DEG_analysis_deseq2(adata, 
                          ctr, 
                          cnd, 
                          cell_types, 
                          save_path=f'figures/{ctr}_{cnd}_results_df.pkl')

In [None]:
results_df1['significant'] = results_df1['padj'] < 0.05
results_df1['outside_range'] = results_df1['significant'] & (results_df1['log2FoldChange'].abs() > 0.25)

In [None]:
results_df2['significant'] = results_df2['padj'] < 0.05
results_df2['outside_range'] = results_df2['significant'] & (results_df2['log2FoldChange'].abs() > 0.25)

In [None]:
df2_names = [x['names'] for x in results_df2.T.to_dict().values() if x['significant'] and x['outside_range']]

In [None]:
df1_names = [x['names'] for x in results_df1.T.to_dict().values() if x['significant'] and x['outside_range']]

In [None]:
import matplotlib.pyplot as plt
from matplotlib_venn import venn2

set1 = set(df1_names)
set2 = set(df2_names)

venn2([set1, set2], ('Restored DEGs', 'New DEGs'))
plt.title('WT-DMSO vs 3xTG-DMSO - WT-DMSO vs 3xTG-SCDi')
plt.savefig("venn_diagram.png")
plt.show()
