In [1]:
import warnings
warnings.filterwarnings("ignore")
import json
import multiprocessing
import os
import platform
import copy
import gc
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
import pickle
import torch
from anndata import AnnData
import scanpy as sc
import seaborn as sns
import numpy as np
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from sklearn.metrics import confusion_matrix
import scgpt as scg
from scgpt.model import TransformerModel
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value,tokenize_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics
from tqdm import tqdm
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)

In [2]:
sc.set_figure_params(figsize=(6, 6))

In [3]:
hyperparameter_defaults = dict(
    seed=0,
    do_train=True,
    load_model= "best_model.pt",
    mask_ratio=0.15,
    epochs=20,
    n_bins=51,
    MVC=False, 
    ecs_thres=0,  
    dab_weight=0,  
    lr=1e-4,
    batch_size=32, 
    layer_size=256,  
    nlayers=6,
    nhead=8,
    dropout=0.2,
    schedule_ratio=0.9,
    save_eval_interval=10,
    fast_transformer=False, 
    pre_norm=True,  
    amp=torch.cuda.is_available(),
    include_zero_gene=False,
    freeze=False,  
    DSBN=False,  
)


In [4]:
class DotDict(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__

config = DotDict(hyperparameter_defaults)
set_seed(config.seed)

In [5]:
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"
include_zero_gene = config.include_zero_gene
max_seq_len = 1201
n_bins = config.n_bins
input_style="binned"
output_style = "binned"

MLM = False
CLS = True
ADV = False
CCE = False
MVC = config.MVC
ECS = config.ecs_thres > 0
DAB = False
INPUT_BATCH_LABELS = False
input_emb_style = "continuous"
cell_emb_style = "cls"
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight

explicit_zero_prob = MLM and include_zero_gene
do_sample_in_train = False and explicit_zero_prob

In [6]:
lr = config.lr
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 2

fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"
embsize = config.layer_size
d_hid = config.layer_size
nlayers = config.nlayers
nhead = config.nhead
dropout = config.dropout

log_interval = 100
save_eval_interval = config.save_eval_interval
do_eval_scib_metrics = True

In [7]:
mask_value = -2
pad_value = -1
n_input_bins = n_bins

In [8]:
DAB_separate_optim = True if DAB > 1 else False

In [9]:
dataset_name = "Melonoma"
save_dir = Path(f"./save/dev_classification_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

save to save/dev_classification_Melonoma-Jul02-16-03


In [10]:
train = sc.read_h5ad("Main_train.h5ad")
test = sc.read_h5ad("Main_test.h5ad")

In [11]:
filter_gene_by_counts = False

In [12]:
train.var["gene_name"] = train.var.index.tolist()
test.var["gene_name"] = test.var.index.tolist()

In [2]:
model_config_file = "model/args.json"
model_file = "model/best_model.pt"
vocab_file = "model/vocab.json"

In [14]:
vocab = GeneVocab.from_file(vocab_file)
shutil.copy(vocab_file,save_dir/"vocab.json")

PosixPath('save/dev_classification_Melonoma-Jul02-16-03/vocab.json')

In [15]:
special_tokens

['<pad>', '<cls>', '<eoc>']

In [16]:
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

In [17]:
train.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in train.var["gene_name"]]

In [18]:
gene_ids_in_vocab = np.array(train.var["id_in_vocab"])
logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )

scGPT - INFO - match 1200/1200 genes in vocabulary of size 60697.


In [19]:
train = train[:,train.var["id_in_vocab"]>=0]

In [20]:
train

View of AnnData object with n_obs × n_vars = 5503 × 1200
    obs: 'cells', 'samples', 'cell.types', 'treatment.group', 'Cohort', 'no.of.genes', 'no.of.reads'
    var: 'gene_name', 'id_in_vocab'

In [21]:
with open(model_config_file, "r") as f:
        model_configs = json.load(f)
logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )

scGPT - INFO - Resume model from best_model.pt, the model args will override the config args.json.


In [22]:
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = 12
n_layers_cls = model_configs["n_layers_cls"]

In [23]:
flat = train.X.toarray().flatten()

# Remove duplicates and sort
unique_vals = np.unique(flat)
second_max = unique_vals[-40:]

In [24]:
np.mean(train.X)

array(32.019123, dtype=float32)

In [25]:
preprocessor = Preprocessor(
    use_key="X",
    filter_gene_by_counts=filter_gene_by_counts,
    filter_cell_by_counts=False,
    normalize_total=1e4,
    result_normed_key="X_normed",
    log1p=True,
    result_log1p_key="X_log1p",
    subset_hvg=False,
    hvg_flavor="seurat_v3" if True else "cell_ranger",
    binning=n_bins,
    result_binned_key="X_binned",
)

In [26]:
preprocessor(train, batch_key=None)
preprocessor(test, batch_key=None)

scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Binning data ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Binning data ...


In [27]:
n_bins

51

