In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import harmonypy as hm
from modules.visualize import *
from modules.deg_analysis import *

In [None]:
sample_tag_mapping = {'SampleTag17_flex':'WT-DMSO',
                      'SampleTag18_flex':'3xTg-DMSO',
                      'SampleTag19_flex':'WT-SCDi',
                      'SampleTag20_flex':'3xTg-SCDi',
                      'Undetermined':'Undetermined',
                      'Multiplet':'Multiplet'}
# Load h5ad object
adata = anndata.read_h5ad("data/fede_count.h5ad")
# Map appropriate condition tags
adata.obs['Sample_Tag'] = adata.obs['Sample_Tag'].map(sample_tag_mapping)

# Load cell annotation info
anno_df = pd.read_csv("data/fede_mapping.csv", skiprows=4)
anno_df = anno_df.set_index('cell_id')[['class_name', "subclass_name", "supertype_name", 'cluster_name']]

In [None]:
# Initial QC and filtering
sc.pp.filter_cells(adata, min_genes=150)
sc.pp.filter_genes(adata, min_cells=3)

In [None]:
# Calculate mitochondrial genes percentage
adata.var['mt'] = adata.var_names.str.startswith('mt-')
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

In [None]:
# Plot QC metrics
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'], jitter=0.4, multi_panel=True, save='violin.png')

In [None]:
# Filter cells
adata = adata[(adata.obs.pct_counts_mt < 50), :]

In [None]:
# Filter multiplets if applicable (assuming Sample_Name is a column in adata.obs)
adata = adata[adata.obs['Sample_Name'] != "Multiplet", :]

In [None]:
# Saving raw counts before data transformation
adata.raw = adata

In [None]:
# Normalization and identifying variable genes
sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, max_fraction=0.05)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, subset=True)
sc.pp.scale(adata, max_value=10)

In [None]:
# Plot explained variance vs PCs
elbow_plot(adata)

In [None]:
# Dimension heatmap
dimension_heatmap(adata, n_components=15, n_cells=500)

In [None]:
# Harmony batch correction
sc.tl.pca(adata, n_comps=10)

# Harmony batch correction
harmony_out = hm.run_harmony(adata.obsm['X_pca'], adata.obs, 'Sample_Tag')
adata.obsm['X_pca_harmony'] = harmony_out.Z_corr.T

# Neighbors and clustering using Harmony-corrected PCA
sc.pp.neighbors(adata, use_rep='X_pca_harmony', n_pcs=10)
sc.tl.leiden(adata, resolution=0.5)

# UMAP and t-SNE
sc.tl.umap(adata)
sc.tl.tsne(adata)

In [None]:
# Plotting
sc.pl.umap(adata, color=['leiden'], save='umap.png')
sc.pl.tsne(adata, color=['leiden'], save='tsne.png')

In [None]:
adata.obs['cell_id'] = adata.obs.index.astype(str)
anno_df.index = anno_df.index.astype(str)
adata.obs = adata.obs.merge(anno_df, left_on='cell_id', right_index=True, how='left')
# Assign unique cell type names to each cluster
assign_unique_cell_type_names(adata)
# Ensure leiden and annotated_cluster are strings for concatenation
adata.obs['leiden'] = adata.obs['leiden'].astype(str)
adata.obs['annotated_cluster'] = adata.obs['annotated_cluster'].astype(str)

In [None]:
# Plot UMAP with unique cell type annotations
sc.pl.umap(adata, color=['annotated_cluster'], save='umap_all_groups.png', title=f'After QC - {adata.shape[0]} cells', size=10)

In [None]:
sample_tags = adata.obs['Sample_Tag'].unique()
plot_umap(adata, sample_tags, legend_loc='on data', legend_fontsize=5)

In [None]:
sample_tag_counts = get_master_table(adata)

In [None]:
get_ditto(sample_tag_counts)

In [None]:
#ctr, cnd = 'WT-DMSO', '3xTg-DMSO'
ctr, cnd = 'WT-SCDi', '3xTg-SCDi'

In [None]:
cell_types = [x for x in set(adata.obs.annotated_cluster.values)]

In [None]:
cluster_n_DEGs = []
min_fold_change = 0.25
max_p_value = 0.05
for cell_type in tqdm(cell_types):
    df = DEG_analysis(adata, ctr, cnd, [cell_type])
    positive_enriched = df[(df['logfoldchanges'] > min_fold_change) & (df['pvals_adj'] < max_p_value)]
    negative_enriched = df[(df['logfoldchanges'] < -min_fold_change) & (df['pvals_adj'] < max_p_value)]
    positive_count = positive_enriched.shape[0]
    negative_count = negative_enriched.shape[0]
    cluster_n_DEGs.append((cell_type, positive_count, negative_count))

In [None]:
horizontal_deg_chart(cluster_n_DEGs)

In [None]:
results_df = DEG_analysis(adata, ctr, cnd, ['Astro-Epen_3'])

In [None]:
get_volcano_plot(results_df)