In [1]:
import os
import sys
import scanpy as sc
import numpy as np
import pandas as pd
import pickle as pkl
import torch
from scipy.sparse import lil_matrix, csr_matrix
import importlib
import src.utils_BERT
from src.performer_pytorch.performer_pytorch import PerformerLM
importlib.reload(src.utils_BERT)
importlib.reload(src.performer_pytorch.performer_pytorch)

<module 'src.performer_pytorch.performer_pytorch' from '/data/lyx/scCHiP/scATAC/LDA/PBMC_10k/src/performer_pytorch/performer_pytorch.py'>

In [2]:
data_dir = '/data/lyx/scCHiP/scATAC/LDA/PBMC_10k/processed_data'
work_dir = '/data/lyx/scCHiP/scATAC/LDA/PBMC_10k'
sc_count_file = os.path.join(data_dir,"10k_PBMC_Multiome_filtered_gene_count.h5ad")
sc_anno_file = os.path.join(data_dir,"MainCelltype.txt")

In [3]:
meta_data = pd.read_table(sc_anno_file, header=None,index_col=0)
meta_data.columns = ["Celltype"]
cell_type = list(set(meta_data.Celltype))
ntopics_list = list(range(len(cell_type), 3*len(cell_type)+1))

In [4]:
data = sc.read_h5ad(sc_count_file)
gene_names = pd.read_table(os.path.join(work_dir,"src/data_BERT/gene2vec_16906_names.txt"),sep="\t",header=None)[0].to_list()
data = data[:,data.var_names.isin(gene_names)].copy()
indices = [index for index, element in enumerate(gene_names) if element in data.var_names] 

sc.pp.normalize_total(data, target_sum=1e4)
sc.pp.log1p(data)

data_csr = np.zeros((data.shape[0], len(gene_names)), dtype=np.float32) 
data_csr[:, indices]=np.array(data.X.todense())
data_csr = csr_matrix(data_csr)

In [5]:
label_names, label = np.unique(np.array(meta_data['Celltype']), return_inverse=True)  
# Convert strings categorical to integrate categorical, and label_names[label] can be restored
#store the label dict and label for prediction
# with open('label_names', 'wb') as fp:
#     pkl.dump(label_names, fp)
# with open('label', 'wb') as fp:
#     pkl.dump(label, fp)

class_num = np.unique(label, return_counts=True)[1].tolist()

In [6]:
class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
label = torch.from_numpy(label)

In [7]:
from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report

In [8]:
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import StepLR, CosineAnnealingWarmRestarts, CyclicLR
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

In [25]:
import os
import gc
# import argparse
import json
import random
import math
import random
from functools import reduce
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.model_selection import train_test_split, ShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_recall_fscore_support, classification_report

class SCDataset(Dataset):
    def __init__(self, data, label,CLASS,device):
        super().__init__()
        self.data = data
        self.label = label
        self.CLASS = CLASS
        self.device = device

    def __getitem__(self, index):
        rand_start = random.randint(0, self.data.shape[0]-1)
        full_seq = self.data[rand_start].toarray()[0]
        full_seq[full_seq > (self.CLASS - 2)] = self.CLASS - 2
        full_seq = torch.from_numpy(full_seq).long()
        full_seq = torch.cat((full_seq, torch.tensor([0]))).to(self.device)
        seq_label = self.label[rand_start]
        return full_seq, seq_label

    def __len__(self):
        return self.data.shape[0]

class MinimalIdentity(torch.nn.Module):
    def __init__(self,out_dim = 10,SEQ_LEN=16907):
        super(MinimalIdentity, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, (1, 200))
        self.act = nn.ReLU()
        self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=out_dim, bias=True)

    def forward(self, x):
        x = x[:,None,:,:]
        x = self.conv1(x)
        x = self.act(x)
        x = x.view(x.shape[0],-1)
        x = self.fc1(x)
        return x
class MinimalClassifier(nn.Module):
    def __init__(self, embed_dim=200, num_classes=6):
        super().__init__()
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),  # [batch,200,1]
            nn.Flatten()
        )
        self.classifier = nn.Linear(200, num_classes)
        
    def forward(self, x):
        return self.classifier(self.pool(x.permute(0,2,1)))

    