In [28]:
train.obs

Unnamed: 0,cells,samples,cell.types,treatment.group,Cohort,no.of.genes,no.of.reads
1001,cy79_p1_CD45_neg_PDL1_neg_AS_C4_R1_G06_S174_comb,Mel79,Mal,treatment.naive,Tirosh,4781,43875
6710,cy121.1_CD45pos_S341,Mel121.1,T.cell,post.treatment,New,3736,438258
6034,cy110_CD45pos_S362,Mel110,Macrophage,post.treatment,New,6864,1323754
1447,CY89A_CD45_POS_6_G10_S178_comb,Mel89,T.cell,treatment.naive,Tirosh,2857,94005
4080,Merck_CD45pos_pl4_S83,Mel194,T.CD4,post.treatment,New,4770,1351127
...,...,...,...,...,...,...,...
3965,merck_cd45pos_PL3_S289,Mel194,T.cell,post.treatment,New,9177,1514743
5436,cy102_CD45neg_CD90neg_S320,Mel102,Mal,post.treatment,New,6954,678925
5471,cy102_CD45neg_CD90neg_S353,Mel102,Mal,post.treatment,New,6871,1153541
5641,cy106_CD45pos_S250,Mel106,B.cell,post.treatment,New,3784,704292


In [29]:
id2type_train = dict(enumerate(train.obs["cell.types"].astype("category").cat.categories))

In [30]:
type2id_train = {v: k for k, v in id2type_train.items()}

In [31]:
train.obs["class id"] = train.obs["cell.types"].map(type2id_train)

In [32]:
input_layer_key="X_binned"
allcounts = (
    train.layers[input_layer_key].toarray()
    if issparse(train.layers[input_layer_key])
    else train.layers[input_layer_key]
)

In [33]:
train.obs

Unnamed: 0,cells,samples,cell.types,treatment.group,Cohort,no.of.genes,no.of.reads,class id
1001,cy79_p1_CD45_neg_PDL1_neg_AS_C4_R1_G06_S174_comb,Mel79,Mal,treatment.naive,Tirosh,4781,43875,4
6710,cy121.1_CD45pos_S341,Mel121.1,T.cell,post.treatment,New,3736,438258,8
6034,cy110_CD45pos_S362,Mel110,Macrophage,post.treatment,New,6864,1323754,3
1447,CY89A_CD45_POS_6_G10_S178_comb,Mel89,T.cell,treatment.naive,Tirosh,2857,94005,8
4080,Merck_CD45pos_pl4_S83,Mel194,T.CD4,post.treatment,New,4770,1351127,6
...,...,...,...,...,...,...,...,...
3965,merck_cd45pos_PL3_S289,Mel194,T.cell,post.treatment,New,9177,1514743,8
5436,cy102_CD45neg_CD90neg_S320,Mel102,Mal,post.treatment,New,6954,678925,4
5471,cy102_CD45neg_CD90neg_S353,Mel102,Mal,post.treatment,New,6871,1153541,4
5641,cy106_CD45pos_S250,Mel106,B.cell,post.treatment,New,3784,704292,0


In [34]:
celltypes_labels = train.obs["class id"].tolist()  # make sure count from 0
celltypes_labels = np.array(celltypes_labels)
(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
) = train_test_split(
    allcounts, celltypes_labels, test_size=0.1, shuffle=True
)


In [35]:
def compute_sparsity(array):
    total_elements = array.size
    zero_elements = np.count_nonzero(array == 0)
    sparsity = zero_elements / total_elements
    return sparsity

In [36]:
compute_sparsity(train_data)

0.8638755385029617

In [37]:
compute_sparsity(valid_data)

0.868132183908046

In [38]:
train_data.shape

(4952, 1200)

In [39]:
valid_data.shape

(551, 1200)

In [40]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )

In [41]:
genes = train.var.gene_name.tolist()

In [42]:
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)

