In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from modules.utils import *
from modules.deg_analysis import *
from modules.visualize import *
from sklearn.model_selection import train_test_split
import torch
from modules.dataloader import PairedDataset
from torch.utils.data import DataLoader, DistributedSampler
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from modules.mlp_model import MLP, SiameseMLP
from modules.kan_model import DeepKAN, SiameseKAN
import torch.nn as nn
import scanpy as sc

2024-07-08 15:26:46.613038: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# sample_tag_mapping = {'SampleTag17_flex':'WT-DMSO',
#                       'SampleTag18_flex':'3xTg-DMSO',
#                       'SampleTag19_flex':'WT-SCDi',
#                       'SampleTag20_flex':'3xTg-SCDi',
#                       'Undetermined':'Undetermined',
#                       'Multiplet':'Multiplet'}
# adata = anndata.read_h5ad("data/fede_count.h5ad")
# adata.obs['Sample_Tag'] = adata.obs['Sample_Tag'].map(sample_tag_mapping)
# anno_df = pd.read_csv("data/fede_mapping.csv", skiprows=4)

In [3]:
adata1 = anndata.read_h5ad("data/A_count.h5ad")
adata1.obs['Sample_Tag'] = 'LD_5xFAD'
adata2 = anndata.read_h5ad("data/B_count.h5ad")
adata2.obs['Sample_Tag'] = "LD_NC"
adata3 = anndata.read_h5ad("data/C_count.h5ad")
adata3.obs['Sample_Tag'] = "run_5xFAD"
adata4 = anndata.read_h5ad("data/D_count.h5ad")
adata4.obs['Sample_Tag'] = "run_NC"
adata = anndata.concat([adata1, adata2, adata3, adata4], axis=0)

anno_df1 = pd.read_csv("data/A_mapping.csv", skiprows=4)
anno_df2 = pd.read_csv("data/B_mapping.csv", skiprows=4)
anno_df3 = pd.read_csv("data/C_mapping.csv", skiprows=4)
anno_df4 = pd.read_csv("data/D_mapping.csv", skiprows=4)
anno_df = pd.concat([anno_df1, anno_df2, anno_df3, anno_df4])

  utils.warn_names_duplicates("obs")


In [4]:
adata = annotate_adata(adata, anno_df)

In [5]:
sc.pp.filter_cells(adata, min_genes=150)
sc.pp.filter_genes(adata, min_cells=3)
adata.var['mt'] = adata.var_names.str.startswith('mt-')
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
adata.obs['high_mt'] = adata.obs['pct_counts_mt'] > 5
adata = adata[~adata.obs['high_mt'], :]
#adata = adata[adata.obs['Sample_Tag'] != "Multiplet", :]
#adata = adata[:, ~adata.var['mt']]

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


In [6]:
sc.pp.normalize_total(adata)
#sc.pp.log1p(adata)
#sc.pp.scale(adata)

  view_to_actual(adata)
  utils.warn_names_duplicates("obs")


In [7]:
unique_labels = {category: idx for idx, category in enumerate(adata.obs['Sample_Tag'].unique())}
adata.obs['target'] = adata.obs['Sample_Tag'].map(unique_labels)

In [8]:
train_indices, test_indices = train_test_split(np.arange(adata.n_obs), test_size=0.2, random_state=42)
adata_train = adata[train_indices].copy()
adata_test = adata[test_indices].copy()

  utils.warn_names_duplicates("obs")


In [9]:
Y_train = np.array(adata_train.obs['target'].tolist())
Y_test = np.array(adata_test.obs['target'].tolist())
X_train = adata_train.X.toarray()
X_test = adata_test.X.toarray()

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = PairedDataset(X_train, Y_train, 5)
train_loader = DataLoader(train_dataset, batch_size=512)

test_dataset = PairedDataset(X_test, Y_test, 5)
test_loader = DataLoader(test_dataset, batch_size=512)

In [11]:
input_dim = X_train.shape[-1]
shared_layers = [4096,1024,256,32]
base_net = MLP(input_dim, shared_layers).to(device)
siamese_model = SiameseMLP(base_net).to(device)

In [12]:
optimizer = optim.RMSprop(siamese_model.parameters(), lr=0.0001)

In [13]:
epochs = 50
best_accuracy = 0
no_improvement_count = 0 

for epoch in range(epochs):
    train_loss = train_epoch(siamese_model, train_loader, optimizer, device, epoch)
    val_accuracy = eval_model(siamese_model, test_loader, device, epoch)
    print(f"Epoch {epoch}, Train Loss: {train_loss}, Validation Accuracy: {val_accuracy}")

    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        no_improvement_count = 0  
        torch.save(siamese_model.base_network.state_dict(), 'best_model.pth')
        print("Model saved as best model")
    else:
        no_improvement_count += 1  

    if no_improvement_count >= 5:
        print("No improvement in validation accuracy for 5 consecutive epochs. Training stopped.")
        break

 40%|████████████████▍                        | 295/737 [00:38<00:57,  7.67it/s]


KeyboardInterrupt: 

In [14]:
input_dim = X_train.shape[-1]
shared_layers = [4096,1024,256,32]
base_net = MLP(input_dim, shared_layers)
model_path = 'best_model.pth'
checkpoint = torch.load(model_path, map_location=device)
base_net.load_state_dict(checkpoint)
base_net.to(device)

base_net.eval()
base_net.cpu()

FileNotFoundError: [Errno 2] No such file or directory: 'best_model.pth'

In [None]:
data = torch.tensor(adata_test.X.toarray(), dtype=torch.float32)
latents = base_net(data)
latents_np = latents.detach().numpy()

In [None]:
adata_test.obsm['latents'] = latents_np

In [None]:
sc.pp.neighbors(adata_test, use_rep='latents')
sc.tl.leiden(adata_test, resolution=0.5)

In [None]:
sc.tl.umap(adata_test)

In [None]:
sc.pl.umap(adata_test, color=['leiden'], save='umap_leiden.png')
sc.pl.umap(adata_test, color=['Sample_Tag'], save='umap_sampletag.png')

In [None]:
assign_unique_cell_type_names(adata_test, cluster_key='leiden', cluster_types=['class_name', 'subclass_name'])

In [None]:
sc.pl.umap(adata_test, color=['cluster_subclass_name'], save='umap_all_groups.png', title=f'After QC - {adata_test.shape[0]} cells', size=10)

In [None]:
sc.pl.umap(adata_test, color=['Sample_Tag'], save='umap_all_groups.png', title=f'After QC - {adata_test.shape[0]} cells', size=10)

In [None]:
#sample_tags = adata_test.obs['Sample_Tag'].unique()
#plot_umap(adata_test, cluster_type='cluster_subclass_name', legend_fontsize=7, save_path='_sample_tag')

In [None]:
tags = adata_test.obs['Sample_Tag'].unique().tolist()

In [None]:
class_level, cluster_type = 'subclass_name', 'cluster_subclass_name'

In [None]:
create_ditto_plot(adata_test, tags, class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path='figures/dito_siamese.png')
#create_ditto_plot(adata_test, [tags[0]], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path=None)
#create_ditto_plot(adata_test, [tags[1]], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path=None)
#create_ditto_plot(adata_test, [tags[2]], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path=None)
#create_ditto_plot(adata_test, [tags[3]], class_level=class_level, cluster_type=cluster_type, min_cell=100, save_path=None)