In [39]:
# %%
import json
import os
from pathlib import Path
import sys
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
# from . import asyn
import pickle
import torch
from anndata import AnnData
import scanpy as sc
#import scvi
import seaborn as sns
import numpy as np
#import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from sklearn.metrics import confusion_matrix

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

In [2]:
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="ms",
    do_train=True,
    load_model= r"C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human",
    mask_ratio=0.0,
    epochs=10,
    n_bins=51,
    MVC=False, # Masked value prediction for cell embedding
    ecs_thres=0.0, # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=0.0,
    lr=1e-4,
    batch_size=32,
    layer_size=128,
    nlayers=4,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead=4,  # number of heads in nn.MultiheadAttention
    dropout=0.2,  # dropout probability
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=5,
    fast_transformer=True,
    pre_norm=False,
    amp=True,  # Automatic Mixed Precision
    include_zero_gene = False,
    freeze = False, #freeze
    DSBN = False,  # Domain-spec batchnorm
)

In [3]:
config = hyperparameter_defaults

In [4]:
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config["mask_ratio"]
mask_value = "auto"  # for masked values, now it should always be auto

include_zero_gene = config["include_zero_gene"]  # if True, include zero genes among hvgs in the training
max_seq_len = 3001
n_bins = config["n_bins"]

# input/output representation
input_style = "binned"  # "normed_raw", "log1p", or "binned"
output_style = "binned"  # "normed_raw", "log1p", or "binned"

# settings for training
MLM = False  # whether to use masked language modeling, currently it is always on.
CLS = True  # celltype classification objective
ADV = False  # Adversarial training for batch correction+
CCE = False  # Contrastive cell embedding objective
#MVC = config.MVC  # Masked value prediction for cell embedding
#ECS = config.ecs_thres > 0  # Elastic cell similarity objective
DAB = False  # Domain adaptation by reverse backpropagation, set to 2 for separate optimizer
INPUT_BATCH_LABELS = False  # TODO: have these help MLM and MVC, while not to classifier
input_emb_style = "continuous"  # "category" or "continuous" or "scaling"
cell_emb_style = "cls"  # "avg-pool" or "w-pool" or "cls"
adv_E_delay_epochs = 0  # delay adversarial training on encoder for a few epochs
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = config["ecs_thres"]
dab_weight = config["dab_weight"]

explicit_zero_prob = MLM and include_zero_gene  # whether explicit bernoulli for zeros
do_sample_in_train = False and explicit_zero_prob  # sample the bernoulli in training

per_seq_batch_sample = False

# settings for optimizer
lr = config["lr"]  # TODO: test learning rate ratio between two tasks
lr_ADV = 1e-3  # learning rate for discriminator, used when ADV is True
batch_size = config["batch_size"]
eval_batch_size = config["batch_size"]
epochs = config["epochs"] 
schedule_interval = 1




In [5]:
# %% validate settings
assert input_style in ["normed_raw", "log1p", "binned"]
assert output_style in ["normed_raw", "log1p", "binned"]
assert input_emb_style in ["category", "continuous", "scaling"]
if input_style == "binned":
    if input_emb_style == "scaling":
        raise ValueError("input_emb_style `scaling` is not supported for binned input.")
elif input_style == "log1p" or input_style == "normed_raw":
    if input_emb_style == "category":
        raise ValueError(
            "input_emb_style `category` is not supported for log1p or normed_raw input."
        )

if input_emb_style == "category":
    mask_value = n_bins + 1
    pad_value = n_bins  # for padding gene expr values
    n_input_bins = n_bins + 2
else:
    mask_value = -1
    pad_value = -2
    n_input_bins = n_bins

if ADV and DAB:
    raise ValueError("ADV and DAB cannot be both True.")
DAB_separate_optim = True if DAB > 1 else False