In [43]:
def tokenize_batchy(
    data: np.ndarray,
    gene_ids: np.ndarray,
    return_pt: bool = True,
    append_cls: bool = True,
    include_zero_gene: bool = False,
    cls_id: int = "<cls>",
    mod_type: np.ndarray = None,
    cls_id_mod_type: int = None,
) -> List[Tuple[Union[torch.Tensor, np.ndarray]]]:
    """
    Tokenize a batch of data. Returns a list of tuple (gene_id, count).

    Args:
        data (array-like): A batch of data, with shape (batch_size, n_features).
            n_features equals the number of all genes.
        gene_ids (array-like): A batch of gene ids, with shape (n_features,).
        return_pt (bool): Whether to return torch tensors of gene_ids and counts,
            default to True.

    Returns:
        list: A list of tuple (gene_id, count) of non zero gene expressions.
    """
    if data.shape[1] != len(gene_ids):
        raise ValueError(
            f"Number of features in data ({data.shape[1]}) does not match "
            f"number of gene_ids ({len(gene_ids)})."
        )
    if mod_type is not None and data.shape[1] != len(mod_type):
        raise ValueError(
            f"Number of features in data ({data.shape[1]}) does not match "
            f"number of mod_type ({len(mod_type)})."
        )

    tokenized_data = []
    for i in range(len(data)):
        row = data[i]
        mod_types = None
        if include_zero_gene:
            values = row
            genes = gene_ids
            if mod_type is not None:
                mod_types = mod_type
        else:
            idx = np.nonzero(row)[0]
            values = row[idx]
            genes = gene_ids[idx]
            if mod_type is not None:
                mod_types = mod_type[idx]
        if append_cls:
            genes = np.insert(genes, 0, cls_id)
            values = np.insert(values, 0, 0)
            if mod_type is not None:
                mod_types = np.insert(mod_types, 0, cls_id_mod_type)
        if return_pt:
            genes = torch.from_numpy(genes).long()
            values = torch.from_numpy(values).float()
            if mod_type is not None:
                mod_types = torch.from_numpy(mod_types).long()
        tokenized_data.append((genes, values, mod_types))
    return tokenized_data

def pad_batchy(
    batch: List[Tuple],
    max_len: int,
    vocab: Vocab,
    pad_token: str = "<pad>",
    pad_value: int = 0,
    cls_appended: bool = True,
    vocab_mod: Vocab = None,
) -> Dict[str, torch.Tensor]:
    """
    Pad a batch of data. Returns a list of Dict[gene_id, count].

    Args:
        batch (list): A list of tuple (gene_id, count).
        max_len (int): The maximum length of the batch.
        vocab (Vocab): The vocabulary containing the pad token.
        pad_token (str): The token to pad with.

    Returns:
        Dict[str, torch.Tensor]: A dictionary of gene_id and count.
    """
    #max_ori_len = max(len(batch[i][0]) for i in range(len(batch)))
    #max_len = min(max_ori_len, max_len)

    pad_id = vocab[pad_token]
    if vocab_mod is not None:
        mod_pad_id = vocab_mod[pad_token]
    gene_ids_list = []
    values_list = []
    mod_types_list = []

    for i in range(len(batch)):
        gene_ids, values, mod_types = batch[i]

        if len(gene_ids) > max_len:
            # sample max_len genes
            if not cls_appended:
                idx = np.random.choice(len(gene_ids), max_len, replace=False)
            else:
                idx = np.random.choice(len(gene_ids) - 1, max_len - 1, replace=False)
                idx = idx + 1
                idx = np.insert(idx, 0, 0)
            gene_ids = gene_ids[idx]
            values = values[idx]
            if mod_types is not None:
                mod_types = mod_types[idx]
        if len(gene_ids) < max_len:
            gene_ids = torch.cat(
                [
                    gene_ids,
                    torch.full(
                        (max_len - len(gene_ids),), pad_id, dtype=gene_ids.dtype
                    ),
                ]
            )
            values = torch.cat(
                [
                    values,
                    torch.full((max_len - len(values),), pad_value, dtype=values.dtype),
                ]
            )
            if mod_types is not None:
                mod_types = torch.cat(
                    [
                        mod_types,
                        torch.full(
                            (max_len - len(mod_types),),
                            mod_pad_id,
                            dtype=mod_types.dtype,
                        ),
                    ]
                )

        gene_ids_list.append(gene_ids)
        values_list.append(values)
        if mod_types is not None:
            mod_types_list.append(mod_types)

    batch_padded = {
        "genes": torch.stack(gene_ids_list, dim=0),
        "values": torch.stack(values_list, dim=0),
    }
    if mod_types is not None:
        batch_padded["mod_types"] = torch.stack(mod_types_list, dim=0)
    return batch_padded

def tokenize_and_pad_batchy(
    data: np.ndarray,
    gene_ids: np.ndarray,
    max_len: int,
    vocab: Vocab,
    pad_token: str,
    pad_value: int,
    append_cls: bool = True,
    include_zero_gene: bool = False,
    cls_token: str = "<cls>",
    return_pt: bool = True,
    mod_type: np.ndarray = None,
    vocab_mod: Vocab = None,
) -> Dict[str, torch.Tensor]:
    """
    Tokenize and pad a batch of data. Returns a list of tuple (gene_id, count).
    """
    cls_id = vocab[cls_token]
    if mod_type is not None:
        cls_id_mod_type = vocab_mod[cls_token]
    tokenized_data = tokenize_batchy(
        data,
        gene_ids,
        return_pt=return_pt,
        append_cls=append_cls,
        include_zero_gene=include_zero_gene,
        cls_id=cls_id,
        mod_type=mod_type,
        cls_id_mod_type=cls_id_mod_type if mod_type is not None else None,
    )

    batch_padded = pad_batchy(
        tokenized_data,
        max_len,
        vocab,
        pad_token,
        pad_value,
        cls_appended=append_cls,
        vocab_mod=vocab_mod,
    )
    return batch_padded

