## We evaluate the cell type annotation performance of the pretrained model, 
## using a multilayer perceptron (MLP) as the baseline.

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset,DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix,f1_score
import ray
from ray import train,tune
from ray.tune import CLIReporter, Checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle
import inspect
import json
from MachineLearnFunc import SparseToDenseDataset, HyperparameterTune,load_config
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os
import tempfile
work_dir='./'

### step 1. load dataset.
#### This dataset download from cellxgene, which contain 156,726 single cell RNA sequencing of the human embryonic meninges at 5-13 weeks post conception.

In [None]:
## dataset download from here: https://datasets.cellxgene.cziscience.com/f82c6e2b-8056-45a3-967a-fbadf057566f.h5ad
filepath=f"{work_dir}/datasets/f82c6e2b-8056-45a3-967a-fbadf057566f.h5ad"
sc_data = sc.read_h5ad(filepath)##load file as anndata object,n_obs × n_vars = 156726 × 33159
print(sc_data)
##filter low-quality cells and genes, normalize the data, and identify highly variable genes
sc_data.obs['fraction_mitochondrial']#(min,median,max)=(0,0.04,0.82)
sc_data.obs['total_UMIs']#(min,median,max)=(1000,7057,290169)
sc_data.obs['total_genes']#(min,median,max)=(39,2677,14076)
sc_data = sc_data[sc_data.obs['fraction_mitochondrial'] < 0.2, :] #filter out cells with >20% mitochondrial genes
sc.pp.filter_cells(sc_data, min_genes=200)#filter out cells with <200 genes
sc.pp.filter_cells(sc_data, max_genes=6000) #filter out cells with >6000 genes
# sc.pp.filter_genes(sc_data, min_cells=3)#filter out genes expressed in <3 cells
##generally, we need split dataset before normalization,
##but in single cell RNA-seq data, the normalization method is based on the total counts of each cell
sc.pp.normalize_total(sc_data, target_sum=1e4)##normalize each cell by total counts over all genes, so that every cell has the same total count after normalization (1e4)
sc.pp.log1p(sc_data)##log-transform the data
# sc.pp.highly_variable_genes(sc_data, n_top_genes=3000, flavor='seurat') ##identify the top 3000 highly variable genes
# sc_data = sc_data[:, sc_data.var.highly_variable]

# encoding labels as integers
label_encoder = LabelEncoder()
label_encoder.fit(sc_data.obs['cell_type']) ##only for learning encoding way in training set
num_classes=len(label_encoder.classes_)  
print(f'Number of cell types: {num_classes}')# 58 cell types
y = label_encoder.transform(sc_data.obs['cell_type'])

# to split the data into training, validation, and test sets, while save memory
# to extract dense tensors from sparse matrices,use a custom dataset class-->SparseToDenseDataset
# split the indices instead of the data itself, so that we don't need to create multiple copies of the data in memory.
all_indices = np.arange(sc_data.obs.shape[0])
train_dev_indices, test_indices, y_train_dev, y_test = train_test_split(all_indices,y, test_size=0.2, random_state=42, stratify=y)
train_indices, dev_indices, y_train, y_dev = train_test_split(train_dev_indices,y_train_dev,test_size=0.3, random_state=42, stratify=y_train_dev)
print('Data prepared!')


### step 2. Cell type annotation with MLP model as baseline.

In [None]:
## define a simple MLP classifier
class CellTypeClassifierLinear(nn.Module):
    def __init__(self,  num_classes=num_classes,
                 input_dim=sc_data.X.shape[1],hidden_neuron=200,dropout=0.1):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_neuron),  # embed_dim=input_dim to hidden_neuron
            nn.ReLU(),nn.Dropout(dropout),
            nn.Linear(hidden_neuron, num_classes)  # output cell types prediction
        )
    
    def forward(self, x):
        return self.classifier(x.squeeze(1)) 

