In [None]:
import torch
from scLLM.Models.scBERT.model import PerformerLM
from scLLM.Models.scBERT.paras import scBERT_para


model_para = scBERT_para()    
model_para.num_tokens=5+2                         # num of tokens
model_para.max_seq_len=16906+1#24447+1              # max length of sequence
model_para.dim=200                                # dim of tokens
model_para.depth=6                              # layers
model_para.heads=10
model_para.local_attn_heads = 0 
model_para.g2v_position_emb = True 
model_para.g2v_weight_loc = "Path/to/gene2vec_16906_200.npy"
#"/Users/shipan/Documents/scLLM_workspace/scLLM/pre_trained/scBERT/gene2vec_24447_200.npy"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cls_nb = 20
upper_bound = 5
ckpt_pth = "/Path/to/panglao_pretrain.pth"
EPOCHS=10
GRADIENT_ACCUMULATION=8
LEARNING_RATE = 1e-3


In [None]:
# get model from embedding file with saved numpy file
#gene2vec_emb_loc = "github/Gene2vec/pre_trained_emb/gene2vec_dim_200_iter_9_w2v.txt"
#from scLLM.Models.scBERT.utils import transfer_gene2vec_as_weight
# get model from embedding file with saved numpy file
#gs_model = transfer_gene2vec_as_weight(gene2vec_emb_loc,model_para.g2v_weight_loc)
# get model from embedding file without saved numpy file
#gs_model = transfer_gene2vec_as_weight(gene2vec_emb_loc)

In [None]:
model = PerformerLM(model_para)

In [None]:
import torch
pre_model = torch.load(ckpt_pth, map_location=device)
model.load_state_dict(pre_model["model_state_dict"])

# change output layer
from scLLM.Modules.out_layer import Identity
model.to_out = Identity(in_dim=model_para.max_seq_len,
                        dropout=0., 
                        h_dim=128, 
                        out_dim=cls_nb)

In [None]:
for param in model.parameters():
    param.requires_grad = False
for param in model.norm.parameters():
    param.requires_grad = True
for param in model.performer.net.layers[-2].parameters():
    param.requires_grad = True

In [None]:
vocab_loc = "/Users/shipan/Documents/scLLM_workspace/pre_trained/scBERT/vocab_gene2vec_16906.pkl"
adata_loc = "/Users/shipan/Documents/scLLM_workspace/data/Eloise/allMCF.h5ad"

import pickle
with open(vocab_loc, "rb") as f:
    vocab = pickle.load(f)
# init preprocessor
from scLLM.Dataset.preprocessor import Preprocessor
from scLLM.Dataset.paras import Dataset_para
# define pre-processing by follow original implementation of scBERT
dataset_para = Dataset_para(gene_vocab=vocab,
                            filter_gene_by_counts=False,
                            filter_cell_by_counts=200,
                            log1p=True,
                            log1p_base=2,
                            batch_size=1,
                            )

preprocess = Preprocessor(dataset_para)
preprocess.load_adata(adata_loc)
data = preprocess.to_data(data_type="log1p")
label,class_weight = preprocess.to_label(
                          label_key="pseudotimes",
                          binarize="equal_instance",
                          bin_nb=cls_nb,)

from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2023)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2023)

idx_tr,idx_val = next(iter(sss.split(data, label)))
data_train, label_train = data[idx_tr], label[idx_tr]
data_val, label_val = data[idx_val], label[idx_val]

from torch.utils.data.sampler import WeightedRandomSampler
weights_train = [class_weight[label_train[i]] for i in range(label_train.shape[0])]
sampler_train = WeightedRandomSampler(torch.DoubleTensor(weights_train), len(weights_train))
weights_val = [class_weight[label_val[i]] for i in range(label_val.shape[0])]
sampler_val = WeightedRandomSampler(torch.DoubleTensor(weights_val), len(weights_val))


from scLLM.Dataset.dataset import SCDataset
train_dataset = SCDataset(data_train, label_train,cls_nb=cls_nb,device=device)
val_dataset = SCDataset(data_val, label_val,cls_nb=cls_nb,device=device)
train_loader = preprocess.to_dataloader(train_dataset, sampler=sampler_train)
val_loader = preprocess.to_dataloader(val_dataset,sampler=sampler_val)

In [None]:
import torch.nn as nn
from torch.optim import Adam, SGD, AdamW
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR
from scLLM.Models.scBERT.utils import CosineAnnealingWarmupRestarts

# optimizer
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = CosineAnnealingWarmupRestarts(
    optimizer,
    first_cycle_steps=15,
    cycle_mult=2,
    max_lr=LEARNING_RATE,
    min_lr=1e-6,
    warmup_steps=5,
    gamma=0.9
)
loss_fn = nn.CrossEntropyLoss(weight=None).to(device)

In [None]:

for i in range(1, EPOCHS+1):
    #train_loader.sampler.set_epoch(i)
    model.train()
    #dist.barrier()
    running_loss = 0.0
    cum_acc = 0.0
    for index, (data, labels) in enumerate(train_loader):
        index += 1
        data, labels = data.to(device), labels.to(device)
        if index % GRADIENT_ACCUMULATION != 0:
            #with model.no_sync():
            logits = model(data)
            loss = loss_fn(logits, labels)
            loss.backward()
        if index % GRADIENT_ACCUMULATION == 0:
            logits = model(data)
            loss = loss_fn(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6))
            optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.item()
        softmax = nn.Softmax(dim=-1)
        final = softmax(logits)
        final = final.argmax(dim=-1)
        pred_num = labels.size(0)
        correct_num = torch.eq(final, labels).sum(dim=-1)
        cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
    epoch_loss = running_loss / index
    epoch_acc = 100 * cum_acc / index

    print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:6.4f}%  ==')
    
    scheduler.step()