In [44]:
include_zero_gene

False

max_seq_len

In [45]:
include_zero_gene

False

In [46]:
tokenized_train = tokenize_and_pad_batchy(
    train_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)
tokenized_valid = tokenize_and_pad_batchy(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)
logger.info(
    f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
)
logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)

scGPT - INFO - train set number of samples: 4952, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 551, 
	 feature length: 1201


In [47]:
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

In [48]:
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = None,  # set to None for auto-detection
) -> DataLoader:
    if num_workers is None:
        # Auto-detect optimal number of workers
        try:
            if platform.system() == "Windows":
                num_workers = max(1, batch_size // 2)
            else:
                num_workers = min(os.cpu_count() or 4, batch_size)
        except Exception:
            num_workers = 0  # fallback if detection fails

    dataset = SeqDataset(data_pt)

    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),  # Only pin if using GPU
    )


In [49]:
def prepare_data() -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    ) 
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    ) 

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    #input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_celltype_labels_train = torch.from_numpy(train_celltype_labels).long()
    tensor_celltype_labels_valid = torch.from_numpy(valid_celltype_labels).long()
    input_values_train, input_values_valid = masked_values_train, masked_values_valid

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target":target_values_train,
        "celltype_labels": tensor_celltype_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target":target_values_valid,
        "celltype_labels": tensor_celltype_labels_valid,
    }

    return train_data_pt, valid_data_pt

In [50]:
num_types = len(np.unique(celltypes_labels))
num_types

9

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntokens = len(vocab)  
model = TransformerModel(
    ntoken=ntokens,
    d_model=embsize,
    nhead=nhead,
    d_hid=d_hid,
    nlayers=12,
    nlayers_cls=3,
    n_cls=num_types if CLS else 1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=MVC,
    do_dab=DAB,
    use_batch_labels=INPUT_BATCH_LABELS,
    domain_spec_batchnorm=config.DSBN,
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    ecs_threshold=ecs_threshold,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=fast_transformer,
    fast_transformer_backend=fast_transformer_backend,
    pre_norm=config.pre_norm,
)

In [52]:
num_types

9

In [113]:
model_file

'best_model.pt'

In [114]:
try:
    model.load_state_dict(torch.load(model_file))
    logger.info(f"Loading all model params from {model_file}")
except:
        # only load params that are in the model and match the size
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Si

In [115]:
for p,_ in model.named_parameters():
    print(p)

encoder.embedding.weight
encoder.enc_norm.weight
encoder.enc_norm.bias
value_encoder.linear1.weight
value_encoder.linear1.bias
value_encoder.linear2.weight
value_encoder.linear2.bias
value_encoder.norm.weight
value_encoder.norm.bias
transformer_encoder.layers.0.self_attn.in_proj_weight
transformer_encoder.layers.0.self_attn.in_proj_bias
transformer_encoder.layers.0.self_attn.out_proj.weight
transformer_encoder.layers.0.self_attn.out_proj.bias
transformer_encoder.layers.0.linear1.weight
transformer_encoder.layers.0.linear1.bias
transformer_encoder.layers.0.linear2.weight
transformer_encoder.layers.0.linear2.bias
transformer_encoder.layers.0.norm1.weight
transformer_encoder.layers.0.norm1.bias
transformer_encoder.layers.0.norm2.weight
transformer_encoder.layers.0.norm2.bias
transformer_encoder.layers.1.self_attn.in_proj_weight
transformer_encoder.layers.1.self_attn.in_proj_bias
transformer_encoder.layers.1.self_attn.out_proj.weight
transformer_encoder.layers.1.self_attn.out_proj.bias
tra

In [57]:
pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
for name, para in model.named_parameters():
    print("-"*20)
    print(f"name: {name}")
    if config.freeze and "encoder" in name and "transformer_encoder" not in name:
        print(f"freezing weights for: {name}")
        para.requires_grad = False
post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
model.to(device)

--------------------
name: encoder.embedding.weight
--------------------
name: encoder.enc_norm.weight
--------------------
name: encoder.enc_norm.bias
--------------------
name: value_encoder.linear1.weight
--------------------
name: value_encoder.linear1.bias
--------------------
name: value_encoder.linear2.weight
--------------------
name: value_encoder.linear2.bias
--------------------
name: value_encoder.norm.weight
--------------------
name: value_encoder.norm.bias
--------------------
name: transformer_encoder.layers.0.self_attn.in_proj_weight
--------------------
name: transformer_encoder.layers.0.self_attn.in_proj_bias
--------------------
name: transformer_encoder.layers.0.self_attn.out_proj.weight
--------------------
name: transformer_encoder.layers.0.self_attn.out_proj.bias
--------------------
name: transformer_encoder.layers.0.linear1.weight
--------------------
name: transformer_encoder.layers.0.linear1.bias
--------------------
name: transformer_encoder.layers.0.linear

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [59]:
criterion_cls = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, schedule_interval, gamma=config.schedule_ratio
)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

