In [1]:
%reload_ext autoreload
%autoreload 2
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from scipy.sparse import issparse
import numpy as np
from tqdm import tqdm
import wandb
import scanpy as sc
import pandas as pd
import os
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from sklearn.metrics import confusion_matrix

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics





  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="clonotype_genes_filtered_counts_data",
    do_train=True,
    load_model="../model",
    mask_ratio=0.5,
    epochs=1,
    n_bins=101,
    MVC=True, # Masked value prediction for cell embedding
    ecs_thres=0.0, # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=0.0,
    lr=1e-4,
    batch_size=8,
    log_interval=100,
    layer_size=128,
    nlayers=4,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead=4,  # number of heads in nn.MultiheadAttention
    dropout=0.2,  # dropout probability
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=5,
    fast_transformer=True,
    pre_norm=False,
    amp=True,  # Automatic Mixed Precision
    include_zero_gene = False,
    freeze = False, #freeze
    DSBN = False,  # Domain-spec batchnorm
)

In [3]:
run = wandb.init(
    config=hyperparameter_defaults,
    project="Fine tune scGPT for generative task",
    # reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mkristint[0m ([33mmackall_lab[0m). Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import HTML, display  # type: ignore


{'seed': 0, 'dataset_name': 'clonotype_genes_filtered_counts_data', 'do_train': True, 'load_model': '../model', 'mask_ratio': 0.5, 'epochs': 1, 'n_bins': 101, 'MVC': True, 'ecs_thres': 0.0, 'dab_weight': 0.0, 'lr': 0.0001, 'batch_size': 8, 'log_interval': 100, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'include_zero_gene': False, 'freeze': False, 'DSBN': False}


In [4]:
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"  # for masked values, now it should always be auto

include_zero_gene = config.include_zero_gene  # if True, include zero genes among hvgs in the training
max_seq_len = 3001
n_bins = config.n_bins

# input/transfer representation
input_style = "binned"  # "normed_raw", "log1p", or "binned"
output_style = "binned"  # "normed_raw", "log1p", or "binned"

# settings for training
MLM = True  # whether to use masked language modeling, currently it is always on.
CLS = False  # celltype classification objective
ADV = False  # Adversarial training for batch correction
CCE = False  # Contrastive cell embedding objective
MVC = config.MVC  # Masked value prediction for cell embedding
ECS = config.ecs_thres > 0  # Elastic cell similarity objective
DAB = False  # Domain adaptation by reverse backpropagation, set to 2 for separate optimizer
INPUT_BATCH_LABELS = False  # TODO: have these help MLM and MVC, while not to classifier
input_emb_style = "continuous"  # "category" or "continuous" or "scaling"
cell_emb_style = "cls"  # "avg-pool" or "w-pool" or "cls"
adv_E_delay_epochs = 0  # delay adversarial training on encoder for a few epochs
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight

explicit_zero_prob = MLM and include_zero_gene  # whether explicit bernoulli for zeros
do_sample_in_train = False and explicit_zero_prob  # sample the bernoulli in training

per_seq_batch_sample = False

# settings for optimizer
lr = config.lr  # TODO: test learning rate ratio between two tasks
lr_ADV = 1e-3  # learning rate for discriminator, used when ADV is True
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 1

# settings for the model
fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"  # "linear" or "flash"
embsize = config.layer_size  # embedding dimension
d_hid = config.layer_size  # dimension of the feedforward network in TransformerEncoder
nlayers = config.nlayers  # number of TransformerEncoderLayer in TransformerEncoder
nhead = config.nhead  # number of heads in nn.MultiheadAttention
dropout = config.dropout  # dropout probability

# logging
log_interval = 100  # iterations
save_eval_interval = config.save_eval_interval  # epochs
do_eval_scib_metrics = True

In [5]:
# %% validate settings
assert input_style in ["normed_raw", "log1p", "binned"]
assert output_style in ["normed_raw", "log1p", "binned"]
assert input_emb_style in ["category", "continuous", "scaling"]
if input_style == "binned":
    if input_emb_style == "scaling":
        raise ValueError("input_emb_style `scaling` is not supported for binned input.")
elif input_style == "log1p" or input_style == "normed_raw":
    if input_emb_style == "category":
        raise ValueError(
            "input_emb_style `category` is not supported for log1p or normed_raw input."
        )

if input_emb_style == "category":
    mask_value = n_bins + 1
    pad_value = n_bins  # for padding gene expr values
    n_input_bins = n_bins + 2
else:
    mask_value = -1
    pad_value = -2
    n_input_bins = n_bins

if ADV and DAB:
    raise ValueError("ADV and DAB cannot be both True.")
DAB_separate_optim = True if DAB > 1 else False

In [6]:
dataset_name = config.dataset_name
save_dir = Path(f"./save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

save to save/dev_clonotype_genes_filtered_counts_data-Aug10-21-50


In [7]:
Tcell_data = sc.read_h5ad(Path.cwd()/"Clono_MT_overlap_gene_filtered_merged_counts.h5ad")

In [8]:
Tcell_data.layers["counts"] = Tcell_data.X.copy()

In [9]:
# Normalize to median total count
sc.pp.normalize_total(Tcell_data)
sc.pp.log1p(Tcell_data)

In [10]:
sc.pp.highly_variable_genes(Tcell_data, n_top_genes=3000, batch_key="sample_source", inplace=False)

Unnamed: 0_level_0,means,dispersions,dispersions_norm,highly_variable,highly_variable_nbatches,highly_variable_intersection
gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
11902,0.347741,4.386477,13.132321,True,4,True
11900,1.084330,4.461092,8.271159,True,4,True
7927,0.076210,4.908865,7.876472,True,4,True
11899,0.765927,3.163127,7.151582,True,4,True
4341,0.260919,2.420961,6.309089,True,4,True
...,...,...,...,...,...,...
2193,0.000054,-1.218939,-2.775038,False,0,False
8396,0.000207,-0.976199,-2.780292,False,0,False
9481,0.000164,-1.023284,-2.796171,False,0,False
6739,0.000065,-1.344594,-2.835213,False,0,False


In [11]:
Tcell_hvgs = Tcell_data[:, Tcell_data.var["highly_variable"]].copy()

In [12]:
sc.tl.pca(Tcell_hvgs, layer="counts", svd_solver="arpack")

In [13]:
sc.pp.neighbors(Tcell_hvgs, use_rep="X_pca", n_neighbors=15, metric="cosine")
sc.tl.umap(Tcell_hvgs)

In [14]:
# %% Preprocess data
if dataset_name == "clonotype_genes_filtered_counts_data":
    # data_dir = Path("./sample_data_cell_ann")
    adata = Tcell_hvgs
    
              
    adata.var.set_index("gene_name", inplace=True) # using gene name directly as index
    # fixed typo here adata_test.var instead of adata.var
    
    data_is_raw = True
    filter_gene_by_counts = False
    
    
                
# make the batch category column

adata.var["gene_name"] = adata.var.index.tolist()

In [15]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    shutil.copy(vocab_file, save_dir / "vocab.json")
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]

scGPT - INFO - match 4961/4989 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from ../model/best_model.pt, the model args will override the config ../model/args.json.


In [16]:
# set up the preprocessor, use the args to config the workflow
preprocessor = Preprocessor(
    use_key="counts",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=filter_gene_by_counts,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=False,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)




preprocessor(adata, batch_key=None)


scGPT - INFO - Normalizing total counts ...


  view_to_actual(adata)


scGPT - INFO - Log1p transforming ...
scGPT - INFO - Binning data ...


In [17]:
input_layer_key = {  # the values of this map coorespond to the keys in preprocessing
    "normed_raw": "X_normed",
    "log1p": "X_normed",
    "binned": "X_binned",
}[input_style]
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()





# (
#     train_data,
#     valid_data,
#     train_response_labels, 
#     valid_response_labels,
#     train_batch_labels,
#     valid_batch_labels,
# ) = train_test_split(
#     all_counts, response_labels, batch_ids, test_size=0.2, shuffle=True
# )

In [18]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)

