## We evaluate the cell type annotation performance of the pretrained model, 
## using a multilayer perceptron (MLP) as the baseline.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset,DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix,f1_score
import ray
from ray import train,tune
from ray.tune import CLIReporter, Checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle
import inspect
from MachineLearnFunc import SparseToDenseDataset, HyperparameterTune
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os
import tempfile
work_dir='./'


### step 1. load dataset.
#### This dataset download from cellxgene, which contain 156,726 single cell RNA sequencing of the human embryonic meninges at 5-13 weeks post conception.

In [None]:
## dataset download from here: https://datasets.cellxgene.cziscience.com/f82c6e2b-8056-45a3-967a-fbadf057566f.h5ad
filepath=f"{work_dir}/datasets/f82c6e2b-8056-45a3-967a-fbadf057566f.h5ad"
sc_data = sc.read_h5ad(filepath)##load file as anndata object,n_obs × n_vars = 156726 × 33159
print(sc_data)
##filter low-quality cells and genes, normalize the data, and identify highly variable genes
sc_data.obs['fraction_mitochondrial']#(min,median,max)=(0,0.04,0.82)
sc_data.obs['total_UMIs']#(min,median,max)=(1000,7057,290169)
sc_data.obs['total_genes']#(min,median,max)=(39,2677,14076)
sc_data = sc_data[sc_data.obs['fraction_mitochondrial'] < 0.2, :] #filter out cells with >20% mitochondrial genes
sc.pp.filter_cells(sc_data, min_genes=200)#filter out cells with <200 genes
sc.pp.filter_cells(sc_data, max_genes=6000) #filter out cells with >6000 genes
# sc.pp.filter_genes(sc_data, min_cells=3)#filter out genes expressed in <3 cells
##generally, we need split dataset before normalization,
##but in single cell RNA-seq data, the normalization method is based on the total counts of each cell
sc.pp.normalize_total(sc_data, target_sum=1e4)##normalize each cell by total counts over all genes, so that every cell has the same total count after normalization (1e4)
sc.pp.log1p(sc_data)##log-transform the data
# sc.pp.highly_variable_genes(sc_data, n_top_genes=3000, flavor='seurat') ##identify the top 3000 highly variable genes
# sc_data = sc_data[:, sc_data.var.highly_variable]

# encoding labels as integers
label_encoder = LabelEncoder()
label_encoder.fit(sc_data.obs['cell_type']) ##only for learning encoding way in training set
num_classes=len(label_encoder.classes_)  
print(f'Number of cell types: {num_classes}')# 58 cell types
y = label_encoder.transform(sc_data.obs['cell_type'])

# to split the data into training, validation, and test sets, while save memory
# to extract dense tensors from sparse matrices,use a custom dataset class-->SparseToDenseDataset
# split the indices instead of the data itself, so that we don't need to create multiple copies of the data in memory.
all_indices = np.arange(sc_data.obs.shape[0])
train_dev_indices, test_indices, y_train_dev, y_test = train_test_split(all_indices,y, test_size=0.2, random_state=42, stratify=y)
train_indices, dev_indices, y_train, y_dev = train_test_split(train_dev_indices,y_train_dev,test_size=0.3, random_state=42, stratify=y_train_dev)
print('Data prepared!')



### step 2. Training a decoder-only pretrained model