class MinimalIdentity_v2(torch.nn.Module):
    def __init__(self,out_dim = 10,SEQ_LEN=16907,h_dim=512,dropout=0.0):
        super(MinimalIdentity_v2, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, (1, 200))
        self.act = nn.ReLU()
        self.fc1 = nn.Linear(in_features=SEQ_LEN, out_features=h_dim, bias=True)
        self.act1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(in_features=h_dim, out_features=out_dim, bias=True)

    def forward(self, x):
        x = x[:,None,:,:]
        x = self.conv1(x)
        x = self.act(x)
        x = x.view(x.shape[0],-1)
        x = self.fc1(x)
        return x
    
class GeneChannelAttention(nn.Module):
    """基因通道注意力（移除位置相关操作）"""
    def __init__(self, in_dim):
        super().__init__()
        self.channel_scale = nn.Sequential(
            nn.Linear(in_dim, in_dim//4),
            nn.ReLU(),
            nn.Linear(in_dim//4, in_dim),
            nn.Sigmoid()  # 添加Sigmoid确保权重在[0,1]范围
        )
        
    def forward(self, x):
        """输入形状: [batch, channels, num_genes]"""
        # 全局平均 → 通道权重
        gene_weights = self.channel_scale(x.mean(dim=-1))  # [batch, channels]
        # 加权聚合所有基因特征
        return (x * gene_weights.unsqueeze(-1)).sum(dim=-1)  # [batch, channels]

class BioClassifier_v2(nn.Module):
    def __init__(self, num_genes=16907, embed_dim=200, num_classes=6):
        super().__init__()
        
        # 特征聚合模块
        self.feature_aggregator = nn.Sequential(
            GeneChannelAttention(200),    # 通道注意力加权
            nn.BatchNorm1d(200),
            nn.Dropout(0.2)
        )
        
        # 分类器
        self.classifier = nn.Sequential(
            # nn.Linear(200, 512),
            # nn.GELU(),
            nn.Linear(200, num_classes)
        )

    def forward(self, x):
        # x形状: [batch, num_genes, embed_dim]
        x = x.permute(0, 2, 1)         # [24,200,16907] (channels_first)
        x = self.feature_aggregator(x) # [24,200]
        return self.classifier(x)      # [24,6]

# first train #

In [9]:
SEED = 1
BATCH_SIZE = 24
CLASS = 5+2 #Number of bins.'+2
SEQ_LEN = len(gene_names)+1#gene_num", type=int, default=16906
POS_EMBED_USING = True #'Using Gene2vec encoding or not.'
LEARNING_RATE = 1e-4
EPOCHS = 20
GRADIENT_ACCUMULATION = 60
VALIDATE_EVERY =1
PATIENCE = 10
UNASSIGN_THRES = 0.0
path = "/data/lyx/scCHiP/scATAC/LDA/PBMC_10k/src/data_BERT/panglao_pretrain.pth"
device = torch.device("cpu")

POS_EMBED_USING =True
model_name = "2025_finetune_simple_PBMC_MainCelltype_scBert"
ckpt_dir = os.path.join(work_dir,"scBert_model","scBert_PBMC_10k/")
# world_size = torch.distributed.get_world_size()


In [10]:
len(gene_names)

16906

# tree train #

In [10]:
SEED = 3
BATCH_SIZE = 24
CLASS = 5+2 #Number of bins.'+2
SEQ_LEN = len(gene_names)+1#gene_num", type=int, default=16906
POS_EMBED_USING = True #'Using Gene2vec encoding or not.'
LEARNING_RATE = 1e-4
EPOCHS = 20
GRADIENT_ACCUMULATION = 60
VALIDATE_EVERY =1
PATIENCE = 10
UNASSIGN_THRES = 0.0
path = "/data/lyx/scCHiP/scATAC/LDA/PBMC_10k/scBert_model/scBert_PBMC_10k/2025_finetune_simple_PBMC_MainCelltype_scBert_v3_best.pth"
device = torch.device("cpu")

POS_EMBED_USING =True
model_name = "2025_finetune_simple_PBMC_MainCelltype_scBert_v4"
ckpt_dir = os.path.join(work_dir,"scBert_model","scBert_PBMC_10k/")

## test ##

In [None]:
# 测试代码
test_input = torch.randn(24, 16907, 200)  # 模拟真实输入
model = BioClassifier()

# 逐步检查维度变化
x_proj = model.proj(test_input)          # 应得到[24,16907,256]
x_perm = x_proj.permute(0, 2, 1)         # → [24,256,16907]
conv_out = model.feature_extractor[0](x_perm)  # → [24,256,16907]
print(conv_out.shape)  # 应保持通道数256

# 完整前向传播验证
output = model(test_input)
print(output.shape)  # 预期[24,6]

## run ##

In [11]:
acc = []
f1 = []
f1w = []
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
pred_list = pd.Series(['un'] * data_csr.shape[0])

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
for index_train, index_val in sss.split(data_csr, label):
    data_train, label_train = data_csr[index_train], label[index_train]
    data_val, label_val = data_csr[index_val], label[index_val]
    train_dataset = SCDataset(data_train, label_train,CLASS,device)
    val_dataset = SCDataset(data_val, label_val,CLASS,device)

In [12]:
train_sampler = src.utils_BERT.SimpleSampler(train_dataset)
val_sampler = src.utils_BERT.SimpleSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

In [13]:
16907 % 512

11

In [14]:
model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    gene_weight_file='./src/data_BERT/gene2vec_16906.npy',
    local_attn_heads = 0,
    g2v_position_emb = POS_EMBED_USING
)

The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2416.)
  q, r = torch.qr(unstructured_block.cpu(), some = True)


