# Packages

In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn as nn
import torch.nn.functional as F
import umap
import pandas as pd
import anndata
import scanpy as sc
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
from modules.sparse_autoencoder import *
import pickle
from tqdm import tqdm
from modules.deg_analysis import *
from modules.visualize import *
from sklearn.neighbors import NearestNeighbors

  from .autonotebook import tqdm as notebook_tqdm
2024-06-26 10:22:29.456498: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Import dataset and annotation

In [None]:
# sample_tag_mapping = {'SampleTag17_flex':'WT-DMSO',
#                       'SampleTag18_flex':'3xTg-DMSO',
#                       'SampleTag19_flex':'WT-SCDi',
#                       'SampleTag20_flex':'3xTg-SCDi',
#                       'Undetermined':'Undetermined',
#                       'Multiplet':'Multiplet'}
# adata = anndata.read_h5ad("data/fede_count.h5ad")
# adata.obs['Sample_Tag'] = adata.obs['Sample_Tag'].map(sample_tag_mapping)
# anno_df = pd.read_csv("data/fede_mapping.csv", skiprows=4)

In [None]:
adata1 = anndata.read_h5ad("data/A_count.h5ad")
adata1.obs['Sample_Tag'] = 'LD_5xFAD'
adata2 = anndata.read_h5ad("data/B_count.h5ad")
adata2.obs['Sample_Tag'] = "LD_NC"
adata3 = anndata.read_h5ad("data/C_count.h5ad")
adata3.obs['Sample_Tag'] = "run_5xFAD"
adata4 = anndata.read_h5ad("data/D_count.h5ad")
adata4.obs['Sample_Tag'] = "run_NC"
adata = anndata.concat([adata1, adata2, adata3, adata4], axis=0)

anno_df1 = pd.read_csv("data/A_mapping.csv", skiprows=4)
anno_df2 = pd.read_csv("data/B_mapping.csv", skiprows=4)
anno_df3 = pd.read_csv("data/C_mapping.csv", skiprows=4)
anno_df4 = pd.read_csv("data/D_mapping.csv", skiprows=4)
anno_df = pd.concat([anno_df1, anno_df2, anno_df3, anno_df4])

In [None]:
adata = annotate_adata(adata, anno_df)

# Data preprocessing

In [None]:
sc.pp.filter_cells(adata, min_genes=150)
sc.pp.filter_genes(adata, min_cells=3)

In [None]:
adata.var['mt'] = adata.var_names.str.startswith('mt-')
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
adata.obs['high_mt'] = adata.obs['pct_counts_mt'] > 5

In [None]:
adata = adata[~adata.obs['high_mt'], :]
adata = adata[adata.obs['Sample_Tag'] != "Multiplet", :]

In [None]:
adata = adata[:, ~adata.var['mt']]

# Train/Test split

