# Toturial 1: Fine-tuning CellFM for Cell Type Annotation
In this tutorial, we will demonstrate how to fine-tune a pre-trained CellFM model to perform cell type annotation on a new single-cell dataset. The workflow consists of the following stages:

1. Data Preprocessing

2. Data Loading

3. Model construct

4. Weight Loading

5. Fine-tuning

6. Result Visualization

Before starting, import the following packages and set up the configuration:

In [2]:
import os
#os.chdir("..")
os.chdir("/home/yinghsin/M2LAB/CellFM-torch")
print("Current working directory:", os.getcwd())

import scanpy as sc
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
from layers.utils import *
import pandas as pd
import numpy as np
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

from model import Cell_FM

cfg = Config_80M()
cfg.ecs_threshold = 0.8
cfg.ecs = True
cfg.add_zero = True
cfg.pad_zero = True
cfg.use_bs = 8 #this is batch_size batch size for training
cfg.mask_ratio = 0.5
### Main param ###
cfg.dataset = "Pancrm4"
cfg.feature_col = "cell_type"
cfg.ckpt_path = "/bigdat2/user/shanggny/checkpoint/para80m/6300w_18000_19479-1_38071.ckpt"
cfg.device = "cuda:0"
cfg.epoch = 5
#cfg.num_cls = 1 # by default 之後會改成cfg.num_cls = int(batch["feat"].max().item() + 1)



Current working directory: /home/yinghsin/M2LAB/CellFM-torch


  from anndata import __version__ as anndata_version
  if Version(anndata.__version__) >= Version("0.11.0rc2"):
  if Version(anndata.__version__) >= Version("0.11.0rc2"):


In [3]:
import torch
print("cuda available:", torch.cuda.is_available())
print("cuda device count:", torch.cuda.device_count())
cfg.device = "cuda:0"
#cfg.device = "cpu"

cuda available: True
cuda device count: 1


In [4]:
import os
import glob
import time
import numpy as np
import scipy as sp
import pandas as pd
import scanpy as sc
import pickle as pk
import anndata as ad
import requests as rq
import multiprocessing as mp
from tqdm import tqdm,trange
from functools import partial
from scipy.sparse import csr_matrix as csr
from scipy.sparse import csc_matrix as csc
from multiprocessing import Process,Pool
from sklearn.neighbors import NearestNeighbors as NN
from scipy.spatial.distance import cdist
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Data Preprocessing