In [71]:
'''
def train(model: nn.Module, loader: DataLoader,epoch:int) -> None:
    """
    Train the model for one epoch on CPU.
    """
    model.train()

    (
        total_loss
    ) = (0.0)
    total_error = 0.0
    start_time = time.time()

    num_batches = len(loader)
    for batch, batch_data in enumerate(tqdm(loader, desc=f"Training Epoch {epoch}", leave=False)):
        input_gene_ids = batch_data["gene_ids"].to(device,non_blocking=True)
        input_values = batch_data["values"].to(device,non_blocking=True)
        target_values = batch_data["target_values"].to(device,non_blocking=True)
        celltype_labels = batch_data["celltype_labels"].to(device,non_blocking=True)

        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        masked_positions = input_values.eq(mask_value)
        output_dict = model(
            input_gene_ids,
            input_values,
            src_key_padding_mask=src_key_padding_mask,
            batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
            CLS=CLS,
            CCE=CCE,
            MVC=MVC,
            ECS=ECS,
            do_sample=do_sample_in_train,
        )

        masked_positions = input_values.eq(mask_value)
        loss = 0.0
        metrics_to_log = {}

                
        model.zero_grad()
        loss = torch.tensor(loss, device=device, requires_grad=True)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | "
            )

            # Reset
            total_loss = 0
            start_time = time.time()

def chatgpt_train(model: nn.Module, loader: DataLoader,epoch:int) -> None:
    model.train()

    total_loss = 0.0
    total_error = 0.0
    start_time = time.time()
    num_batches = len(loader)
    
    for batch, batch_data in enumerate(tqdm(loader, desc=f"Training Epoch {epoch}", leave=False)):
        input_gene_ids = batch_data["gene_ids"].to(device, non_blocking=True)
        input_values = batch_data["values"].to(device, non_blocking=True)
        celltype_labels = batch_data["celltype_labels"].to(device, non_blocking=True)
        batch_labels = batch_data.get("batch_labels", None)
        if batch_labels is not None:
            batch_labels = batch_labels.to(device, non_blocking=True)

        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])

        with torch.cuda.amp.autocast(enabled=config.amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
                do_sample=do_sample_in_train,
            )

            loss = None

            if MLM:
                output = output_dict["mlm_output"]
                masked_positions = input_values.eq(mask_value)
                for idx, row in enumerate(output):
                    masked_position = src_key_padding_mask[idx]
                    if masked_position.any():
                        predicted_values = row[masked_position]
                        gene_tokens = input_gene_ids[idx][masked_position]
                        cancer_type, cell_type = id2type_train[celltype_labels[idx].item()].split()
                        single_loss = mse_main_loss_vectorized(
                            predicted_values,
                            cancer_type,
                            cell_type,
                            gene_tokens,
                            device
                        )
                        if loss is None:
                            loss = single_loss
                        else:
                            loss = loss + single_loss
            else:
                raise ValueError("MLM must be True — no loss defined otherwise.")
            
            # === Skip training step if loss is None ===
        if loss is not None:
                model.zero_grad()
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
            
                with warnings.catch_warnings(record=True) as w:
                    warnings.filterwarnings("always")
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), 1.0,
                        error_if_nonfinite=False if scaler.is_enabled() else True,
                    )
                    if len(w) > 0:
                        logger.warning(
                            f"⚠️ Infinite gradient detected (scale={scaler.get_scale()}) — may stabilize after autoscaling."
                        )
            
                scaler.step(optimizer)
                scaler.update()
            
                total_loss += loss.item()

            loss = torch.tensor(0.0, device=device)
            if MLM:
                output = output_dict["mlm_output"]
                masked_positions = input_values.eq(mask_value)
                for idx, row in enumerate(output):
                    masked_position = src_key_padding_mask[idx]
                    predicted_values = row[masked_position]
                    gene_tokens = input_gene_ids[idx][masked_position]
                    cancer_type, cell_type = id2type_train[celltype_labels[idx].item()].split()
                    loss += mse_main_loss_vectorized(
                                predicted_values,
                                cancer_type,
                                cell_type,
                                gene_tokens,
                                device
                                )

        # Convert float loss to tensor for backward

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)

        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"⚠️ Infinite gradient detected (scale={scaler.get_scale()}) — may stabilize after autoscaling."
                )

        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()



        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval

            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | loss {cur_loss:5.4f} |"
            )

            # Reset logging stats
            total_loss = 0.0
            start_time = time.time() '''