In [18]:
model.to_out = MinimalIdentity(out_dim=6)

In [19]:
device = torch.device("cpu")

In [20]:
ckpt = torch.load(path,map_location=device)

  ckpt = torch.load(path,map_location=device)


In [21]:
model.load_state_dict(ckpt['model_state_dict'])
for param in model.parameters():
    param.requires_grad = False
for param in model.norm.parameters():
    param.requires_grad = True
for param in model.performer.net.layers[-2].parameters():
    param.requires_grad = True

In [28]:
model.to_out = MinimalIdentity_v2(out_dim = 6,SEQ_LEN=16907,h_dim=512,dropout=0.2)

In [29]:
device = torch.device("cuda:1")
model = model.to(device)

In [30]:
USE_AMP = True  # 启用混合精度
MAX_GRAD_NORM = 1.0  # 梯度裁剪阈值

# 初始化混合精度训练
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

# 改进的参数分组策略
param_groups = [
    {'params': model.performer.parameters(), 'lr': 1e-5, 'weight_decay': 0.001},
    {'params': model.to_out.parameters(), 'lr': 1e-3, 'weight_decay': 0.01}
]

# 动态调整的优化器配置
optimizer = torch.optim.AdamW(param_groups)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=[group['lr'] for group in param_groups],  # 分组学习率峰值
    total_steps=1000,
    pct_start=0.2
)

loss_fn = nn.CrossEntropyLoss(weight=None).to(device)

  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [31]:
class EarlyStopper:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_f1 = -float('inf')
        self.counter = 0

    def __call__(self, val_f1):
        if val_f1 > self.best_f1 + self.min_delta:
            self.best_f1 = val_f1
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

early_stopper = EarlyStopper(patience=PATIENCE)

In [32]:
# -*- coding: utf-8 -*-
import torch
from sklearn.metrics import f1_score, accuracy_score