In [None]:
train_indices, test_indices = train_test_split(np.arange(adata.n_obs), test_size=0.2, random_state=42)
adata_train = adata[train_indices].copy()
adata_test = adata[test_indices].copy()
train_dataset = AnnDataDataset(adata_train)
test_dataset = AnnDataDataset(adata_test)
n_inputs = adata.var.index.values.shape[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Train model

In [None]:
autoencoder = Autoencoder(n_latents=32, 
                          n_inputs=n_inputs, 
                          activation=TopK(k=16), 
                          tied=True, 
                          normalize=True)
autoencoder.to(device)

train_autoencoder(autoencoder, 
                  train_loader, 
                  test_loader,
                  device,
                  num_epochs=100, 
                  learning_rate=0.0001, 
                  prune_interval=10,
                  prune_amount=0.95)

# Load trained model

In [None]:
autoencoder = Autoencoder(n_latents=32, n_inputs=n_inputs, activation=TopK(k=16), tied=True, normalize=True).to(device)
state_dict = torch.load('best_autoencoder.pth')
model_state_dict = autoencoder.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
autoencoder.load_state_dict(filtered_state_dict, strict=False)
autoencoder.eval()
autoencoder.cpu()
apply_pruning(autoencoder, amount=0.95)

# Extract latent representations

In [None]:
data = torch.tensor(adata_test.X.toarray(), dtype=torch.float32)
latents, _ = autoencoder.encode(data)
latents_np = latents.detach().numpy()

In [None]:
# Insert the latents into adata_test.obsm
adata_test.obsm['latents'] = latents_np

# Clustering

In [None]:
# Clustering using 'latents'
sc.pp.neighbors(adata_test, use_rep='latents')
sc.tl.leiden(adata_test, resolution=0.5)

# Visualization

In [None]:
# UMAP
sc.tl.umap(adata_test)

In [None]:
# Plotting
sc.pl.umap(adata_test, color=['leiden'], save='umap_leiden.png')
sc.pl.umap(adata_test, color=['Sample_Tag'], save='umap_sampletag.png')
sc.pl.umap(adata_test, color=['high_mt'], save='umap_highmt.png')

# Cluster annotation

In [None]:
assign_unique_cell_type_names(adata_test, cluster_key='leiden', cluster_types=['class_name', 'subclass_name'])

# Cluster annotation visualization

In [None]:
# Plot UMAP with unique cell type annotations
sc.pl.umap(adata_test, color=['cluster_class_name'], save='umap_all_groups.png', title=f'After QC - {adata.shape[0]} cells', size=10)

In [None]:
# Plot UMAP with unique cell type annotations
sc.pl.umap(adata_test, color=['cluster_subclass_name'], save='umap_all_groups.png', title=f'After QC - {adata.shape[0]} cells', size=10)

In [None]:
# Plot separately by sample tag
sample_tags = adata_test.obs['Sample_Tag'].unique()
plot_umap(adata_test, cluster_type='cluster_subclass_name', legend_fontsize=7, save_path='_sample_tag')

# Clusters composition analysis

In [None]:
class_level, cluster_type = 'subclass_name', 'cluster_subclass_name'

In [None]:
#create_ditto_plot(adata, ['WT-DMSO', '3xTg-DMSO', 'WT-SCDi', '3xTg-SCDi', 'Undetermined'], class_level=class_level, cluster_type=cluster_type, min_cell=100)
create_ditto_plot(adata_test, ['WT-DMSO'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/wt_dmso_ditto.png')
create_ditto_plot(adata_test, ['3xTg-DMSO'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/3xtg_dmso_ditto.png')
create_ditto_plot(adata_test, ['WT-SCDi'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/wt_scdi_ditto.png')
create_ditto_plot(adata_test, ['3xTg-SCDi'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/3xtg_scdi_ditto.png')
create_ditto_plot(adata_test, ['Undetermined'], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/undetermined_ditto.png')

# Top contributing genes for each latent dimension

In [None]:
top_genes = get_top_genes(autoencoder, adata_test)

In [None]:
plot_top_contributing_genes(autoencoder, adata_test, latent_dim=0, top_n=25)

In [None]:
top_genes[i]

# GO term enrichment analysis

In [None]:
UP_GO_results = []
DOWN_GO_results = []
UP_KEGG_results = []
DOWN_KEGG_results = []
for i in tqdm(range(32)):
    if len(top_genes[i]['UP_genes_name']) == 0 or len(top_genes[i]['DOWN_genes_name']) == 0:
        UP_GO_results.append(pd.DataFrame())
        DOWN_GO_results.append(pd.DataFrame())
        UP_KEGG_results.append(pd.DataFrame())
        DOWN_KEGG_results.append(pd.DataFrame())
        continue
    UP_GO = go_enrichment_analysis(top_genes[i]['UP_genes_name'], save_path=None)
    DOWN_GO = go_enrichment_analysis(top_genes[i]['DOWN_genes_name'], save_path=None)
    UP_KEGG = kegg_enrichment_analysis(top_genes[i]['UP_genes_name'], save_path=None)
    DOWN_KEGG = kegg_enrichment_analysis(top_genes[i]['DOWN_genes_name'], save_path=None)

    UP_GO_results.append(UP_GO)
    DOWN_GO_results.append(DOWN_GO)
    UP_KEGG_results.append(UP_KEGG)
    DOWN_KEGG_results.append(DOWN_KEGG)

In [None]:
idx=29

In [None]:
display_go_enrichment(UP_GO_results[idx], namespace='BP', fig_title=None, save_path=None)
display_go_enrichment(UP_GO_results[idx], namespace='MF', fig_title=None, save_path=None)
display_go_enrichment(UP_GO_results[idx], namespace='CC', fig_title=None, save_path=None)

In [None]:
display_go_enrichment(DOWN_GO_results[idx], namespace='BP', fig_title=None, save_path=None)
display_go_enrichment(DOWN_GO_results[idx], namespace='MF', fig_title=None, save_path=None)
display_go_enrichment(DOWN_GO_results[idx], namespace='CC', fig_title=None, save_path=None)

In [None]:
display_kegg_enrichment(UP_KEGG_results[idx], fig_title=None, save_path=None)

In [None]:
display_kegg_enrichment(DOWN_KEGG_results[idx], fig_title=None, save_path=None)

# Latent dimensions heatmap

In [None]:
tags = adata_test.obs['Sample_Tag'].unique().tolist()

In [None]:
classes = adata_test.obs['cluster_class_name'].unique().tolist()

In [None]:
classes

In [None]:
plot_latent_heatmap(autoencoder, data, adata_test, sample_tags=tags, clusters=['IT-ET Glut_1'], subclusters=None, num_cells=1000)

In [None]:
plot_latent_heatmap(autoencoder, data, adata_test, sample_tags=tags, clusters=['CTX-MGE GABA_1'], subclusters=None, num_cells=1000)