def chatgpt_train_main(model: nn.Module, loader: DataLoader, epoch: int) -> None:
    model.train()

    total_loss = 0.0
    total_error = 0.0
    total_mse = 0.0
    total_cls = 0.0
    total_zero_log_prob = 0.0
    start_time = time.time()
    num_batches = len(loader)

    for batch, batch_data in enumerate(tqdm(loader, desc=f"Training Epoch {epoch}", leave=False)):
        input_gene_ids = batch_data["gene_ids"].to(device, non_blocking=True)
        input_values = batch_data["values"].to(device, non_blocking=True)
        celltype_labels = batch_data["celltype_labels"].to(device, non_blocking=True)
        target_values = batch_data["target"].to(device,non_blocking=True)
        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        loss = 0.0
        metrics_to_log = {}
        with torch.cuda.amp.autocast(enabled=config.amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=None,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
                do_sample=do_sample_in_train,
            )
        masked_positions = input_values.eq(mask_value)
        if MLM:
            loss_mse = criterion(output_dict["mlm_output"], target_values, masked_positions)
            loss += loss_mse
            metrics_to_log["train/mse"] = loss_mse.item()
        if explicit_zero_prob:
            loss_zero_log_prob = criterion_neg_log_bernoulli(output_dict["mlm_zero_probs"], target_values, masked_positions)
            loss += loss_zero_log_prob
            metrics_to_log["train/nzlp"] = loss_zero_log_prob.item()
        if CLS:
            loss_cls = criterion_cls(output_dict["cls_output"], celltype_labels)
            loss += loss_cls
            metrics_to_log["train/cls"] = loss_cls.item()
            error_rate = 1 - ((output_dict["cls_output"].argmax(1) == celltype_labels).sum().item()) / celltype_labels.size(0)
            
        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        total_mse += loss_mse.item() if MLM else 0.0
        total_cls += loss_cls.item() if CLS else 0.0
        total_zero_log_prob += loss_zero_log_prob.item() if explicit_zero_prob else 0.0
        total_error += error_rate

        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            cur_cls = total_cls / log_interval if CLS else 0.0
            cur_zero_log_prob = total_zero_log_prob / log_interval if explicit_zero_prob else 0.0
            cur_error = total_error / log_interval

            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | "
                + (f"mse {cur_mse:5.2f} | mre {cur_error:5.2f} |" if MLM else "")
                + (f"cls {cur_cls:5.2f} | " if CLS else "")
                + (f"err {cur_error:5.2f} | " if CLS else "")
                + (f"nzlp {cur_zero_log_prob:5.2f} |" if explicit_zero_prob else "")
            )

            total_loss = 0
            total_mse = 0
            total_cls = 0
            total_zero_log_prob = 0
            total_error = 0
            start_time = time.time()


'''
def eval_no_dropout(model):
    model.train()
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.eval()
    return model
'''


#@torch.no_grad()
def evaluateog(model: nn.Module, loader: DataLoader,return_raw: bool = False) -> tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    num_batches = len(loader)
    predictions = []
    for batch_data in tqdm(loader, desc="Evaluating", leave=False):
        input_gene_ids = batch_data["gene_ids"].to(device, non_blocking=True)
        input_values   = batch_data["values"].to(device, non_blocking=True)
        target_values  = batch_data["target"].to(device, non_blocking=True)
        celltype_labels = batch_data["celltype_labels"].to(device, non_blocking=True)
        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        with torch.cuda.amp.autocast(enabled=False):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=None,
                CLS=CLS, CCE=CCE, MVC=MVC, ECS=ECS, do_sample=False,
            )
            output_values = output_dict["cls_output"]
            loss = criterion_cls(output_values, celltype_labels)
            if CLS:
                loss_cls = criterion_cls(output_dict["cls_output"], celltype_labels)
                loss += loss_cls
                error = 1 - ((output_dict["cls_output"].argmax(1) == celltype_labels).sum().item()) / celltype_labels.size(0)
        preds = output_values.argmax(1).cpu().numpy()
        predictions.append(preds)
        total_loss += loss.item()
        total_error += error

    avg_loss = total_loss / num_batches
    avg_error = total_error / num_batches
    if return_raw:
        return predictions
    return avg_loss, avg_error


In [119]:
best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train_data_pt, valid_data_pt = prepare_data()
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=32,
        shuffle=False,
        intra_domain_shuffle=True,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=32,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
    )

    if config.do_train:
        chatgpt_train_main(
            model,
            loader=train_loader,epoch=epoch
        )
    val_loss, val_err = evaluateog(
        model,
        loader=valid_loader,
    )
    elapsed = time.time() - epoch_start_time
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} | err {val_err:5.4f}"
    )
    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = copy.deepcopy(model.state_dict())
        torch.save(best_model_state, "final_classification_melenoma.pt")
        best_model_epoch = epoch
        logger.info(f"Best model with score {best_val_loss:5.9f}")

    scheduler.step()

random masking at epoch   1, ratio of masked values in train:  0.1471


Training Epoch 1:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.79it/s]

scGPT - INFO - | epoch   1 | 100/155 batches | lr 0.0001 | ms/batch 230.89 | loss  1.37 | cls  1.37 | err  0.49 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   1 | time: 39.12s | valid loss/mse 0.9692 | err 0.1625
scGPT - INFO - -----------------------------------------------------------------------------------------