# 训练主循环（完整改进版）
def train_model(model, optimizer, scheduler, train_loader, val_loader, 
                loss_fn, device, EPOCHS, GRADIENT_ACCUMULATION, 
                MAX_GRAD_NORM, VALIDATE_EVERY, UNASSIGN_THRES, 
                early_stopper, model_name, ckpt_dir):
    # 初始化记录器
    train_loss_history = []
    val_loss_history = []
    val_epoch_indices = []
    lr_history = []
    
    # 显式确保模型在目标设备
    model = model.to(device)
    print(f"[Init] 模型已加载到设备: {next(model.parameters()).device}")
    
    # 梯度缩放器
    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    
    # 获取训练集总样本量（用于精确损失计算）
    total_train_samples = len(train_loader.dataset)
    
    for epoch in range(1, EPOCHS+1):
        train_loader.sampler.set_epoch(epoch)
        model.train()
        
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        optimizer.zero_grad()
        
        # ================== 训练阶段 ==================
        for batch_idx, (data, labels) in enumerate(train_loader, 1):
            # 确保数据与模型在同一设备
            data, labels = data.to(device), labels.to(device)
            assert data.device == labels.device == device  # 设备一致性检查
            
            # ---- 混合精度前向 ----
            with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=USE_AMP):
                logits = model(data)
                loss = loss_fn(logits, labels) / GRADIENT_ACCUMULATION
                
            # ---- 异常损失检测 ----
            if torch.isnan(loss).any() or torch.isinf(loss).any():
                print(f"[Epoch {epoch} Batch {batch_idx}] 检测到异常损失值: {loss.item():.4f}，跳过该batch")
                optimizer.zero_grad()
                continue
                
            # ---- 梯度累积 ----
            scaler.scale(loss).backward()
            
            # ---- 统计量计算 ----
            with torch.no_grad():
                preds = logits.detach().argmax(dim=-1)
                correct_predictions += (preds == labels).sum().item()
                total_samples += labels.size(0)
                # 按样本比例计算损失（考虑梯度累积）
                running_loss += loss.item() * (data.size(0) / total_train_samples)
                
            # ---- 参数更新 ----
            if batch_idx % GRADIENT_ACCUMULATION == 0 or batch_idx == len(train_loader):
                # 梯度裁剪与检查
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                
                # 梯度异常处理
                if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                    print(f"[Epoch {epoch} Batch {batch_idx}] 检测到异常梯度范数: {grad_norm:.4f}，跳过更新")
                    optimizer.zero_grad()
                    scaler.update()  # 必须更新scaler状态
                    continue
                    
                # 优化器步进
                try:
                    scaler.step(optimizer)
                    scaler.update()
                except RuntimeError as e:
                    print(f"[Epoch {epoch} Batch {batch_idx}] 优化器步进失败: {str(e)}")
                    optimizer.zero_grad()
                    continue
                    
                optimizer.zero_grad()
        
        # ================== 训练后处理 ==================
        # 计算epoch指标
        epoch_loss = running_loss * GRADIENT_ACCUMULATION  # 修正累积损失
        epoch_acc = 100 * correct_predictions / total_samples
        train_loss_history.append(epoch_loss)
        lr_history.append(optimizer.param_groups[0]['lr'])
        
        print(f"Epoch {epoch} | Train Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")
        
        # ================== 验证阶段 ==================
        if epoch % VALIDATE_EVERY == 0:
            model.eval()
            val_preds = []
            val_labels = []
            val_loss = 0.0
            
            with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, enabled=USE_AMP):
                for data_v, labels_v in val_loader:
                    data_v, labels_v = data_v.to(device), labels_v.to(device)
                    logits = model(data_v)
                    
                    # 损失计算
                    val_loss += loss_fn(logits, labels_v).item()
                    
                    # 预测处理
                    probs = torch.softmax(logits, dim=-1)
                    preds = probs.argmax(dim=-1)
                    mask = probs.max(dim=-1).values >= UNASSIGN_THRES
                    val_preds.append(preds[mask].cpu())
                    val_labels.append(labels_v[mask].cpu())
                    
            # 合并结果
            val_loss = val_loss / len(val_loader)
            val_preds = torch.cat(val_preds) if len(val_preds) > 0 else torch.tensor([])
            val_labels = torch.cat(val_labels) if len(val_labels) > 0 else torch.tensor([])
            
            # 空样本处理
            if len(val_labels) == 0:
                print(f"[Epoch {epoch}] 警告：验证集无有效预测样本，跳过指标计算")
                val_f1 = 0.0
                val_acc = 0.0
            else:
                val_f1 = f1_score(val_labels, val_preds, average='macro')
                val_acc = accuracy_score(val_labels, val_preds)
                print(f"Epoch {epoch} | Val Loss: {val_loss:.4f} | F1: {val_f1:.4f} | Acc: {val_acc:.2f}%")
                
                # 早停判断与模型保存
                if val_f1 > early_stopper.best_f1:
                    src.utils_BERT.save_best_ckpt(epoch, model, optimizer, scheduler, 
                                                val_loss, model_name, ckpt_dir)
                
                if early_stopper(val_f1):
                    print(f"Early stopping at epoch {epoch}")
                    break
            
            val_loss_history.append(val_loss)
            val_epoch_indices.append(epoch)
            
            # 学习率调度（示例：ReduceLROnPlateau）
            if len(val_labels) > 0:
                scheduler.step(val_loss)  # 根据验证损失调整
            else:
                scheduler.step()  # 无验证指标时使用默认调整
        
        # ================== 资源管理 ==================
        # 每5个epoch清理一次显存
        if epoch % 5 == 0:
            torch.cuda.empty_cache()
            print(f"[Memory] 已清理显存，当前使用量: {torch.cuda.memory_allocated(device)/1e9:.2f} GB")
    
    return {
        "train_loss": train_loss_history,
        "val_loss": val_loss_history,
        "val_epochs": val_epoch_indices,
        "lr_history": lr_history
    }