In [19]:
tokenized_train = tokenize_and_pad_batch(
    all_counts,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)

logger.info(
    f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
)



scGPT - INFO - train set number of samples: 25066, 
	 feature length: 3001


In [20]:
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    # masked_values_valid = random_mask_value(
    #     tokenized_valid["values"],
    #     mask_ratio=mask_ratio,
    #     mask_value=mask_value,
    #     pad_value=pad_value,
    # )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train = tokenized_train["genes"]
        
    input_values_train = masked_values_train
    target_values_train, target_values_valid = tokenized_train["values"]
        

    # tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    # tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    # tensor_response_labels_train = torch.from_numpy(train_response_labels).long()
    # tensor_response_labels_valid = torch.from_numpy(valid_response_labels).long()

    # if sort_seq_batch:  # TODO: update to random pick seq source in each traning batch
    #     train_sort_ids = np.argsort(train_batch_labels)
    #     input_gene_ids_train = input_gene_ids_train[train_sort_ids]
    #     input_values_train = input_values_train[train_sort_ids]
    #     target_values_train = target_values_train[train_sort_ids]
    #     tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]
    #     tensor_response_labels_train = tensor_response_labels_train[train_sort_ids]

    #     valid_sort_ids = np.argsort(valid_batch_labels)
    #     input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
    #     input_values_valid = input_values_valid[valid_sort_ids]
    #     target_values_valid = target_values_valid[valid_sort_ids]
    #     tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]
    #     tensor_response_labels_valid = tensor_response_labels_valid[valid_sort_ids]

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        # "batch_labels": tensor_batch_labels_train,
        # "response_labels": tensor_response_labels_train,
    }
    

    return train_data_pt

# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    if num_workers == 0:
        num_workers = min(len(os.sched_getaffinity(0)), batch_size // 2)

    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader


In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=3,
    n_cls= 1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=MVC,
    do_dab=DAB,
    use_batch_labels=INPUT_BATCH_LABELS,
    # num_batch_labels=num_batch_types,
    domain_spec_batchnorm=config.DSBN,
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    ecs_threshold=ecs_threshold,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=fast_transformer,
    fast_transformer_backend=fast_transformer_backend,
    pre_norm=config.pre_norm,

    ## transfer learning
    # transfer_learning=True,
    # transfer_hidden_dim = 256,
    # transfer_dropout = 0.2,
)
if config.load_model is not None:
    try:
        model.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

# Freeze all pre-decoder weights
for name, para in model.named_parameters():
    print("-"*20)
    print(f"name: {name}")
    if config.freeze and "encoder" in name and "transformer_encoder" not in name:
    # if config.freeze and "encoder" in name:
        print(f"freezing weights for: {name}")
        para.requires_grad = False

post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
wandb.log(
        {
            "info/pre_freeze_param_count": pre_freeze_param_count,
            "info/post_freeze_param_count": post_freeze_param_count,
        },
)

model.to(device)
wandb.watch(model)

# if ADV:
#     discriminator = AdversarialDiscriminator(
#         d_model=embsize,
#         n_cls=num_batch_types,
#     ).to(device)




scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Si

[]