scGPT - INFO - Best model with score 0.969169910
random masking at epoch   2, ratio of masked values in train:  0.1471


Training Epoch 2:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.76it/s]

scGPT - INFO - | epoch   2 | 100/155 batches | lr 0.0001 | ms/batch 223.34 | loss  0.38 | cls  0.38 | err  0.12 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   2 | time: 38.50s | valid loss/mse 0.4070 | err 0.0461
scGPT - INFO - -----------------------------------------------------------------------------------------




scGPT - INFO - Best model with score 0.406984711
random masking at epoch   3, ratio of masked values in train:  0.1471


Training Epoch 3:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.74it/s]

scGPT - INFO - | epoch   3 | 100/155 batches | lr 0.0001 | ms/batch 224.55 | loss  0.21 | cls  0.21 | err  0.06 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   3 | time: 39.00s | valid loss/mse 0.4028 | err 0.0444
scGPT - INFO - -----------------------------------------------------------------------------------------




scGPT - INFO - Best model with score 0.402807827
random masking at epoch   4, ratio of masked values in train:  0.1471


Training Epoch 4:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.72it/s]

scGPT - INFO - | epoch   4 | 100/155 batches | lr 0.0001 | ms/batch 225.19 | loss  0.17 | cls  0.17 | err  0.05 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   4 | time: 38.94s | valid loss/mse 0.4505 | err 0.0496
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch   5, ratio of masked values in train:  0.1471


Training Epoch 5:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.72it/s]

scGPT - INFO - | epoch   5 | 100/155 batches | lr 0.0001 | ms/batch 226.65 | loss  0.15 | cls  0.15 | err  0.04 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   5 | time: 39.02s | valid loss/mse 0.2759 | err 0.0295
scGPT - INFO - -----------------------------------------------------------------------------------------




scGPT - INFO - Best model with score 0.275863568
random masking at epoch   6, ratio of masked values in train:  0.1471


Training Epoch 6:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:21<00:11,  4.73it/s]

scGPT - INFO - | epoch   6 | 100/155 batches | lr 0.0001 | ms/batch 221.59 | loss  0.14 | cls  0.14 | err  0.03 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   6 | time: 38.56s | valid loss/mse 0.3184 | err 0.0322
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch   7, ratio of masked values in train:  0.1471


Training Epoch 7:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.74it/s]

scGPT - INFO - | epoch   7 | 100/155 batches | lr 0.0001 | ms/batch 226.20 | loss  0.13 | cls  0.13 | err  0.03 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   7 | time: 38.95s | valid loss/mse 0.2972 | err 0.0357
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch   8, ratio of masked values in train:  0.1471


Training Epoch 8:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:22<00:11,  4.73it/s]

scGPT - INFO - | epoch   8 | 100/155 batches | lr 0.0001 | ms/batch 222.39 | loss  0.12 | cls  0.12 | err  0.03 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   8 | time: 38.57s | valid loss/mse 0.2828 | err 0.0305
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch   9, ratio of masked values in train:  0.1471


Training Epoch 9:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▊                                                    | 100/155 [00:21<00:11,  4.74it/s]

scGPT - INFO - | epoch   9 | 100/155 batches | lr 0.0001 | ms/batch 221.22 | loss  0.10 | cls  0.10 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   9 | time: 38.40s | valid loss/mse 0.3034 | err 0.0402
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch  10, ratio of masked values in train:  0.1471


Training Epoch 10:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 100/155 [00:22<00:11,  4.76it/s]

scGPT - INFO - | epoch  10 | 100/155 batches | lr 0.0001 | ms/batch 224.36 | loss  0.09 | cls  0.09 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  10 | time: 38.74s | valid loss/mse 0.2943 | err 0.0305
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch  11, ratio of masked values in train:  0.1471


Training Epoch 11:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 100/155 [00:22<00:11,  4.74it/s]

scGPT - INFO - | epoch  11 | 100/155 batches | lr 0.0001 | ms/batch 225.00 | loss  0.09 | cls  0.09 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  11 | time: 38.87s | valid loss/mse 0.2164 | err 0.0288
scGPT - INFO - -----------------------------------------------------------------------------------------




scGPT - INFO - Best model with score 0.216374342
random masking at epoch  12, ratio of masked values in train:  0.1471


Training Epoch 12:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 100/155 [00:22<00:11,  4.74it/s]

scGPT - INFO - | epoch  12 | 100/155 batches | lr 0.0001 | ms/batch 225.18 | loss  0.09 | cls  0.09 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  12 | time: 38.89s | valid loss/mse 0.2799 | err 0.0260
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch  13, ratio of masked values in train:  0.1471


Training Epoch 13:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 100/155 [00:22<00:11,  4.74it/s]