The data preprocessing workflow is identical to the original CellFM implementation. Please follow the same steps as described in the original [CellFM documentation](https://github.com/biomed-AI/CellFM/blob/main/tutorials/process.ipynb) to prepare your datasets.

# Load Data
Load an .h5ad single-cell dataset and return a PyTorch DataLoader ready for model training or testing.

In [19]:
adata_path = "/home/yinghsin/M2LAB/CellFM-torch/lung_processed.h5ad"
def load_data(adata_path, mode="train"):
    adata = read_h5ad(adata_path)
    adata.obs['celltype'] = adata.obs['cell_type']

    # 1) cell type label (固定全域 category)
    cat = adata.obs[cfg.feature_col].astype(str).astype("category")
    adata.obs["feat"] = cat.cat.codes.astype(np.int64)

    # ✅ 用「總類別數」，不是 nunique()
    cfg.num_cls = int(len(cat.cat.categories))

    # adata.obs['batch_id'] = 0 #this is for single batch data, 會讓adv_loss變沒意義
    # 1) cell type label
    adata.obs["celltype"] = adata.obs["cell_type"].astype(str)
    adata.obs["feat"] = adata.obs[cfg.feature_col].astype("category").cat.codes.astype(np.int64)
    cfg.num_cls = int(adata.obs["feat"].nunique())
    adata.obs["batch_id"] = adata.obs["batch"].astype(str).astype("category").cat.codes.astype(np.int64)

    cfg.num_batches = int(adata.obs["batch_id"].nunique())

    if mode == "train":
        adata.obs['train'] = 0
        dataset = SCrna(adata, mode="train")
        prep = Prepare(cfg.nonz_len, pad=0, mask_ratio=cfg.mask_ratio)
        loader = build_dataset(
            dataset,
            prep=prep,
            batch_size=cfg.use_bs,
            pad_zero=cfg.pad_zero,
            drop=True,
            shuffle=True
        )
    if mode== "test":
        adata.obs['train'] = 2
        dataset = SCrna(adata, mode="test")
        prep = Prepare(cfg.nonz_len, pad=0, mask_ratio=cfg.mask_ratio)
        loader = build_dataset(
            dataset,
            prep=prep,
            batch_size=cfg.use_bs,
            drop=True,
            shuffle=True
        )
    return loader

train_loader = load_data(
    "/home/yinghsin/M2LAB/CellFM-torch/lung_processed.h5ad",
    mode="train"
)

test_loader = load_data(
    "/home/yinghsin/M2LAB/CellFM-torch/lung_processed.h5ad",
    mode="test"
)


origin shape: (32472, 15148)
origin shape: (32472, 15148)


In [20]:
# check labels range 
train_loader = load_data("/home/yinghsin/M2LAB/CellFM-torch/lung_processed.h5ad", mode="train")
batch = next(iter(train_loader))
labels = batch["feat"].long()

print("labels min/max:", labels.min().item(), labels.max().item())
print("cfg.num_cls:", cfg.num_cls)
print("range ok:", (labels.min() >= 0) and (labels.max() < cfg.num_cls))


origin shape: (32472, 15148)
labels min/max: 2 14
cfg.num_cls: 17
range ok: tensor(True)


In [23]:
import os
os.chdir("/home/yinghsin/M2LAB/CellFM-torch")
print("cwd =", os.getcwd())

cwd = /home/yinghsin/M2LAB/CellFM-torch


# Model construct
To make CellFM more flexible and reusable across different tasks, we retain its core masked recovery module and re-implement it in PyTorch as a class called `Cell_FM`. This class can either:

- Load a set of pre-trained weights and the corresponding configuration, or

- Be trained from scratch.

`Cell_FM` allows fine-tuning on new datasets to perform masked expression recovery and outputs a cls_token representing each cell.

For tasks that require cell type classification, we can simply add a linear classification layer on top of the cls_token. In this tutorial, we use a single hidden-layer linear model for this purpose.

Below, we demonstrate how to leverage the `Cell_FM` module to build a custom PyTorch model suitable for task-specific fine-tuning:

In [24]:
class Finetune_Cell_FM(nn.Module):
    def __init__(self, cfg):
        super(Finetune_Cell_FM, self).__init__()
        self.cfg = cfg
        self.num_cls = cfg.num_cls
        self.extractor = Cell_FM(27855, self.cfg, ckpt_path=self.cfg.ckpt_path, device=self.cfg.device) # n_gene, cfg=config_80M()
        # notion: 27855 is the orignal pre_train gene set of 80M CellFM model
        self.cls = nn.Sequential(
            nn.Linear(self.cfg.enc_dims, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(),
            nn.Linear(128, self.num_cls)
        )
    
    def forward(self, raw_nzdata,
                dw_nzdata,
                ST_feat,
                nonz_gene,
                mask_gene,
                zero_idx):
        
        mask_loss, cls_token = self.extractor(
                raw_nzdata,
                dw_nzdata,
                ST_feat,
                nonz_gene,
                mask_gene,
                zero_idx
            )
        
        cls = self.cls(cls_token)

        return cls, mask_loss, cls_token

# Weight Loading

You can easily load pre-trained weights using the `load_model` function in the `Cell_FM` module. Setting the option to False will initialize the model from scratch.

Note: During fine-tuning, we only unfreeze the `cls.` and `encoder` layers. This strategy reduces the computational cost and memory usage while still allowing effective adaptation to the new dataset.

In [25]:
net = Finetune_Cell_FM(cfg) # 27855
net = Finetune_Cell_FM(cfg).to(cfg.device)

for name, param in net.named_parameters():
    param.requires_grad = "cls." in name or "encoder" in name

print("Trainable parameters:")
for name, param in net.named_parameters():
    if param.requires_grad:
        print(name)
net = net.to(cfg.device)
#net.extractor.load_model(weight=True, moment=False)

net.extractor.load_torch_weight("/home/yinghsin/M2LAB/CellFM-torch/cellfm_80m_pretrained.pt")

optimizer = AdamW([p for p in net.parameters() if p.requires_grad], 
                    lr=1e-4,
                    weight_decay=1e-5)

scaler = GradScaler() 
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

criterion_cls = nn.CrossEntropyLoss()


Trainable parameters:
extractor.net.encoder.0.attn.q_proj.weight
extractor.net.encoder.0.attn.k_proj.weight
extractor.net.encoder.0.attn.v_proj.weight
extractor.net.encoder.0.attn.u_proj.weight
extractor.net.encoder.0.attn.o_proj.weight
extractor.net.encoder.0.ffn.u_proj.weight
extractor.net.encoder.0.ffn.v_proj.weight
extractor.net.encoder.0.ffn.o_proj.weight
extractor.net.encoder.0.post_norm1.weight
extractor.net.encoder.0.post_norm1.bias
extractor.net.encoder.0.post_norm2.weight
extractor.net.encoder.0.post_norm2.bias
extractor.net.encoder.1.attn.q_proj.weight
extractor.net.encoder.1.attn.k_proj.weight
extractor.net.encoder.1.attn.v_proj.weight
extractor.net.encoder.1.attn.u_proj.weight
extractor.net.encoder.1.attn.o_proj.weight
extractor.net.encoder.1.ffn.u_proj.weight
extractor.net.encoder.1.ffn.v_proj.weight
extractor.net.encoder.1.ffn.o_proj.weight
extractor.net.encoder.1.post_norm1.weight
extractor.net.encoder.1.post_norm1.bias
extractor.net.encoder.1.post_norm2.weight
extracto

In [26]:
import random
import json, time, os
from pathlib import Path

import numpy as np
import scanpy as sc
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm 
from scib_metrics import silhouette_label, silhouette_batch, bras
import scib

from TOSICA_scGPT.TOSICA_model import (
    ClsDecoder,
    Mlp_reconstruction,
    ProjectionHead,
    AdversarialHead,
    DomainAffine,
    DSBatchNorm,
)
from TOSICA_scGPT.train import mask_input, sclsc_margin_loss_batch

from scipy.sparse import issparse
import warnings
warnings.filterwarnings("ignore")

import argparse
import pickle

In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CellFMWithMyHead(nn.Module):
    """
    用 CellFM 的 cls_token 當 cell embedding (z)，
    後面接你們原本 4 個 head + 4 個 loss。
    """
    def __init__(
        self,
        cellfm_net,          # 你建立好的 net = Finetune_Cell_FM(cfg)
        emb_dim: int,        # 1536
        num_genes: int,
        num_cell_types: int,
        num_batches: int,
        use_domain_affine: bool = False,
        recons_hidden_dims=(256, 512),
        contras_dim: int = 64,
    ):
        super().__init__()
        self.cellfm = cellfm_net             # 這個裡面有 extractor (Cell_FM)
        self.emb_dim = emb_dim
        self.num_genes = num_genes
        self.num_cell_types = num_cell_types
        self.num_batches = num_batches
        self.use_domain_affine = use_domain_affine

        # 你們原本四個 head（你說你都搬好了，我就直接用同名）
        if use_domain_affine:
            self.domain_affine = DomainAffine(emb_dim, n_domain=num_batches)
        else:
            self.domain_affine = nn.Identity()

        self.cls_head = ClsDecoder(in_features=emb_dim, n_cls=num_cell_types, nlayers=3)

        self.dsbn = DSBatchNorm(emb_dim, n_domain=num_batches)
        self.recon_head = Mlp_reconstruction(
            in_features=emb_dim,
            n_domain=num_batches,
            hidden_dims=list(recons_hidden_dims),
            out_features=num_genes,
            act_layer=nn.ReLU,
            drop=0.2,
        )

        self.contras_head = ProjectionHead(
            input_dim=emb_dim,
            hidden_dim1=512,
            hidden_dim2=128,
            output_dim=contras_dim,
        )

        self.adv_head = AdversarialHead(in_dim=emb_dim, n_batch=num_batches)

    def encode(self, batch):
        """
        batch 是 CellFM dataloader 吐出來的那包 dict：
        raw_nzdata, dw_nzdata, ST_feat, nonz_gene, mask_gene, zero_idx, feat, batch_id...
        """
        # CellFM tutorial 的 net(...) 回傳 cls_token
        cls_logits, mask_loss, cls_token = self.cellfm(
            raw_nzdata=batch["raw_nzdata"],
            dw_nzdata=batch["dw_nzdata"],
            ST_feat=batch["ST_feat"],
            nonz_gene=batch["nonz_gene"],
            mask_gene=batch["mask_gene"],
            zero_idx=batch["zero_idx"],
        )

        z = cls_token  # [B, 1536]
        # 如果你們要做 domain affine，就用 batch_id
        if self.use_domain_affine:
            z = self.domain_affine(z, batch["batch_id"])

        return z, mask_loss, cls_logits

    def forward(self, batch, mode="cls", λ=1.0):
        z, mask_loss, _ = self.encode(batch)

        if mode == "cls":
            logits = self.cls_head(z)
            return z, logits, mask_loss

        if mode == "recons":
            recons = self.recon_head(z, batch["batch_id"])
            return z, recons, mask_loss

        if mode == "contras":
            proj = self.contras_head(z)
            return z, proj, mask_loss

        if mode == "adv":
            batch_logits = self.adv_head(z, λ=λ)
            return z, batch_logits, mask_loss

        raise ValueError(f"Unknown mode={mode}")

    def forward_adv(self, z: torch.Tensor, λ: float = 1.0):
        """
        跟原本 TOSICA 一樣的介面： core(model).forward_adv(latent, λ)
        """
        return self.adv_head(z, λ=λ)

    # ===================== loss functions =====================

    @staticmethod
    def loss_classification(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        return F.cross_entropy(logits, labels)

    @staticmethod
    def loss_reconstruction(
        recons: torch.Tensor,
        target: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # target: 原始 expression（log1p / scale 後都可以，但要一致）
        if mask is None:
            return F.mse_loss(recons, target)
        else:
            return F.mse_loss(recons[mask], target[mask])

    @staticmethod
    def loss_adversarial(
        batch_logits: torch.Tensor,
        batch_labels: torch.Tensor,
    ) -> torch.Tensor:
        return F.cross_entropy(batch_logits, batch_labels)

    @staticmethod
    def loss_contrastive(
        z: torch.Tensor,
        labels: torch.Tensor,
        batch: torch.Tensor,
        epoch: int,
        margin: float = 1.0,
        alpha1: float = 1.0,
        alpha2: float = 1.0,
        alpha3: float = 1.0,
        alpha4: float = 1.0,
    ) -> torch.Tensor:
        """
        用你 train.py 裡的 sclsc_margin_loss_batch
        """
        z = F.normalize(z, dim=-1)
        return sclsc_margin_loss_batch(
            z,
            labels,
            batch,
            epoch,
            margin=margin,
            alpha1=alpha1,
            alpha2=alpha2,
            alpha3=alpha3,
            alpha4=alpha4,
        )

    # ===================== 一次算四種 loss cellfm version==============
    # this is the 0119 new version
    def compute_total_loss_from_batch(
        self,
        batch: dict,
        epoch: int,
        λ_adv: float = 1.0,
        w_cls: float = 1.0,
        w_contras: float = 0.3,
        w_recons: float = 0.3,
        w_adv: float = 0.2,
        **kwargs,
    ):
        labels = batch["feat"].long()
        batch_labels = batch["batch_id"].long()

        # ✅ 只 encode 一次
        z, mask_loss, _ = self.encode(batch)   # z: [B,1536]

        # 1) cls head
        cls_logits = self.cls_head(z)
        cls_loss = self.loss_classification(cls_logits, labels)

        # 2) recons: 用 CellFM mask_loss（最穩）
        recons_loss = mask_loss

        # 3) adv head
        batch_logits = self.adv_head(z, λ=λ_adv)
        batch_loss = self.loss_adversarial(batch_logits, batch_labels)

        # 4) contras head
        proj = self.contras_head(z)
        contras_loss = self.loss_contrastive(proj, labels, batch_labels, epoch, **kwargs)

        total = w_cls*cls_loss + w_recons*recons_loss + w_adv*batch_loss + w_contras*contras_loss

        return {
            "loss": total,
            "cls_loss": cls_loss,
            "recons_loss": recons_loss,
            "batch_loss": batch_loss,
            "contras_loss": contras_loss,
            # 方便 debug / print
            "cls_logits": cls_logits,
            "batch_logits": batch_logits,
        }



In [None]:
# see lung.h5ad有何column, to set batch_id for meaningful adversarial loss
# load data處做過，這邊可以不做或不跑
import scanpy as sc

adata = sc.read_h5ad("/home/yinghsin/M2LAB/CellFM-torch/lung_processed.h5ad")
print(adata.obs.columns)
adata = sc.read_h5ad(adata_path)

# cell type（你已經在做）
adata.obs["str_batch"] = adata.obs["batch"].astype(str)
adata.obs["batch_id"]  = adata.obs["str_batch"].astype("category").cat.codes
adata.var["gene_name"] = adata.var_names.tolist()

adata.obs["str_cell_type"] = adata.obs["cell_type"].astype(str).tolist()
adata.obs["cell_type_id"] = adata.obs["str_cell_type"].astype("category").cat.codes
adata = sc.read_h5ad(adata_path)

# cell type（你原本就在做）
adata.obs["celltype"] = adata.obs["cell_type"]

# batch id（新增這個）
adata.obs["batch_id"] = (
    adata.obs["batch"]
    .astype(str)
    .astype("category")
    .cat.codes
)

# 設定給 config
cfg.num_batches = adata.obs["batch_id"].nunique()


# ✅ batch 分群（重點）
# adata.obs["batch_id"] = adata.obs["sample_id"].astype("category").cat.codes
# for col in adata.obs.columns:
#     print(col, adata.obs[col].nunique())
# batch_id = batch['batch_id'].to(cfg.device)
batch = next(iter(train_loader))
# print(batch["batch_id"].unique())

Index(['dataset', 'location', 'nGene', 'nUMI', 'patientGroup', 'percent.mito',
       'protocol', 'sanger_type', 'size_factors', 'sampling_method', 'batch',
       'cell_type', 'donor', 'n_genes', 'train'],
      dtype='object')


In [13]:
batch = next(iter(train_loader))
print("batch_id unique:", batch["batch_id"].unique())
print("feat min/max:", batch["feat"].min().item(), batch["feat"].max().item())
print("cfg.num_cls / num_batches:", cfg.num_cls, cfg.num_batches)


batch_id unique: tensor([ 1,  5,  9, 12, 14, 15])
feat min/max: 0.0 14.0
cfg.num_cls / num_batches: 17 16


In [None]:
adata = sc.read_h5ad(adata_path)

# cell type（你原本就在做）
adata.obs["celltype"] = adata.obs["cell_type"]

# batch id（新增這個）
adata.obs["batch_id"] = (
    adata.obs["batch"]
    .astype(str)
    .astype("category")
    .cat.codes
)

# 設定給 config
cfg.num_batches = adata.obs["batch_id"].nunique()
num_batches = int(adata.obs["batch_id"].nunique())
print("num_batches =", num_batches)
cfg.num_cls = int(batch["feat"].max().item() + 1)
print("cfg.num_cls =", cfg.num_cls)

num_batches = 16
cfg.num_cls = 15


In [29]:
adata.obs["feat"] = (
    adata.obs[cfg.feature_col].astype("category").cat.codes.astype(np.int64)
)
labels = batch["feat"].to(cfg.device).long()
batch_labels = batch["batch_id"].to(cfg.device).long()



In [17]:
cfg.device = "cpu"
batch = next(iter(train_loader))
need = ["raw_nzdata","dw_nzdata","ST_feat","nonz_gene","mask_gene","zero_idx","feat","batch_id"]
for k in need:
    batch[k] = batch[k].to(cfg.device)
batch["feat"] = batch["feat"].long()
batch["batch_id"] = batch["batch_id"].long()

model = CellFMWithMyHead(...).to(cfg.device)
z, logits, mask_loss = model(batch, mode="cls")
print(z.shape, logits.shape, mask_loss)


TypeError: CellFMWithMyHead.__init__() missing 4 required positional arguments: 'emb_dim', 'num_genes', 'num_cell_types', and 'num_batches'

In [18]:
batch = next(iter(train_loader))
labels = batch["feat"].long()
print("labels unique:", torch.unique(labels))
print("labels min/max:", labels.min().item(), labels.max().item())
print("num_cls:", cfg.num_cls)
print("range ok:", (labels.min() >= 0) and (labels.max() < cfg.num_cls))


labels unique: tensor([ 0,  2,  3,  5,  9, 16])
labels min/max: 0 16
num_cls: 15
range ok: tensor(False)


In [45]:
cfg.device = "cuda:0"
batch = next(iter(train_loader))
need = ["raw_nzdata","dw_nzdata","ST_feat","nonz_gene","mask_gene","zero_idx","feat","batch_id"]
for k in need:
    batch[k] = batch[k].to(cfg.device)
batch["feat"] = batch["feat"].long()
batch["batch_id"] = batch["batch_id"].long()

model = CellFMWithMyHead(
    cellfm_net=net,
    emb_dim=cfg.enc_dims,   # 1536
    num_genes=200,          # 先隨便填，下一步我再幫你對齊
    num_cell_types=17,
    num_batches=16,                  # 你現在好像 batch_id 都設 0，先填 1
).to(cfg.device)


z, logits, mask_loss = model(batch, mode="cls")
print(z.shape, logits.shape, mask_loss)

torch.Size([8, 1536]) torch.Size([8, 17]) tensor(29.7994, device='cuda:0', grad_fn=<AddBackward0>)


In [None]:
batch = next(iter(train_loader))

print("nonz_gene max =", batch["nonz_gene"].max().item())
print("nonz_gene min =", batch["nonz_gene"].min().item())
print("n_genes (model) =", net.extractor.net.n_genes)

print(batch["feat"])
print(batch["feat"].dtype)
print(model.num_cell_types)


nonz_gene max = 27781
nonz_gene min = 0
n_genes (model) = 27855


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
for k, v in batch.items():
    if torch.is_tensor(v):
        print(
            k,
            v.dtype,
            v.shape,
            v.min().item(),
            v.max().item()
        )

print("num_cls =", cfg.num_cls)
print("cfg.num_cls =", cfg.num_cls)
print("cfg.num_batches =", cfg.num_batches)

print("batch_labels min/max =", batch_labels.min().item(), batch_labels.max().item()) #have error in this line
print("label min/max =", labels.min().item(), labels.max().item()) #have error in this line


raw_nzdata torch.float32 torch.Size([2, 2048]) 0.0 6.4710493087768555
dw_nzdata torch.float32 torch.Size([2, 2048, 2]) 0.0 6.4710493087768555
ST_feat torch.float32 torch.Size([2, 2]) 0.0305292047560215 0.14841999113559723
nonz_gene torch.int32 torch.Size([2, 2048]) 0 25688
mask_gene torch.float32 torch.Size([2, 2048]) 0.0 1.0
zero_idx torch.float32 torch.Size([2, 2048]) 0.0 1.0
celltype_label torch.int64 torch.Size([2]) 14 16
batch_id torch.int64 torch.Size([2]) 0 0
feat torch.float32 torch.Size([2]) 14.0 16.0
num_cls = 10
cfg.num_cls = 10
cfg.num_batches = 16


NameError: name 'batch_labels' is not defined

In [None]:
# 先拿一個 batch（CPU）
batch = next(iter(train_loader))

# 只用 CPU 印
print("keys:", batch.keys())
print("feat dtype/shape/min/max:",
      batch["feat"].dtype, batch["feat"].shape,
      batch["feat"].min().item(), batch["feat"].max().item())

print("batch_id dtype/shape/min/max:",
      batch["batch_id"].dtype, batch["batch_id"].shape,
      batch["batch_id"].min().item(), batch["batch_id"].max().item())

print("cfg.num_cls =", cfg.num_cls)
print("cfg.num_batches =", cfg.num_batches)

print("check labels range ok? ", batch["feat"].max().item() < cfg.num_cls)
print("check batch range ok?  ", batch["batch_id"].max().item() < cfg.num_batches)



keys: dict_keys(['raw_nzdata', 'dw_nzdata', 'ST_feat', 'nonz_gene', 'mask_gene', 'zero_idx', 'celltype_label', 'batch_id', 'feat'])
feat dtype/shape/min/max: torch.float32 torch.Size([2]) 2.0 16.0
batch_id dtype/shape/min/max: torch.int64 torch.Size([2]) 0 0
cfg.num_cls = 10
cfg.num_batches = 16
check labels range ok?  False
check batch range ok?   True


In [None]:
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
cfg.device = "cuda:0"

scaler = GradScaler()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

w_cls, w_recons, w_contras, w_adv = 1.0, 0.3, 0.3, 0.2

model.train()
for step, batch in enumerate(train_loader):
    # 1) move to device
    for k, v in batch.items():
        if torch.is_tensor(v):
            batch[k] = v.to(cfg.device)

    batch["feat"] = batch["feat"].long()
    batch["batch_id"] = batch["batch_id"].long()

    optimizer.zero_grad(set_to_none=True)

    with autocast():
        out = model.compute_total_loss_from_batch(
            batch,
            epoch=10,
            λ_adv=1.0,
            w_cls=w_cls,
            w_recons=w_recons,
            w_contras=w_contras,
            w_adv=w_adv,
        )
        loss = out["loss"]

    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

    if step % 20 == 0:
        labels = batch["feat"]
        batch_labels = batch["batch_id"]

        pred = out["cls_logits"].argmax(dim=1)
        acc = (pred == labels).float().mean().item()

        pred_batch = out["batch_logits"].argmax(dim=1)
        batch_acc = (pred_batch == batch_labels).float().mean().item()

        print(
            f"step={step} loss={out['loss'].item():.4f} "
            f"cls={out['cls_loss'].item():.4f} recons={out['recons_loss'].item():.4f} "
            f"contras={out['contras_loss'].item():.4f} adv={out['batch_loss'].item():.4f} "
            f"acc={acc:.3f} batch_acc={batch_acc:.3f}"
        )


RuntimeError: value cannot be converted to type at::Half without overflow