In [33]:
EPOCHS = 10

In [34]:
%%time
# 调用训练函数
training_stats = train_model(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=loss_fn,
    device=device,
    EPOCHS=EPOCHS,
    GRADIENT_ACCUMULATION=GRADIENT_ACCUMULATION,
    MAX_GRAD_NORM=MAX_GRAD_NORM,
    VALIDATE_EVERY=VALIDATE_EVERY,
    UNASSIGN_THRES=UNASSIGN_THRES,
    early_stopper=early_stopper,
    model_name=model_name,
    ckpt_dir=ckpt_dir
)

[Init] 模型已加载到设备: cuda:1


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


Epoch 1 | Train Loss: 5.9255 | Acc: 17.72%
Epoch 1 | Val Loss: 5.5467 | F1: 0.0269 | Acc: 0.31%




Epoch 2 | Train Loss: 5.3046 | Acc: 32.56%
Epoch 2 | Val Loss: 4.9285 | F1: 0.1088 | Acc: 0.37%




Epoch 3 | Train Loss: 4.6485 | Acc: 39.71%
Epoch 3 | Val Loss: 4.2659 | F1: 0.1080 | Acc: 0.42%




Epoch 4 | Train Loss: 4.1078 | Acc: 43.15%
Epoch 4 | Val Loss: 3.9116 | F1: 0.1072 | Acc: 0.44%




Epoch 5 | Train Loss: 3.6740 | Acc: 44.40%
Epoch 5 | Val Loss: 3.3112 | F1: 0.1031 | Acc: 0.45%
[Memory] 已清理显存，当前使用量: 0.18 GB




Epoch 6 | Train Loss: 3.0593 | Acc: 45.31%
Epoch 6 | Val Loss: 2.7151 | F1: 0.1050 | Acc: 0.46%




Epoch 7 | Train Loss: 2.3850 | Acc: 44.32%
Epoch 7 | Val Loss: 2.1057 | F1: 0.1012 | Acc: 0.44%




Epoch 8 | Train Loss: 1.9690 | Acc: 44.27%
Epoch 8 | Val Loss: 1.8254 | F1: 0.1039 | Acc: 0.45%




Epoch 9 | Train Loss: 1.7002 | Acc: 44.79%
Epoch 9 | Val Loss: 1.6192 | F1: 0.0993 | Acc: 0.42%




Epoch 10 | Train Loss: 1.5383 | Acc: 44.84%
Epoch 10 | Val Loss: 1.5165 | F1: 0.1413 | Acc: 0.46%
[Memory] 已清理显存，当前使用量: 0.18 GB
CPU times: user 2h 9min 59s, sys: 13.8 s, total: 2h 10min 13s
Wall time: 2h 12min 26s




In [25]:
# 输出训练结果
print("最终验证损失:", training_stats["val_loss"][-1])

最终验证损失: 1.303528324250252


In [24]:
training_stats

{'train_loss': [1.288267878155803,
  1.288598245695821,
  1.306687069889516,
  1.2957995911584594,
  1.315560081370334,
  1.313788087495373,
  1.3372783117582736,
  1.3026923176528968,
  1.313859963851606,
  1.3171471798779317],
 'val_loss': [1.3215542863453589,
  1.3019985897887139,
  1.3039210330574744,
  1.310948115683371,
  1.3551416704731603,
  1.329684595907888,
  1.3016742991824304,
  1.2968803324526357,
  1.2975076653303639,
  1.303528324250252],
 'val_epochs': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'lr_history': [3.9999999999999956e-07,
  4.010446218278138e-07,
  4.010139360665647e-07,
  4.0101693239045346e-07,
  4.010279224240047e-07,
  4.010983927816881e-07,
  4.0105751408840324e-07,
  4.0101343106229606e-07,
  4.0100598029073404e-07,
  4.010069537268468e-07]}