In [6]:
import scgpt as scg
import time 
dataset_name = config["dataset_name"]
save_dir = Path(f"./save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

save to save\dev_ms-Jan23-16-47


## Load and preprocess data
In this part we find 
- Data loading of dataset name
- making the bach category columns
- loading the model dir and vocab and args


In [7]:
"""
In this step we: 
1. Standardize their cell type labels
2. Adding batch information to track data sources
3. Organizing gene names as indices
4. Setting up processing flags
5. creating a backup
Combining the datasets 
"""
if dataset_name == "ms":
    data_dir = Path(r"C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\ms")
    adata = sc.read(data_dir / "c_data.h5ad") #this is the reference data 
    #adata_test = sc.read(data_dir / "filtered_ms_adata.h5ad") #this is the query data 
    adata.obs["celltype"] = adata.obs["Factor Value[inferred cell type - authors labels]"].astype("category") #adding a new column "celltype"
    #adata_test.obs["celltype"] = adata_test.obs["Factor Value[inferred cell type - authors labels]"].astype("category") #adding a new column "celltype" for the dquery data
    adata.obs["batch_id"]  = adata.obs["str_batch"] = "0" #creates 2 identical columns in the metadata which tracks from which dataset each cell comes 
    #adata_test.obs["batch_id"]  = adata_test.obs["str_batch"] = "1"     #for the query data this is stored as 1      
    adata.var.set_index(adata.var["gene_name"], inplace=True) #changes the index of the genes to the name of the gene 
    data_is_raw = False #Assuming it is already normalized / pre-processed / 
    filter_gene_by_counts = False #we don't filter any genes on their expression counts
    #adata_test_raw = adata_test.copy() #make a safety backup of the test dataset (This is good practice!)
    #adata = adata.concatenate(adata_test, batch_key="str_batch") #this combines both the test and the training data
                
'''
We concatenate so that we can create consistent category encodings across both datasets. 
'''

# make the batch category column
batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
adata.obs["batch_id"] = batch_id_labels
celltype_id_labels = adata.obs["celltype"].astype("category").cat.codes.values
celltypes = adata.obs["celltype"].unique()
num_types = len(np.unique(celltype_id_labels))
id2type = dict(enumerate(adata.obs["celltype"].astype("category").cat.categories))
adata.obs["celltype_id"] = celltype_id_labels
adata.var["gene_name"] = adata.var.index.tolist()

In [58]:
print(id2type)
print(celltype_id_labels)
print(celltypes)
print(num_types)
print(adata.obs["celltype_id"])

{0: 'PVALB-expressing interneuron', 1: 'SST-expressing interneuron', 2: 'SV2C-expressing interneuron', 3: 'VIP-expressing interneuron', 4: 'astrocyte', 5: 'cortical layer 2-3 excitatory neuron A', 6: 'cortical layer 2-3 excitatory neuron B', 7: 'cortical layer 4 excitatory neuron', 8: 'cortical layer 5-6 excitatory neuron', 9: 'endothelial cell', 10: 'microglial cell', 11: 'mixed excitatory neuron', 12: 'mixed glial cell?', 13: 'oligodendrocyte A', 14: 'oligodendrocyte C', 15: 'oligodendrocyte precursor cell', 16: 'phagocyte', 17: 'pyramidal neuron?'}
[8 0 6 ... 0 7 1]
['cortical layer 5-6 excitatory neuron', 'PVALB-expressing interneuron', 'cortical layer 2-3 excitatory neuron B', 'oligodendrocyte C', 'VIP-expressing interneuron', ..., 'astrocyte', 'microglial cell', 'endothelial cell', 'oligodendrocyte A', 'phagocyte']
Length: 18
Categories (18, object): ['PVALB-expressing interneuron', 'SST-expressing interneuron', 'SV2C-expressing interneuron', 'VIP-expressing interneuron', ..., 'o

In [55]:
import shutil 
import json
import numpy as np
print (config["load_model"] is not None)

if config["load_model"] is not None:
    model_dir = Path(config["load_model"])
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"
    print(model_dir, model_config_file, model_file, vocab_file)

    vocab = GeneVocab.from_file(vocab_file)
    shutil.copy(vocab_file, save_dir / "vocab.json")
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]

True
C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human\args.json C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human\best_model.pt C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human\vocab.json
scGPT - INFO - match 2808/2808 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human\best_model.pt, the model args will override the config C:\Users\annel\OneDrive\Documenten\Machine Learning\scGPT_data\Human\args.json.


In [9]:
# set up the preprocessor, use the args to config the workflow
from scgpt.preprocess import Preprocessor
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=filter_gene_by_counts,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=False,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)