In [22]:
def verify_trainable(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable: {name}")
        else:
            print(f"Frozen: {name}")

# freeze_original_model(model)
verify_trainable(model)

Trainable: encoder.embedding.weight
Trainable: encoder.enc_norm.weight
Trainable: encoder.enc_norm.bias
Trainable: value_encoder.linear1.weight
Trainable: value_encoder.linear1.bias
Trainable: value_encoder.linear2.weight
Trainable: value_encoder.linear2.bias
Trainable: value_encoder.norm.weight
Trainable: value_encoder.norm.bias
Trainable: transformer_encoder.layers.0.self_attn.in_proj_weight
Trainable: transformer_encoder.layers.0.self_attn.in_proj_bias
Trainable: transformer_encoder.layers.0.self_attn.out_proj.weight
Trainable: transformer_encoder.layers.0.self_attn.out_proj.bias
Trainable: transformer_encoder.layers.0.linear1.weight
Trainable: transformer_encoder.layers.0.linear1.bias
Trainable: transformer_encoder.layers.0.linear2.weight
Trainable: transformer_encoder.layers.0.linear2.bias
Trainable: transformer_encoder.layers.0.norm1.weight
Trainable: transformer_encoder.layers.0.norm1.bias
Trainable: transformer_encoder.layers.0.norm2.weight
Trainable: transformer_encoder.layers

In [23]:
criterion = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
criterion_dab = nn.CrossEntropyLoss()

# new criterion for transfer learning
criterion_transfer = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(
    model.parameters(), lr=lr, weight_decay=0.0001, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, schedule_interval, gamma=config.schedule_ratio
)
if DAB_separate_optim:
    optimizer_dab = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler_dab = torch.optim.lr_scheduler.StepLR(
        optimizer_dab, schedule_interval, gamma=config.schedule_ratio
    )
if ADV:
    criterion_adv = nn.CrossEntropyLoss()  # consider using label smoothing
    optimizer_E = torch.optim.Adam(model.parameters(), lr=lr_ADV)
    scheduler_E = torch.optim.lr_scheduler.StepLR(
        optimizer_E, schedule_interval, gamma=config.schedule_ratio
    )
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr_ADV)
    scheduler_D = torch.optim.lr_scheduler.StepLR(
        optimizer_D, schedule_interval, gamma=config.schedule_ratio
    )

scaler = torch.cuda.amp.GradScaler(enabled=config.amp)



In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import wandb
import logging

logger = logging.getLogger(__name__)

def train(model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, criterion: nn.Module, device: torch.device, epoch: int, config: dict) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss = 0.0
    start_time = time.time()

    num_batches = len(loader)
    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device)
        target_values = batch_data["target_values"].to(device)
        input_values = batch_data["values"].to(device)
        
        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])

        

        with torch.cuda.amp.autocast(enabled=config['amp']):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=None,
                MVC=True,
                do_sample=False,
            )

            masked_positions = input_values.eq(mask_value)  # the postions to predict
            loss = 0.0
            metrics_to_log = {}
            if MVC:
                loss_mvc = criterion(
                    output_dict["mvc_output"], target_values, masked_positions
                )
                metrics_to_log.update({"train/mvc": loss_mvc.item()})
                loss = loss + loss_mvc

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        wandb.log(metrics_to_log)     

        total_loss += loss.item()

        if batch % config['log_interval'] == 0 and batch > 0:
            lr = optimizer.param_groups[0]['lr']
            ms_per_batch = (time.time() - start_time) * 1000 / config['log_interval']
            cur_loss = total_loss / config['log_interval']
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f}"
            )
            wandb.log({
                "train/loss": cur_loss,
                "train/lr": lr,
                "train/ms_per_batch": ms_per_batch,
                "epoch": epoch,
                "batch": batch,
                "train/progress": batch / num_batches,
            })
            total_loss = 0
            start_time = time.time()

    return total_loss / num_batches