In [44]:
src.utils_BERT.save_best_ckpt("16", model, optimizer, scheduler, 
                              0, '2025_finetune_PBMC_MainCelltype_scBert_last', ckpt_dir)

In [45]:
torch.cuda.empty_cache()

In [17]:
device = torch.device("cuda:0")
model = model.to(device)

In [18]:
# optimizer
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = CosineAnnealingWarmupRestarts(
    optimizer,
    first_cycle_steps=15,
    cycle_mult=2,
    max_lr=LEARNING_RATE,
    min_lr=1e-6,
    warmup_steps=5,
    gamma=0.9
)
loss_fn = nn.CrossEntropyLoss(weight=None).to(device)

In [None]:
trigger_times = 0
max_acc = 0.0
for i in range(1, EPOCHS+1):
    train_loader.sampler.set_epoch(i)
    model.train()
    running_loss = 0.0
    cum_acc = 0.0
    for index, (data, labels) in enumerate(train_loader):
        index += 1
        data, labels = data.to(device), labels.to(device)
        if index % GRADIENT_ACCUMULATION != 0:
            
            logits = model(data)
            loss = loss_fn(logits, labels)
            loss.backward()
            
        if index % GRADIENT_ACCUMULATION == 0:      
            torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6))
            optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.item()
        softmax = nn.Softmax(dim=-1)
        
        final = softmax(logits)
        final = final.argmax(dim=-1)
        
        #pred_num = labels.size(0)
        #correct_num = torch.eq(final, labels).sum(dim=-1)
        #cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
        cum_acc += torch.eq(final, labels).sum().item() / labels.size(0)
        
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * cum_acc / len(train_loader)
    print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:4f}%  ==')
    
    scheduler.step()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        running_loss = 0.0
        predictions = []
        truths = []
        with torch.no_grad():
            for index, (data_v, labels_v) in enumerate(val_loader):
        
                data_v, labels_v = data_v.to(device), labels_v.to(device)
            
                logits = model(data_v)
                loss = loss_fn(logits, labels_v)
                running_loss += loss.item()
                
                softmax = nn.Softmax(dim=-1)
                final_prob = softmax(logits)
                final = final_prob.argmax(dim=-1)
                final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1
                
                predictions.append(final.cpu())
                truths.append(labels_v.cpu())
            
            predictions = torch.cat(predictions)
            truths = torch.cat(truths)
            no_drop = predictions != -1
            predictions = predictions[no_drop]
            truths = truths[no_drop]
            
            cur_acc = accuracy_score(truths, predictions)
            f1 = f1_score(truths, predictions, average='macro')
            val_loss = running_loss / len(val_loader)

            print(f'Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f}')
            print(confusion_matrix(truths, predictions))
            print(classification_report(truths,predictions, target_names=label_names.tolist(), digits=4))
            
            if cur_acc > max_acc:
                max_acc = cur_acc
                trigger_times = 0
                src.utils_BERT.save_best_ckpt(i, model, optimizer, scheduler, val_loss, model_name, ckpt_dir)
            else:
                trigger_times += 1
                if trigger_times > PATIENCE:
                    break
    del predictions, truths

In [20]:
print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:4f}%  ==')
print(f'Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f}')

    ==  Epoch: 20 | Training Loss: 0.549712 | Accuracy: 78.317356%  ==
Epoch: 20 | Validation Loss: 0.565491 | F1 Score: 0.415135


In [22]:
src.utils_BERT.save_best_ckpt(i, model, optimizer, scheduler, val_loss, '2025_PBMC_MainCelltype_scBert_v1_last', ckpt_dir)

In [10]:
acc = []
f1 = []
f1w = []
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
pred_list = pd.Series(['un'] * data_csr.shape[0])

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
for index_train, index_val in sss.split(data_csr, label):
    data_train, label_train = data_csr[index_train], label[index_train]
    data_val, label_val = data_csr[index_val], label[index_val]
    train_dataset = SCDataset(data_train, label_train,CLASS,device)
    val_dataset = SCDataset(data_val, label_val,CLASS,device)

NameError: name 'SCDataset' is not defined