scGPT - INFO - | epoch  13 | 100/155 batches | lr 0.0001 | ms/batch 224.78 | loss  0.10 | cls  0.10 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  13 | time: 38.79s | valid loss/mse 0.2913 | err 0.0305
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch  14, ratio of masked values in train:  0.1471


Training Epoch 14:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 100/155 [00:22<00:11,  4.75it/s]

scGPT - INFO - | epoch  14 | 100/155 batches | lr 0.0001 | ms/batch 224.56 | loss  0.07 | cls  0.07 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  14 | time: 38.79s | valid loss/mse 0.2802 | err 0.0243
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch  15, ratio of masked values in train:  0.1471


Training Epoch 15:  65%|██████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 100/155 [00:22<00:11,  4.73it/s]

scGPT - INFO - | epoch  15 | 100/155 batches | lr 0.0000 | ms/batch 224.94 | loss  0.07 | cls  0.07 | err  0.02 | 


                                                                                                                                                                                                            

scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch  15 | time: 38.86s | valid loss/mse 0.2781 | err 0.0288
scGPT - INFO - -----------------------------------------------------------------------------------------




random masking at epoch  16, ratio of masked values in train:  0.1471


Training Epoch 16:  22%|████████████████████████████████▏                                                                                                                  | 34/155 [00:08<00:25,  4.72it/s]Exception in thread Thread-36 (_pin_memory_loop):
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/opt/conda/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
    do_one_step()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 122, i

KeyboardInterrupt: 

In [54]:
def testf(model: nn.Module, adata: DataLoader) -> float:
    adata.obs["class id"] = adata.obs["cell.types"].map(type2id_train)
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].toarray()
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )

    celltypes_labels = adata.obs["class id"].tolist()  # make sure count from 0
    celltypes_labels = np.array(celltypes_labels)

    tokenized_test = tokenize_and_pad_batchy(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=include_zero_gene,
    )

    input_values_test = random_mask_value(
        tokenized_test["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )

    test_data_pt = {
        "gene_ids": tokenized_test["genes"],
        "values": input_values_test,
        "target": tokenized_test["values"],
        "celltype_labels": torch.from_numpy(celltypes_labels).long(),
    }

    test_loader = DataLoader(
        dataset=SeqDataset(test_data_pt),
        batch_size=eval_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=min(len(os.sched_getaffinity(0)), eval_batch_size // 2),
        pin_memory=True,
    )

    model.eval()
    predictions = evaluateog(
        model,
        loader=test_loader,
        return_raw=True,
    )
    predictions = np.concatenate(predictions)
# compute accuracy, precision, recall, f1
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

    accuracy = accuracy_score(celltypes_labels, predictions)
    precision = precision_score(celltypes_labels, predictions, average="macro")
    recall = recall_score(celltypes_labels, predictions, average="macro")
    macro_f1 = f1_score(celltypes_labels, predictions, average="macro")

    logger.info(
        f"Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, "
        f"Macro F1: {macro_f1:.3f}"
    )

    results = {
        "test/accuracy": accuracy,
        "test/precision": precision,
        "test/recall": recall,
        "test/macro_f1": macro_f1,
    }

    return predictions,celltypes_labels,results

In [55]:
model.load_state_dict(torch.load('final_classification_melenoma.pt'))

<All keys matched successfully>

In [60]:
predictions = testf(model,test)

                                                                                                                                                                                                            

scGPT - INFO - Accuracy: 0.964, Precision: 0.914, Recall: 0.904, Macro F1: 0.909




In [61]:
predictions,celltypes_labels,results = predictions

In [64]:
celltypes_labels

array([7, 0, 4, ..., 5, 1, 7])

In [67]:
results

{'test/accuracy': 0.9643895348837209,
 'test/precision': 0.914420568544977,
 'test/recall': 0.9044512056332985,
 'test/macro_f1': 0.9090877780523887}

In [84]:
sc.pp.neighbors(test, n_neighbors=15, use_rep='X')
sc.tl.umap(test)
test.obs["predictions"] = [id2type_train[p] for p in predictions]
# plot
palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"] 
palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"] + plt.rcParams["axes.prop_cycle"].by_key()["color"] + plt.rcParams["axes.prop_cycle"].by_key()["color"]
palette_ = {c: palette_[i] for i, c in enumerate(celltypes)}

with plt.rc_context({"figure.figsize": (6, 4), "figure.dpi": (300)}):
    sc.pl.umap(
        test,
        color=["cell.type", "predictions"],
        palette=palette_,
        show=False,
    )
    plt.savefig(save_dir / "results.png", dpi=300)

save_dict = {
    "predictions": predictions,
    "labels": celltypes_labels,
    "results": results,
    "id_maps": id2type_train
}
with open(save_dir / "results.pkl", "wb") as f:
    pickle.dump(save_dict, f)
    

KeyError: 'Could not find key cell.type in .var_names or .obs.columns.'

<Figure size 4122x1200 with 0 Axes>