#adata_test = adata[adata.obs["str_batch"] == "1"]
adata = adata[adata.obs["str_batch"] == "0"]

preprocessor(adata, batch_key=None)
#preprocessor(adata_test, batch_key=None)

scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Binning data ...


In [10]:
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split
import numpy as np
# get the training data

input_layer_key = {  # the values of this map coorespond to the keys in preprocessing
    "normed_raw": "X_normed",
    "log1p": "X_normed",
    "binned": "X_binned",
}[input_style]
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

celltypes_labels = adata.obs["celltype_id"].tolist()  # make sure count from 0
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
    train_batch_labels,
    valid_batch_labels,
) = train_test_split(
    all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True
)

In [11]:
print (valid_data.shape, valid_celltype_labels.shape, valid_batch_labels.shape)
print (valid_data)

(785, 2808) (785,) (785,)
[[ 0  0  0 ...  0  0  0]
 [42  0  0 ...  0  0  0]
 [23  0  0 ...  0  0  0]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  6  0 ...  0  0  0]
 [40  0  0 ...  0  0  0]]


In [13]:
# Let's examine valid_data more closely
print("Data inspection:")
print(f"Type of valid_data: {type(valid_data)}")
print(f"Shape of valid_data: {valid_data.shape}")
print(f"Data type (dtype): {valid_data.dtype}")

# Make sure it's a proper numpy array with the right format
if not isinstance(valid_data, np.ndarray):
    valid_data = np.array(valid_data)
    
# Check the first few values again after conversion
print("\nFirst few values:")
print(valid_data[:2, :5])  # Show first 2 cells, first 5 genes

Data inspection:
Type of valid_data: <class 'numpy.ndarray'>
Shape of valid_data: (785, 2808)
Data type (dtype): int64

First few values:
[[ 0  0  0  0  0]
 [42  0  0  0  0]]


In [14]:
from torchtext._torchtext import (
    Vocab as VocabPybind,
)  # this is a pybind version of torchtext.vocab.Vocab

if config["load_model"] is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)

In [None]:
from scgpt.tokenizer import tokenize_and_pad_batch
include_zero_gene = config["include_zero_gene"]  # if True, include zero genes among hvgs in the training
max_seq_len = 3001  # maximum sequence length, including <cls> and <eoc> tokens

#from #validate settings

pad_value = -2
n_input_bins = n_bins

tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)

logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)

scGPT - INFO - valid set number of samples: 785, 
	 feature length: 1259


In [16]:
from typing import List, Tuple, Dict, Union, Optional
from torch.utils.data import Dataset, DataLoader    

def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    #masked_values_train = random_mask_value(
    #    tokenized_train["values"],
    #    mask_ratio=mask_ratio,
    #    mask_value=mask_value,
    #    pad_value=pad_value,
    #)
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    #print(
    #    f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
    #    f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    #)

    input_gene_ids_valid = (
        tokenized_valid["genes"],
    )
    input_values_valid = tokenized_valid["values"]
    target_values_valid = (
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    tensor_celltype_labels_train = torch.from_numpy(train_celltype_labels).long()
    tensor_celltype_labels_valid = torch.from_numpy(valid_celltype_labels).long()

    if sort_seq_batch:  # TODO: update to random pick seq source in each traning batch
        #train_sort_ids = np.argsort(train_batch_labels)
        #input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        #input_values_train = input_values_train[train_sort_ids]
        #target_values_train = target_values_train[train_sort_ids]
        #tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]
        #tensor_celltype_labels_train = tensor_celltype_labels_train[train_sort_ids]

        valid_sort_ids = np.argsort(valid_batch_labels)
        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]
        tensor_celltype_labels_valid = tensor_celltype_labels_valid[valid_sort_ids]

    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
        "celltype_labels": tensor_celltype_labels_valid,
    }

    return valid_data_pt


# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    if num_workers == 0:
        num_workers = min(os.cpu_count() or 1, batch_size // 2)


    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader


In [17]:

tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)

logger.info(
    f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
    f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
)

scGPT - INFO - valid set number of samples: 785, 
	 feature length: 1259


## Load pre-trained scGPT model for inference

In [18]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerModel(
    len(vocab),
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=3,
    n_cls=num_types if CLS else 1,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
)

if config["load_model"] is not None:
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file, map_location=device)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

model.to(device)
model.eval()  # Add this line for inference mode

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.5, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [19]:
def evaluate(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)
            celltype_labels = batch_data["celltype_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if INPUT_BATCH_LABELS or config["DSBN"] else None,
                    CLS=CLS,  # evaluation does not need CLS or CCE
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                    #generative_training = False,
                )
                output_values = output_dict["cls_output"]
                loss = criterion_cls(output_values, celltype_labels)

                if DAB:
                    loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            accuracy = (output_values.argmax(1) == celltype_labels).sum().item()
            total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids) if DAB else 0.0
            total_num += len(input_gene_ids)
            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)


    if return_raw:
        return np.concatenate(predictions, axis=0)

    return total_loss / total_num, total_error / total_num

In [50]:
from tqdm import tqdm  # Import at the top of your file