def fine_tune(model: nn.Module, train_loader: DataLoader, config: dict):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=config.schedule_ratio)

    save_dir = Path.cwd() / "best_finetuned_models"
    save_dir.mkdir(parents=True, exist_ok=True)
    

    best_loss = float('inf')
    best_model = None

    for epoch in range(1, config['epochs'] + 1):
        epoch_loss = train(model, train_loader, optimizer, criterion, device, epoch, config)
        scheduler.step()

        # Log epoch-level metrics
        wandb.log({
            "epoch": epoch,
            "epoch_loss": epoch_loss,
            "learning_rate": optimizer.param_groups[0]['lr']
        })

        logger.info(f"| end of epoch {epoch:3d} | average loss {epoch_loss:5.2f}")
        
        # Save the best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            
            best_model_state = model.state_dict()
            
            save_path = save_dir / f'best_model_epoch_{epoch}_loss_{best_loss:.4f}.pth'
            try:
                torch.save(best_model_state, save_path)
                logger.info(f"Saving best model to {save_path}")
                wandb.log({
                "best_model/epoch": epoch,
                "best_model/loss": best_loss,
                "best_model/save_path": save_path
            })
            except Exception as e:
                logger.error(f"Failed to save model {e}")

            

    wandb.finish()

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model




tokenized_trn = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=include_zero_gene,
    )

input_values_test = random_mask_value(
        tokenized_trn["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )

trn_data_pt = {
        "gene_ids": tokenized_trn["genes"],
        "values": input_values_test,
        "target_values": tokenized_trn["values"],
        # "batch_labels": torch.from_numpy(batch_ids).long(),
        
    }
train_loader = DataLoader(
        dataset=SeqDataset(trn_data_pt),
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=min(len(os.sched_getaffinity(0)), eval_batch_size // 2),
        pin_memory=True,
    )



In [25]:
best_model = fine_tune(model, train_loader, config)

Failed to save model Object of type PosixPath is not JSON serializable
  from IPython.core.display import HTML, display  # type: ignore


0,1
batch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch_loss,▁
info/post_freeze_param_count,▁
info/pre_freeze_param_count,▁
learning_rate,▁
train/loss,█▄▃▃▄▃▃▂▁▂▁▁▂▂▅▅▃▃▆▆▇▄▄▄▃▅▅▅▄▃▄
train/lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/ms_per_batch,█▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁
train/mvc,▅▄▃▃▄▄▄▂▁▃▄▂▂▁▄▁▄▂▄█▃▄▅▅▅▅▅▅▄▄▄▅▃▅▂▅▃▃▂▄

0,1
batch,3100.0
epoch,1.0
epoch_loss,5.24924
info/post_freeze_param_count,51856898.0
info/pre_freeze_param_count,51856898.0
learning_rate,9e-05
train/loss,508.37923
train/lr,0.0001
train/ms_per_batch,591.39114
train/mvc,500.90289


In [26]:
best_model

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [27]:
import os

def save_finetuned_model(model, save_path, model_name):
    """
    Save the finetuned model using both methods: entire model and state dict.
    
    Args:
    model (torch.nn.Module): The finetuned model to save
    save_path (str): Directory to save the model
    model_name (str): Name to use for the saved model files
    """
    # Ensure the save directory exists
    os.makedirs(save_path, exist_ok=True)
    
    # 1. Save the entire model
    entire_model_path = os.path.join(save_path, f"{model_name}_entire.pth")
    torch.save(model, entire_model_path)
    print(f"Entire model saved to {entire_model_path}")
    
    # 2. Save only the state dict
    state_dict_path = os.path.join(save_path, f"{model_name}_state_dict.pth")
    torch.save(model.state_dict(), state_dict_path)
    print(f"Model state dict saved to {state_dict_path}")


save_dir = Path.cwd() / "best_finetuned_models"
save_finetuned_model(best_model, save_dir, "finetuned_generative_task")

Entire model saved to /home/bench-user/scGPT/Tcell_GPT/tutorials/best_finetuned_models/finetuned_generative_task_entire.pth
Model state dict saved to /home/bench-user/scGPT/Tcell_GPT/tutorials/best_finetuned_models/finetuned_generative_task_state_dict.pth