In [None]:
## 2.1: define a decoder-only model 
class SingleCellDecoder(nn.Module):
    def __init__(
        self,
        input_dim: int = sc_data.X.shape[1],   # how many genes in a cell,columns of the data matrix
        hidden_dim: int = 512,   # neuron number in the hidden layer
        num_layers: int = 2,     # layers of Transformer
        nhead: int = 4,          # number of attention heads
        dropout: float = 0.1,    # Dropout rate
        output_dim: int = sc_data.X.shape[1]   # ,output_dim,default to input_dim for reconstruction
    ):
        super().__init__()
        
        # embedding layer, embed the input gene expression to a higher-dimensional space
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        # one decoder-only layer,
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dropout=dropout,
            batch_first=True,
        )## batch_first:tell function input shape.True = (batch_size, seq_length,features_dimension).False = (seq_length,batch_size, features_dimension)
        
        # multiple decoder-only layer
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers,
        )
        
        # output layer, hidden layer to gene expression reconstruction.
        self.output_layer = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        input: x: (batch_size, seq_len, input_dim)  # single cell gene expression matrix
        output: (batch_size, seq_len, input_dim)     # recontruct gene expression matrix
        """
        # 1. embedding
        x_embed = self.embedding(x)  # (batch_size, seq_len, hidden_dim)
        
        # 2. self-attention decoder-only layer
        memory = torch.zeros_like(x_embed)  # without encoder, only autoregression
        output = self.transformer_decoder(
            tgt=x_embed, ## embeded input
            memory=memory,## without encoder, use zero to occupied memory
        )##output shape (batch_size, seq_len, hidden_dim)
        
        # 3. output reconstructed gene expression matrix
        return self.output_layer(output)


## 2.2: tune hyperparameters 
def SingleCell_Pretrain(config,patience_=5):
    ## 1. prepare data
    train_dataset = ray.get(config["train_data_ref"])
    dev_dataset = ray.get(config["dev_data_ref"])
    if config.get("use_subset", True):
        subset_indices = np.random.choice(len(train_dataset), int(config['subset_proportion'] * len(train_dataset)),replace=False)
        train_subset = Subset(train_dataset, subset_indices)
    else:
        train_subset = train_dataset
    
    # 2. model initialization
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = config['classifier'](
        nhead=config["nhead"],
        num_layers=config['num_layers']
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
    criterion = nn.MSELoss()

    # 3. data loaders
    train_loader = DataLoader(train_subset, batch_size=config["batch_size"],shuffle=True)
    dev_loader = DataLoader(dev_dataset,batch_size=config["batch_size"])

    # 4. training loop with Early Stopping
    best_loss = float('inf')
    no_improve = 0
    checkpoint = None

    for epoch in range(config["num_epochs"]):
        ##training phase
        model.train()
        train_loss = 0.0
        
        for batch_X, _ in train_loader:
            batch_X = batch_X.to(device)
            optimizer.zero_grad()
            pred_x = model(batch_X)
            loss = criterion(pred_x, batch_X)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        # validation phase
        model.eval()
        dev_loss = 0.0
        with torch.no_grad():
            for batch_X, _ in dev_loader:
                batch_X = batch_X.to(device)
                pred_x = model(batch_X)
                dev_loss += criterion(pred_x, batch_X).item()
        
        dev_loss /= len(dev_loader)
        
        metrics = {
            'train_loss': float(train_loss),
            'dev_loss': float(dev_loss),
            'best_loss': float(best_loss),
            'epoch': epoch + 1,
        }

        # Early Stopping and Checkpointing
        checkpoint_frequency = 5  # save a checkpoint every 5 epochs
        if dev_loss < best_loss:
            best_loss = dev_loss
            no_improve = 0 ##reset no improve,still improving
             # create a temporary directory to save checkpoints
            if (epoch + 1) % checkpoint_frequency == 0:
                with tempfile.TemporaryDirectory() as tempdir:
                    # save model parameters
                    torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
                    # save extra info to a pickle file
                    checkpoint_data = metrics
                    with open(os.path.join(tempdir, "checkpoint_data.pkl"), "wb") as f:
                        pickle.dump(checkpoint_data, f)
                    # create a Checkpoint object from the temporary directory
                    checkpoint = Checkpoint.from_directory(tempdir)# wrap up this dir as Checkpoint,
                    # report metrics and checkpoint to Tune
                    tune.report(metrics, checkpoint=checkpoint)
            else:
                tune.report(metrics)
        else:
            no_improve += 1
            tune.report(metrics)
            if no_improve >= patience_:
                break  

os.makedirs(f"{work_dir}/Pretrained_model", exist_ok=True)
## search for best hyperparameters
config={'classifier':SingleCellDecoder,
        'TrainingFunc':SingleCell_Pretrain,
        'metriclst':['train_loss', 'dev_loss', 'best_loss', 'epoch'],
        'metric_standard':'best_loss',
        'metric_mode':'min',
        'num_epochs':50,
        'num_epochs_atleast':10,
        'totaltrials':50,
        'learning_rate':tune.loguniform(1e-5, 1e-3),
        'batch_size':tune.choice([500,1000]),
        "num_layers": tune.choice([4,8]),
        "nhead": tune.choice([4,8]),
        "use_subset":True,
        "subset_proportion":0.2,
        'outputPath':f"{work_dir}/Pretrained_model",
        'filename':"scPretrain_tune",}

## use a custom hyperparameter tuning function to launch the tuning process
HyperparameterTune(anndata_obj=sc_data,y=y,train_indices=train_indices,dev_indices=dev_indices,test_indices=test_indices,config=config,initcpu=6,initgpu=1,onetrialcpu=1,onetrialgpu=0.2,)