def evaluate1(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> float:
    """
    Evaluate the model on the evaluation data with a progress indicator.
    This version runs on CPU and shows real-time progress of the evaluation.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []
    
    # Calculate total number of batches for the progress bar
    total_batches = len(loader)
    
    # Create progress bar with informative description
    progress_bar = tqdm(
        enumerate(loader),
        total=total_batches,
        desc="Evaluating model",
        # Add some useful metrics to display
        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} batches '
                  '[{elapsed}<{remaining}, {rate_fmt}]'
    )
    
    with torch.no_grad():
        # Use progress_bar instead of direct loader iteration
        for batch_idx, batch_data in progress_bar:
            # Get batch data
            input_gene_ids = batch_data["gene_ids"]
            input_values = batch_data["values"]
            target_values = batch_data["target_values"]
            batch_labels = batch_data["batch_labels"]
            celltype_labels = batch_data["celltype_labels"]
            
            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            
            # Handle config safely as before
            use_batch_labels = (INPUT_BATCH_LABELS or 
                              (isinstance(config, dict) and config.get('DSBN', False)) or
                              (hasattr(config, 'DSBN') and config["DSBN"]))
            
            # Forward pass
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if use_batch_labels else None,
                CLS=CLS,
                CCE=False,
                MVC=False,
                ECS=False,
                do_sample=do_sample_in_train,
            )
            
            output_values = output_dict["cls_output"]
            loss = criterion_cls(output_values, celltype_labels)
            
            if DAB:
                loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
            
            # Calculate batch metrics
            batch_size = len(input_gene_ids)
            batch_loss = loss.item()
            accuracy = (output_values.argmax(1) == celltype_labels).sum().item()
            batch_error = 1 - accuracy / batch_size
            
            # Update totals
            total_loss += batch_loss * batch_size
            total_error += batch_error * batch_size
            total_dab += loss_dab.item() * batch_size if DAB else 0.0
            total_num += batch_size
            
            # Store predictions
            preds = output_values.argmax(1).numpy()
            predictions.append(preds)
            
            # Update progress bar with current metrics
            progress_bar.set_postfix({
                'loss': f'{batch_loss:.4f}',
                'error': f'{batch_error:.4f}'
            })
    
    # Close the progress bar
    progress_bar.close()
    
    if return_raw:
        return np.concatenate(predictions, axis=0)
    
    # Calculate final metrics
    avg_loss = total_loss / total_num
    avg_error = total_error / total_num
    
    print(f"\nEvaluation completed:")
    print(f"Average loss: {avg_loss:.4f}")
    print(f"Average error: {avg_error:.4f}")
    
    return avg_loss, avg_error

In [20]:
import os
valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
valid_loader = prepare_dataloader(
    valid_data_pt,
    batch_size=eval_batch_size,
    shuffle=False,
    intra_domain_shuffle=False,
    drop_last=False,
)

In [44]:
# %% inference
def test(model: nn.Module, adata: DataLoader) -> float:
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )

    celltypes_labels = adata.obs["celltype_id"].tolist()  # make sure count from 0
    celltypes_labels = np.array(celltypes_labels)

    batch_ids = adata.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)

    tokenized_test = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=include_zero_gene,
    )

    input_values_test = random_mask_value(
        tokenized_test["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )

    test_data_pt = {
        "gene_ids": tokenized_test["genes"],
        "values": input_values_test,
        "target_values": tokenized_test["values"],
        "batch_labels": torch.from_numpy(batch_ids).long(),
        "celltype_labels": torch.from_numpy(celltypes_labels).long(),
    }

    test_loader = DataLoader(
        dataset=SeqDataset(test_data_pt),
        batch_size=eval_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers = 0,
        #num_workers= min(os.cpu_count() or 1, batch_size // 2),
        pin_memory=True,
    )
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

    model.eval()

    #
    predictions = evaluate1(
        model,
        loader=test_loader,
        return_raw=True,
    )

    # compute accuracy, precision, recall, f1
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

    accuracy = accuracy_score(celltypes_labels, predictions)
    precision = precision_score(celltypes_labels, predictions, average="macro")
    recall = recall_score(celltypes_labels, predictions, average="macro")
    macro_f1 = f1_score(celltypes_labels, predictions, average="macro")

    logger.info(
        f"Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, "
        f"Macro F1: {macro_f1:.3f}"
    )

    results = {
        "test/accuracy": accuracy,
        "test/precision": precision,
        "test/recall": recall,
        "test/macro_f1": macro_f1,
    }

    return predictions, celltypes_labels

In [51]:
criterion = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
criterion_dab = nn.CrossEntropyLoss()
criterion_mlm = criterion_neg_log_bernoulli

In [52]:
best_model = model
print(model)
predictions, label = test(best_model, adata)


TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.5, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

KeyboardInterrupt: 

In [36]:
print(label)

[8 0 6 ... 0 7 1]


In [34]:
# Show what columns are available in adata.obs
print("Available columns in adata.obs:")
print(adata.obs.columns)

Available columns in adata.obs:
Index(['Sample Characteristic[organism]',
       'Sample Characteristic Ontology Term[organism]',
       'Sample Characteristic[individual]',
       'Sample Characteristic Ontology Term[individual]',
       'Sample Characteristic[sex]',
       'Sample Characteristic Ontology Term[sex]',
       'Sample Characteristic[age]',
       'Sample Characteristic Ontology Term[age]',
       'Sample Characteristic[developmental stage]',
       'Sample Characteristic Ontology Term[developmental stage]',
       'Sample Characteristic[organism part]',
       'Sample Characteristic Ontology Term[organism part]',
       'Sample Characteristic[sampling site]',
       'Sample Characteristic Ontology Term[sampling site]',
       'Sample Characteristic[disease]',
       'Sample Characteristic Ontology Term[disease]',
       'Sample Characteristic[organism status]',
       'Sample Characteristic Ontology Term[organism status]',
       'Sample Characteristic[cause of death]',


In [38]:
# First, let's see what columns we have in our AnnData object
print("Columns in adata.obs:")
print(adata.obs.columns)

# Let's look at how cell types are stored
print("\nCell type information:")
if 'celltype' in adata.obs.columns:
    print("Unique cell types from 'celltype' column:")
    print(adata.obs['celltype'].unique())

# Now let's create a mapping between IDs and names
if 'celltype_id' in adata.obs.columns and 'celltype' in adata.obs.columns:
    # Create a dictionary mapping IDs to names
    id_to_name = {}
    for idx, name in zip(adata.obs['celltype_id'], adata.obs['celltype']):
        if idx not in id_to_name:
            id_to_name[idx] = name
    
    print("\nMapping of cell type IDs to names:")
    for id_num in sorted(id_to_name.keys()):
        print(f"Cell type {id_num}: {id_to_name[id_num]} ({len(adata.obs[adata.obs['celltype_id'] == id_num])} cells)")

Columns in adata.obs:
Index(['Sample Characteristic[organism]',
       'Sample Characteristic Ontology Term[organism]',
       'Sample Characteristic[individual]',
       'Sample Characteristic Ontology Term[individual]',
       'Sample Characteristic[sex]',
       'Sample Characteristic Ontology Term[sex]',
       'Sample Characteristic[age]',
       'Sample Characteristic Ontology Term[age]',
       'Sample Characteristic[developmental stage]',
       'Sample Characteristic Ontology Term[developmental stage]',
       'Sample Characteristic[organism part]',
       'Sample Characteristic Ontology Term[organism part]',
       'Sample Characteristic[sampling site]',
       'Sample Characteristic Ontology Term[sampling site]',
       'Sample Characteristic[disease]',
       'Sample Characteristic Ontology Term[disease]',
       'Sample Characteristic[organism status]',
       'Sample Characteristic Ontology Term[organism status]',
       'Sample Characteristic[cause of death]',
       'Sa

In [35]:
# If cell types are stored as categories, we can get the category names
if 'celltype' in adata.obs.columns:
    print("\nCell type categories:")
    print(adata.obs['celltype'].cat.categories)

# Match cell type IDs to names
if 'celltype_id' in adata.obs.columns:
    unique_ids = np.unique(adata.obs['celltype_id'])
    print("\nMapping of IDs to cell types:")
    for id_num in unique_ids:
        # Get one example cell of this type to see its name
        cell_type_name = adata.obs.loc[adata.obs['celltype_id'] == id_num, 'celltype'].iloc[0]
        print(f"ID {id_num}: {cell_type_name}")


Cell type categories:
Index(['PVALB-expressing interneuron', 'SST-expressing interneuron',
       'SV2C-expressing interneuron', 'VIP-expressing interneuron',
       'astrocyte', 'cortical layer 2-3 excitatory neuron A',
       'cortical layer 2-3 excitatory neuron B',
       'cortical layer 4 excitatory neuron',
       'cortical layer 5-6 excitatory neuron', 'endothelial cell',
       'microglial cell', 'mixed excitatory neuron', 'mixed glial cell?',
       'oligodendrocyte A', 'oligodendrocyte C',
       'oligodendrocyte precursor cell', 'phagocyte', 'pyramidal neuron?'],
      dtype='object')

Mapping of IDs to cell types:
ID 0: PVALB-expressing interneuron
ID 1: SST-expressing interneuron
ID 2: SV2C-expressing interneuron
ID 3: VIP-expressing interneuron
ID 4: astrocyte
ID 5: cortical layer 2-3 excitatory neuron A
ID 6: cortical layer 2-3 excitatory neuron B
ID 7: cortical layer 4 excitatory neuron
ID 8: cortical layer 5-6 excitatory neuron
ID 9: endothelial cell
ID 10: microglial

In [31]:
# First, let's see the distribution of predicted cell types
unique_labels, counts = np.unique(label, return_counts=True)
print("Distribution of predictions:")
for lab, count in zip(unique_labels, counts):
    print(f"Cell type {lab}: {count} cells")

# If you want to see the actual cell type names instead of numbers
# Assuming you have a mapping of ID to cell type name in your data
celltype_names = adata.obs['celltype'].cat.categories
print("\nCell type mapping:")
for i, name in enumerate(celltype_names):
    print(f"ID {i}: {name}")

# To calculate metrics later, you'll need the true labels
true_labels = np.array(adata.obs["celltype_id"].tolist())

# Then you can calculate metrics using:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(true_labels, label)
precision = precision_score(true_labels, label, average='macro')
recall = recall_score(true_labels, label, average='macro')
f1 = f1_score(true_labels, label, average='macro')

print("\nPerformance Metrics:")
print(f"Accuracy: {accuracy:.3f}")
print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"F1 Score: {f1:.3f}")

Distribution of predictions:
Cell type 0: 569 cells
Cell type 1: 172 cells
Cell type 2: 234 cells
Cell type 3: 662 cells
Cell type 4: 154 cells
Cell type 5: 2010 cells
Cell type 6: 1019 cells
Cell type 7: 1284 cells
Cell type 8: 997 cells
Cell type 9: 38 cells
Cell type 10: 4 cells
Cell type 11: 114 cells
Cell type 12: 60 cells
Cell type 13: 53 cells
Cell type 14: 3 cells
Cell type 15: 55 cells
Cell type 16: 3 cells
Cell type 17: 413 cells

Cell type mapping:
ID 0: PVALB-expressing interneuron
ID 1: SST-expressing interneuron
ID 2: SV2C-expressing interneuron
ID 3: VIP-expressing interneuron
ID 4: astrocyte
ID 5: cortical layer 2-3 excitatory neuron A
ID 6: cortical layer 2-3 excitatory neuron B
ID 7: cortical layer 4 excitatory neuron
ID 8: cortical layer 5-6 excitatory neuron
ID 9: endothelial cell
ID 10: microglial cell
ID 11: mixed excitatory neuron
ID 12: mixed glial cell?
ID 13: oligodendrocyte A
ID 14: oligodendrocyte C
ID 15: oligodendrocyte precursor cell
ID 16: phagocyte
ID 1

In [134]:
sample_counts = all_counts[0:1]  # Take first sample
tokenized_sample = tokenize_and_pad_batch(
    sample_counts,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)
print("Gene IDs shape:", tokenized_sample["genes"].shape)
print("First few tokens:", tokenized_sample["genes"][0][:10])

Gene IDs shape: torch.Size([1, 252])
First few tokens: tensor([60695,  4765, 17568, 31253, 21300, 34984, 11414,  5273, 32751,  3170])


In [136]:
print("Gene IDs shape:", tokenized_sample["genes"].shape)
print("First few tokens:", tokenized_sample["genes"][0][:10])

Gene IDs shape: torch.Size([1, 252])
First few tokens: tensor([60695,  4765, 17568, 31253, 21300, 34984, 11414,  5273, 32751,  3170])


In [None]:
adata_test_raw.obs["predictions"] = [id2type[p] for p in predictions]

# plot
palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"] 
palette_ = plt.rcParams["axes.prop_cycle"].by_key()["color"] + plt.rcParams["axes.prop_cycle"].by_key()["color"] + plt.rcParams["axes.prop_cycle"].by_key()["color"]
palette_ = {c: palette_[i] for i, c in enumerate(celltypes)}

with plt.rc_context({"figure.figsize": (6, 4), "figure.dpi": (300)}):
    sc.pl.umap(
        adata_test_raw,
        color=["celltype", "predictions"],
        palette=palette_,
        show=False,
    )
    plt.savefig(save_dir / "results.png", dpi=300)

save_dict = {
    "predictions": predictions,
    "labels": labels,
    "results": results,
    "id_maps": id2type
}
with open(save_dir / "results.pkl", "wb") as f:
    pickle.dump(save_dict, f)