In [13]:
train_sampler = src.utils_BERT.SimpleSampler(train_dataset)
val_sampler = src.utils_BERT.SimpleSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

In [14]:
model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    gene_weight_file='./src/data_BERT/gene2vec_16906.npy',
    local_attn_heads = 0,
    g2v_position_emb = POS_EMBED_USING
)
model.to_out = Identity(dropout=0., h_dim=128, out_dim=label_names.shape[0])

The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2416.)
  q, r = torch.qr(unstructured_block.cpu(), some = True)


In [15]:
device = torch.device("cpu")
ckpt = torch.load(path,map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
for param in model.parameters():
    param.requires_grad = False
for param in model.norm.parameters():
    param.requires_grad = True
for param in model.performer.net.layers[-2].parameters():
    param.requires_grad = True

  ckpt = torch.load(path,map_location=device)


In [16]:
ckpt.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'losses'])

In [17]:
device = torch.device("cuda:1")
model = model.to(device)

In [18]:
# optimizer
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = CosineAnnealingWarmupRestarts(
    optimizer,
    first_cycle_steps=15,
    cycle_mult=2,
    max_lr=LEARNING_RATE,
    min_lr=1e-6,
    warmup_steps=5,
    gamma=0.9
)
loss_fn = nn.CrossEntropyLoss(weight=None).to(device)

In [None]:
trigger_times = 0
max_acc = 0.0
for i in range(1, EPOCHS+1):
    train_loader.sampler.set_epoch(i)
    model.train()
    running_loss = 0.0
    cum_acc = 0.0
    for index, (data, labels) in enumerate(train_loader):
        index += 1
        data, labels = data.to(device), labels.to(device)
        if index % GRADIENT_ACCUMULATION != 0:
            
            logits = model(data)
            loss = loss_fn(logits, labels)
            loss.backward()
            
        if index % GRADIENT_ACCUMULATION == 0:      
            torch.nn.utils.clip_grad_norm_(model.parameters(), int(1e6))
            optimizer.step()
            optimizer.zero_grad()
        running_loss += loss.item()
        softmax = nn.Softmax(dim=-1)
        
        final = softmax(logits)
        final = final.argmax(dim=-1)
        
        #pred_num = labels.size(0)
        #correct_num = torch.eq(final, labels).sum(dim=-1)
        #cum_acc += torch.true_divide(correct_num, pred_num).mean().item()
        cum_acc += torch.eq(final, labels).sum().item() / labels.size(0)
        
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * cum_acc / len(train_loader)
    print(f'    ==  Epoch: {i} | Training Loss: {epoch_loss:.6f} | Accuracy: {epoch_acc:4f}%  ==')
    
    scheduler.step()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        running_loss = 0.0
        predictions = []
        truths = []
        with torch.no_grad():
            for index, (data_v, labels_v) in enumerate(val_loader):
        
                data_v, labels_v = data_v.to(device), labels_v.to(device)
            
                logits = model(data_v)
                loss = loss_fn(logits, labels_v)
                running_loss += loss.item()
                
                softmax = nn.Softmax(dim=-1)
                final_prob = softmax(logits)
                final = final_prob.argmax(dim=-1)
                final[np.amax(np.array(final_prob.cpu()), axis=-1) < UNASSIGN_THRES] = -1
                
                predictions.append(final.cpu())
                truths.append(labels_v.cpu())
            
            predictions = torch.cat(predictions)
            truths = torch.cat(truths)
            no_drop = predictions != -1
            predictions = predictions[no_drop]
            truths = truths[no_drop]
            
            cur_acc = accuracy_score(truths, predictions)
            f1 = f1_score(truths, predictions, average='macro')
            val_loss = running_loss / len(val_loader)

            print(f'Epoch: {i} | Validation Loss: {val_loss:.6f} | F1 Score: {f1:.6f}')
            print(confusion_matrix(truths, predictions))
            print(classification_report(truths,predictions, target_names=label_names.tolist(), digits=4))
            
            if cur_acc > max_acc:
                max_acc = cur_acc
                trigger_times = 0
                src.utils_BERT.save_best_ckpt(i, model, optimizer, scheduler, val_loss, model_name, ckpt_dir)
            else:
                trigger_times += 1
                if trigger_times > PATIENCE:
                    break
    del predictions, truths

In [30]:
ckpt['epoch']

15