## training function for single trial
def Train_CellType_Annotation(config,patience_=5):
    ## load datasets from Ray object store (instead of capturing directly)
    train_dataset = ray.get(config["train_data_ref"])
    dev_dataset = ray.get(config["dev_data_ref"])
    if config.get("use_subset", True):
        ## prepare a subset of training data for quick tuning
        subset_indices = np.random.choice(len(train_dataset), int(config['subset_fraction'] * len(train_dataset)),replace=False)
        train_subset = Subset(train_dataset, subset_indices)
    else:
        train_subset = train_dataset
    ## model initialization
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = config['classifier'](hidden_neuron=config["hidden_neuron"])
    model.to(device)
    ## define optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = nn.CrossEntropyLoss()
    #load data in batches
    train_loader = DataLoader(train_subset, batch_size=config["batch_size"],shuffle=True)
    dev_loader = DataLoader(dev_dataset,batch_size=config["batch_size"])
    ## training cycle with Early Stopping
    best_f1 = 0.0  ## best_f1 starts from 0, range from 0 to 1
    no_improve = 0
    checkpoint = None  ## initialize checkpoint

    for epoch in range(config["num_epochs"]):
        model.train()
        train_loss = 0.0
        
        for batch_X, batch_y in train_loader:
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)
            optimizer.zero_grad()
            logits = model(batch_X)
            loss = criterion(logits, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        ## average train loss per batch
        avg_train_loss = train_loss / len(train_loader)
        ## evaluate loss on dev set
        model.eval()
        all_preds = []
        all_true = []

        with torch.no_grad():
            for batch_X, batch_y in dev_loader:
                batch_X = batch_X.to(device)
                batch_y = batch_y.to(device)
                logits = model(batch_X)
                preds = torch.argmax(logits, dim=1)
                all_preds.append(preds.cpu())
                all_true.append(batch_y.cpu())

        all_preds = torch.cat(all_preds).numpy()
        all_true = torch.cat(all_true).numpy()

        dev_f1 = f1_score(all_true, all_preds, average='macro') # calculate F1 score (macro)
        # print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Dev F1={dev_f1:.4f}, Best F1={best_f1:.4f}") #for test

        metrics = {
            'train_loss': float(avg_train_loss),
            'dev_f1': float(dev_f1),
            'best_f1': float(best_f1),
            'epoch': epoch + 1}

        ## Early Stopping and Checkpointing
        checkpoint_frequency = 5  ## save checkpoint every 5 epochs
        if dev_f1 > best_f1:
            best_f1 = dev_f1
            no_improve = 0 ##reset no improve,still improving
             ## save checkpoint
            if (epoch + 1) % checkpoint_frequency == 0:
                with tempfile.TemporaryDirectory() as tempdir:
                    ## save model parameters
                    torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
                    ## save extra info to a pickle file
                    checkpoint_data = metrics
                    with open(os.path.join(tempdir, "checkpoint_data.pkl"), "wb") as f:
                        pickle.dump(checkpoint_data, f)
                    ## create a Checkpoint object from the directory
                    checkpoint = Checkpoint.from_directory(tempdir)
                    ## report checkpoint to Ray Tune
                    tune.report(metrics, checkpoint=checkpoint)
            else:
                tune.report(metrics)
        else:
            no_improve += 1
            tune.report(metrics)
            if no_improve >= patience_:
                break  ## early stopping

dataset_ratio=['0.01','0.05','1.00']
for i in dataset_ratio:
    os.makedirs(f"{work_dir}/CellTypeAnnotation/linear{i}", exist_ok=True)
    ## search for best hyperparameters
    config={'classifier':CellTypeClassifierLinear,
        'TrainingFunc':Train_CellType_Annotation,
        'label_names':label_encoder.classes_,
        'metriclst':["train_loss", "dev_f1", "best_f1", "epoch"],
        'metric_standard':'best_f1',
        'metric_mode':'max',
        'num_epochs':10,
        'num_epochs_atleast':5,
        'totaltrials':10,
        'learning_rate':tune.loguniform(1e-5, 1e-3),
        'batch_size':500,#tune.choice([500,1000]),
        'hidden_neuron':400,#tune.choice([1000,2000]),
        "use_subset":True,
        'subset_fraction':float(i),
        'outputPath':f'{work_dir}/CellTypeAnnotation/linear{i}',##ginpai
        'filename':"scCellTypeAnnotation_tune",}

    ## use a custom hyperparameter tuning function to launch the tuning process
    HyperparameterTune(anndata_obj=sc_data,y=y,train_indices=train_indices,dev_indices=dev_indices,test_indices=test_indices,config=config,initcpu=6,initgpu=1,onetrialcpu=1,onetrialgpu=0.2,evaluate_classifier=True)


### step 3. Cell type annotation with pretrained model+MLP (transfer learning).

In [None]:
##load a decoder-only pretrained model
class SingleCellDecoder(nn.Module):
    def __init__(
        self,
        input_dim: int = sc_data.X.shape[1],   # how many genes in a cell,columns of the data matrix
        hidden_dim: int = 512,   # neuron number in the hidden layer
        num_layers: int = 2,     # layers of Transformer
        nhead: int = 4,          # number of attention heads
        dropout: float = 0.1,    # Dropout rate
        output_dim: int = sc_data.X.shape[1]   # ,output_dim,default to input_dim for reconstruction
    ):
        super().__init__()
        
        # embedding layer, embed the input gene expression to a higher-dimensional space
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        # one decoder-only layer,
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dropout=dropout,
            batch_first=True,
        )## batch_first:tell function input shape.True = (batch_size, seq_length,features_dimension).False = (seq_length,batch_size, features_dimension)
        
        # multiple decoder-only layer
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers,
        )
        
        # output layer, hidden layer to gene expression reconstruction.
        self.output_layer = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        input: x: (batch_size, seq_len, input_dim)  # single cell gene expression matrix
        output: (batch_size, seq_len, input_dim)     # recontruct gene expression matrix
        """
        # 1. embedding
        x_embed = self.embedding(x)  # (batch_size, seq_len, hidden_dim)
        
        # 2. self-attention decoder-only layer
        memory = torch.zeros_like(x_embed)  # without encoder, only autoregression
        output = self.transformer_decoder(
            tgt=x_embed, ## embeded input
            memory=memory,## without encoder, use zero to occupied memory
        )##output shape (batch_size, seq_len, hidden_dim)
        
        # 3. output reconstructed gene expression matrix
        return self.output_layer(output)
## load pretrained model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pretrained_model= SingleCellDecoder(load_config(f"{work_dir}/Pretrained_model/best_model_hyperparameters.json"))
pretrained_model.load_state_dict(torch.load(f"{work_dir}/Pretrained_model/best_model.pth",map_location=device))
pretrained_model.to(device)

## define a transfer learning classifier
class CellTypeClassifierTransfer(nn.Module):
    def __init__(self,  num_classes=num_classes, encoder=pretrained_model,
                 input_dim=sc_data.X.shape[1],hidden_neuron=200,dropout=0.1):
        super().__init__()
        self.encoder = encoder  # pretrained model
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_neuron),
            nn.ReLU(),nn.Dropout(dropout),
            nn.Linear(hidden_neuron, num_classes)
        )
        
    def forward(self, x):
        embeddings = self.encoder(x)
        # print(embeddings.squeeze(1).shape)  # remove seq_len dimension → [batch_size, 2000]
        return self.classifier(embeddings.squeeze(1)) ## prediting cell types

dataset_ratio=['0.01','0.05','1.00']
for i in dataset_ratio:
    os.makedirs(f"{work_dir}/CellTypeAnnotation/transfer{i}", exist_ok=True)
    ##hyperparameter tuning: search space
    config={'classifier':CellTypeClassifierTransfer,
        'TrainingFunc':Train_CellType_Annotation,
        'label_names':label_encoder.classes_,
        'metriclst':["train_loss", "dev_f1", "best_f1", "epoch"],
        'metric_standard':'best_f1',
        'metric_mode':'max',
        'num_epochs':10,
        'num_epochs_atleast':5,
        'totaltrials':10,
        'learning_rate':tune.loguniform(1e-5, 1e-3),
        'batch_size':500,#tune.choice([500,1000]),
        'hidden_neuron':400,#tune.choice([400, 800]),
        "use_subset":True,
        'subset_fraction':float(i),
        'outputPath':f"{work_dir}/CellTypeAnnotation/transfer{i}",
        'filename':"scCellTypeAnnotation_tune",} 

    HyperparameterTune(anndata_obj=sc_data,y=y,train_indices=train_indices,dev_indices=dev_indices,test_indices=test_indices,config=config,initcpu=6,initgpu=1,onetrialcpu=1,onetrialgpu=0.2,evaluate_classifier=True)

### step 4. comparing cell type annotation performance between transfer learning and MLP.

In [None]:
##collect results from different trials and make plots
##rename suffix of those files mannually
os.makedirs(f"{work_dir}/evaluation", exist_ok=True)
suffix_lst=['linear1perc','linear5perc','linear100perc','tl1perc','tl5perc','tl100perc']
group_lst=['mlp','mlp','mlp','transfer learning','transfer learning','transfer learning']
subset_prop=['0.01','0.05','1.00','0.01','0.05','1.00']
sumdf=pd.DataFrame()
for i in range(len(suffix_lst)):
    cta_tp=pd.read_csv(f"{work_dir}/evaluation/classification_report_{suffix_lst[i]}.csv",index_col=0)
    cta_tp=cta_tp.iloc[:-3,:]
    cta_tp=cta_tp.assign(group=group_lst[i],subset_proportion=subset_prop[i])
    sumdf=pd.concat([sumdf,cta_tp],axis=0)

sumdf.index.name='cell type'
## table wide to long format for seaborn plotting
df_long = sumdf.reset_index().melt(
    id_vars=['cell type', 'support', 'group', 'subset_proportion'],
    value_vars=['precision', 'recall', 'f1-score'],
    var_name='metric',
    value_name='score')


In [None]:
sns.set_style("whitegrid")
plt.figure(figsize=(15, 5))
# create 1x3 subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# create boxplots for each metric
metrics = ['precision', 'recall', 'f1-score']
titles = ['Precision', 'Recall', 'F1-Score']

for i, (metric, title) in enumerate(zip(metrics, titles)):
    metric_data = df_long[df_long['metric'] == metric]
    sns.boxplot(
        data=metric_data, 
        x='subset_proportion', 
        y='score', 
        hue='group',
        ax=axes[i],
        palette=['#7fbc41', '#de77ae'])
    
    # axes[i].set_title(f'{title}', fontsize=14)
    axes[i].set_xlabel('Training Set Proportion',   fontsize=14)
    axes[i].set_ylabel(f'{title}', fontsize=14)
    axes[i].legend(title='Method',  fontsize=12,loc='lower right')
    axes[i].grid(True, alpha=0.3)

plt.suptitle('Model Performance Across Different Training Set Proportions', fontsize=16)
plt.tight_layout()
plt.savefig(f'{work_dir}/evaluation/boxplot.jpg')
plt.show()