In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import scib
import torch.nn as nn
import torch
import random
import tensorflow as tf
import warnings
from IPython.display import display
from functions import data_preprocessing as dp
from models import cl_dummy_model5 as scRNASeq_model

  from tensorflow.tsl.python.lib.core import pywrap_ml_dtypes


In [2]:
data_path = '../../data/processed/immune_cells/merged/Oetjen_merged.h5ad'

adata = sc.read(data_path, cache=True)

adata.obs["batch"] = adata.obs["patientID"]

# Ensure reproducibility
def rep_seed(seed):
    # Check if a GPU is available
    if torch.cuda.is_available():
        # Set the random seed for PyTorch CUDA (GPU) operations
        torch.cuda.manual_seed(seed)
        # Set the random seed for all CUDA devices (if multiple GPUs are available)
        torch.cuda.manual_seed_all(seed)
    
    # Set the random seed for CPU-based PyTorch operations
    torch.manual_seed(seed)
    
    # Set the random seed for NumPy
    np.random.seed(seed)
    
    # Set the random seed for Python's built-in 'random' module
    random.seed(seed)
    
    # Set the random seed for TensorFlow
    tf.random.set_seed(seed)
    
    # Set CuDNN to deterministic mode for PyTorch (GPU)
    torch.backends.cudnn.deterministic = True
    
    # Disable CuDNN's benchmarking mode for deterministic behavior
    torch.backends.cudnn.benchmark = False

rep_seed(42)

sc.pp.highly_variable_genes(adata, n_top_genes=4000, flavor="cell_ranger")
adata = adata[:, adata.var["highly_variable"]].copy()

In [3]:
pathways_path = '../../data/processed/pathway_information/all_pathways.json'
train_env = scRNASeq_model.train_module(data_path=adata,
                                        json_file_path=pathways_path,
                                        num_pathways=300,
                                        save_model_path="",
                                        HVG=False,
                                        HVGs=4000,
                                        Scaled=False,
                                        target_key="cell_type",
                                        batch_keys=["batch"])

In [5]:
# Train
_ = train_env.train(device=None,
                        seed=42,
                        batch_size=256,
                        attn_embed_dim=24*4,
                        depth=2,
                        num_heads=4,
                        output_dim=100,
                        attn_drop_out=0.,
                        proj_drop_out=0.2,
                        attn_bias=False,
                        act_layer=nn.ReLU,
                        norm_layer=nn.BatchNorm1d,#nn.BatchNorm1d,#nn.LayerNorm,
                        loss_with_weights=True,
                        init_temperature=0.25,
                        min_temperature=0.1,
                        max_temperature=2.0,
                        init_lr=0.001,
                        lr_scheduler_warmup=4,
                        lr_scheduler_maxiters=25,
                        eval_freq=4,
                        epochs=20,
                        earlystopping_threshold=3,
                        pathway_emb_dim=30)

Number of parameters: 11771162

Start Training



  5%|▌         | 1/20 [06:34<2:04:56, 394.54s/it]

Epoch 1 | Training loss: 0.0781 | Validation loss: 0.0777


 25%|██▌       | 5/20 [28:01<1:24:05, 336.34s/it]

Epoch 5 | Training loss: 0.0772 | Validation loss: 0.0756


 25%|██▌       | 5/20 [28:40<1:26:02, 344.15s/it]


KeyboardInterrupt: 