Import Necessary Libraries & Setup


In [1]:
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
import pickle
import torch
from anndata import AnnData
import scanpy as sc
import scvi
import seaborn as sns
import numpy as np
import wandb
from scipy.sparse import issparse, csr_matrix
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import Vocab as VocabPybind
from sklearn.metrics import confusion_matrix

# Insert path for scGPT if needed
sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

# ‚úÖ Set PyTorch GPU Optimization
torch.backends.cudnn.benchmark = True  # Optimizes CUDA performance
torch.backends.cudnn.enabled = True    # Ensures CuDNN is used

# ‚úÖ Check and Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
print(f"üöÄ Using {num_gpus} GPU(s) | Device: {device}")

# Set default visualization parameters and ignore warnings
sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')


Global seed set to 0


üöÄ Using 3 GPU(s) | Device: cuda


  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
# ‚úÖ **Hyperparameters and Configurations**
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="AIDA1_sample",  # Human dataset name
    do_train=True,
    load_model="/data/cellular_aging/references/scGPT_human_pretrained_model",  # Pretrained human model path
    mask_ratio=0.0005,
    epochs=50,
    n_bins=51,
    MLM=False,  # Explicitly add MLM key to avoid KeyError
    MVC=False,  
    ecs_thres=0.0,
    dab_weight=0.0,
    lr=1e-5,
    batch_size=max(1, torch.cuda.device_count()),  
    layer_size=512,
    nlayers=12,
    nhead=8,
    dropout=0.2,
    schedule_ratio=0.9,
    save_eval_interval=5,
    fast_transformer=True,
    pre_norm=False,
    amp=True, 
    include_zero_gene=False,
    freeze=False,
    DSBN=False,
    DAB_separate_optim = False  # 

)

# ‚úÖ Initialize WandB
run = wandb.init(
    config=hyperparameter_defaults,
    dir="/data/cellular_aging/results/fine-tuning",
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config  # WandB converts hyperparameter_defaults into config

# ‚úÖ **Ensure Proper Input Style is Set**
if not hasattr(config, "input_style"):
    config.input_style = "binned"  # Default input style

print(f"‚úÖ Configurations Loaded. Using input style: {config.input_style}")

# ‚úÖ **Set Up Input Representations**
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"  # Always set to auto for masked values
include_zero_gene = config.include_zero_gene
max_seq_len = min(2048, 4417)  # ‚úÖ Dynamically limit max_seq_len if needed
n_bins = config.n_bins
pad_value = 0  # Default padding value

# ‚úÖ **Select Input/Output Representation**
input_style = "binned"
output_style = "binned"

# ‚úÖ **Training Settings**
MLM = False
CLS = True  # Classification objective for age bin prediction
ADV = False  # Disable adversarial training here
CCE = False
MVC = config.MVC
ECS = config.ecs_thres > 0
DAB = False
INPUT_BATCH_LABELS = False
input_emb_style = "continuous"  # Using continuous embedding style for inputs
cell_emb_style = "cls"
adv_E_delay_epochs = 0
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight
explicit_zero_prob = MLM and include_zero_gene
do_sample_in_train = False and explicit_zero_prob
per_seq_batch_sample = False

# ‚úÖ **Optimizer Settings**
lr = config.lr
lr_ADV = 1e-3
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 1

# ‚úÖ **Model Architecture Settings**
fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"
embsize = config.layer_size
d_hid = config.layer_size
nlayers = config.nlayers
nhead = config.nhead
dropout = config.dropout

# ‚úÖ **Logging & Evaluation Settings**
log_interval = 100
save_eval_interval = config.save_eval_interval
do_eval_scib_metrics = True


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmalam007[0m ([33mmalam007-old-dominion-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


‚úÖ Configurations Loaded. Using input style: binned


In [3]:
# ‚úÖ **Load full dataset (AnnData format)**
adata = sc.read_h5ad('/data/cellular_aging/dataset/AIDA.h5ad')

# ‚úÖ Track Total Expression Before Filtering
total_expression_before = adata.X.sum()

# ‚úÖ Load FINAL gene mapping (from BioMart + MyGene.info)
df_final_mapping = pd.read_csv("final_gene_mapping.csv")
final_gene_dict = dict(zip(df_final_mapping["ensembl_id"], df_final_mapping["gene_symbol"]))

# ‚úÖ Extract Ensembl IDs & Apply Gene Mapping
adata.var["ensembl_id"] = adata.var_names.str.split(".").str[0]
adata.var["gene_symbol"] = adata.var["ensembl_id"].map(final_gene_dict)

# ‚úÖ Load `vocab.json`
vocab_path = "/data/cellular_aging/references/scGPT_human_pretrained_model/vocab.json"
with open(vocab_path, "r") as f:
    vocab = json.load(f)

# ‚úÖ Convert vocab genes to a set for fast lookup
vocab_genes = set(vocab.keys())

# ‚úÖ Filter dataset to keep only genes with symbols in `vocab.json`
adata_filtered_vocab = adata[:, adata.var["gene_symbol"].isin(vocab_genes)].copy()

# ‚úÖ Debugging: Check the number of genes after filtering
print(f"üîç Total Genes After Vocab Filtering: {adata_filtered_vocab.shape[1]}")
print(f"üîç Expected Number of Genes in Vocabulary: {len(vocab_genes)}")

# ‚úÖ Track Total Expression After Filtering
total_expression_after = adata_filtered_vocab.X.sum()

# ‚úÖ Compute Gene & Expression Coverage
gene_coverage_vocab = (adata_filtered_vocab.shape[1] / adata.shape[1]) * 100
expression_coverage_vocab = (total_expression_after / total_expression_before) * 100

# ‚úÖ Print Final Summary
print("\nüìä **Final Gene Mapping & Expression Coverage Summary:**")
print(f"üîπ Total Genes Before Vocab Filtering: {adata.shape[1]}")
print(f"üîπ Total Genes After Vocab Filtering: {adata_filtered_vocab.shape[1]}")
print(f"üîπ Total Expression Before Vocab Filtering: {total_expression_before:.2f}")
print(f"üîπ Total Expression After Vocab Filtering: {total_expression_after:.2f}")
print(f"‚úÖ Gene Coverage After Vocab Filtering: {gene_coverage_vocab:.2f}%")
print(f"‚úÖ Expression Coverage After Vocab Filtering: {expression_coverage_vocab:.2f}%")

# ‚úÖ Ensure Sample Size is Valid
sample_fraction = 0.001  # Adjust as needed
num_sample = max(1, int(sample_fraction * adata_filtered_vocab.n_obs))  # Avoid zero or negative samples
num_sample = min(num_sample, adata_filtered_vocab.n_obs)  # Prevent exceeding dataset size

# ‚úÖ Perform Random Sampling
random_indices = np.random.choice(adata_filtered_vocab.n_obs, num_sample, replace=False)
adata_sample = adata_filtered_vocab[random_indices, :].copy()
print(f"‚úÖ Sampled dataset shape: {adata_sample.shape}")

# ‚úÖ Convert to Sparse Format (if not already sparse)
if not issparse(adata_sample.X):
    adata_sample.X = csr_matrix(adata_sample.X)

# ‚úÖ Apply Preprocessing ONLY to the Sampled Data
preprocessor = Preprocessor(
    use_key="X",
    filter_gene_by_counts=False,
    filter_cell_by_counts=False,
    normalize_total=True,
    result_normed_key="X_normed",
    log1p=True,
    result_log1p_key="X_log1p",
    subset_hvg=False,
    hvg_flavor="seurat_v3",
    binning=config.n_bins,  # Use binning defined in the config
    result_binned_key="X_binned",
)

# ‚úÖ Apply Preprocessing to the Sampled Data (do not alter the full dataset)
preprocessor(adata_sample, batch_key=None)

# ‚úÖ Ensure Preprocessing was Applied Correctly
print(f"‚úÖ Preprocessing applied. Available layers: {adata_sample.layers.keys()}")

# ‚úÖ Ensure Gene Symbols are Retained in Sampled Data
if "gene_symbol" not in adata_sample.var.columns:
    print("‚ùå Warning: `gene_symbol` column missing in `adata_sample`. Reassigning from `adata_filtered_vocab`.")
    adata_sample.var["gene_symbol"] = adata_filtered_vocab.var["gene_symbol"]

# ‚úÖ Extract Filtered Gene Symbols from Sampled Data
filtered_gene_symbols = adata_sample.var["gene_symbol"].tolist()

# ‚úÖ Debugging: Check Gene Count After Sampling
print(f"üîç Total Genes in Sampled Dataset: {adata_sample.shape[1]}")
print(f"üîç Total Genes in `filtered_gene_symbols`: {len(filtered_gene_symbols)}")
print(f"üîç First 10 genes: {filtered_gene_symbols[:10]}")

# ‚úÖ Check that the sampled genes are a subset of the vocabulary
if not set(filtered_gene_symbols).issubset(vocab_genes):
    print("‚ö†Ô∏è Warning: Some genes in the sampled data are missing in `vocab.json`!")

# ‚úÖ Ensure correct preprocessing layer is used to create the input tensor
if "X_binned" in adata_sample.layers:
    data_layer = adata_sample.layers["X_binned"]
elif "X_normed" in adata_sample.layers:
    data_layer = adata_sample.layers["X_normed"]
elif "X_log1p" in adata_sample.layers:
    data_layer = adata_sample.layers["X_log1p"]
else:
    raise ValueError("‚ùå No valid processed data layer found in `adata_sample`!")

# ‚úÖ Convert sparse matrix to dense format **before** creating tensors
if issparse(data_layer):
    data_layer = data_layer.toarray()

# ‚úÖ Create Tensors and Move to GPU
values_tensor = torch.tensor(data_layer, dtype=torch.float32, device="cuda")  # Move tensor to GPU

# ‚úÖ Debugging: Check if `values_tensor` shape matches the number of sampled genes
print(f"üîç `values_tensor` shape: {values_tensor.shape}")

num_sampled_genes = len(filtered_gene_symbols)
if values_tensor.shape[1] != num_sampled_genes:
    raise ValueError(
        f"‚ùå Mismatch: `values_tensor` has {values_tensor.shape[1]} features, "
        f"but `filtered_gene_symbols` has {num_sampled_genes} genes. Fix dataset filtering."
    )

print(f"‚úÖ `values_tensor` shape after filtering: {values_tensor.shape}")


üîç Total Genes After Vocab Filtering: 23794
üîç Expected Number of Genes in Vocabulary: 60697

üìä **Final Gene Mapping & Expression Coverage Summary:**
üîπ Total Genes Before Vocab Filtering: 36161
üîπ Total Genes After Vocab Filtering: 23794
üîπ Total Expression Before Vocab Filtering: 2888279552.00
üîπ Total Expression After Vocab Filtering: 2841698816.00
‚úÖ Gene Coverage After Vocab Filtering: 65.80%
‚úÖ Expression Coverage After Vocab Filtering: 98.39%
‚úÖ Sampled dataset shape: (1058, 23794)
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Binning data ...
‚úÖ Preprocessing applied. Available layers: KeysView(Layers with keys: X_normed, X_log1p, X_binned)
üîç Total Genes in Sampled Dataset: 23794
üîç Total Genes in `filtered_gene_symbols`: 23794
üîç First 10 genes: ['MIR1302-2HG', 'FAM138A', 'OR4F5', 'OR4F29', 'OR4F16', 'LINC01409', 'FAM87B', 'LINC01128', 'LINC00115', 'FAM41C']
üîç `values_tensor` shape: torch.Size([105

In [4]:
# ‚úÖ Ensure the index is properly formatted for splitting as integer indices
obs_index = np.arange(adata_sample.n_obs, dtype=np.int64)  # Ensure integer type

# ‚úÖ Split Data into Training (80%) and Testing (20%)
train_idx, test_idx = train_test_split(obs_index, test_size=0.2, random_state=42)

# ‚úÖ Convert to NumPy arrays for efficient indexing
train_idx = np.array(train_idx, dtype=np.int64)
test_idx = np.array(test_idx, dtype=np.int64)

# ‚úÖ Debug: Print dataset size and split verification
print(f"‚úÖ Total Samples: {adata_sample.n_obs}")
print(f"üîπ Train Samples: {len(train_idx)} ({len(train_idx) / adata_sample.n_obs * 100:.2f}%)")
print(f"üîπ Test Samples: {len(test_idx)} ({len(test_idx) / adata_sample.n_obs * 100:.2f}%)")

# ‚úÖ Assign `batch_id` within the Sampled Data
adata_sample = adata_sample.copy()  # Avoid pandas SettingWithCopyWarning
adata_sample.obs["batch_id"] = 0  # Initialize all cells as train (batch 0)
adata_sample.obs.iloc[test_idx, adata_sample.obs.columns.get_loc("batch_id")] = 1  # Mark test cells as batch 1

# ‚úÖ Convert to Numeric Batch Labels and Ensure Tensor Format
batch_ids = torch.tensor(adata_sample.obs["batch_id"].values, dtype=torch.long, device="cuda")  # Move to GPU

# ‚úÖ Debugging: Check batch label distribution
num_batch_labels = batch_ids.unique().numel()
print(f"‚úÖ Assigned {num_batch_labels} unique batch labels (Train/Test Split).")


‚úÖ Total Samples: 1058
üîπ Train Samples: 846 (79.96%)
üîπ Test Samples: 212 (20.04%)
‚úÖ Assigned 2 unique batch labels (Train/Test Split).


In [5]:
# ‚úÖ **Extract Numeric Age from Development Stage (within the Sampled Data)**
age_col = "development_stage"
if "numeric_age" not in adata_sample.obs.columns:
    adata_sample.obs["numeric_age"] = (
        adata_sample.obs[age_col]
        .astype(str)
        .str.extract(r"(\d+)")  # Extracts numeric values from text
        .astype(float)
    )

# ‚úÖ Handle missing values (fill with median to avoid NaNs)
adata_sample.obs["numeric_age"].fillna(adata_sample.obs["numeric_age"].median(), inplace=True)

# ‚úÖ **Convert Numeric Age to Categorical Labels using qcut**
# Here, retbins=True returns the bin edges as well.
age_codes, bin_edges = pd.qcut(
    adata_sample.obs["numeric_age"],
    q=5,
    labels=False,
    retbins=True,
    duplicates="drop"  # In case there are duplicate bin edges
)

adata_sample.obs["age_id"] = age_codes.astype(np.int64)  # Ensure integer format

# ‚úÖ **Create Age Category Mapping from Bin Edges**
id2type = {
    i: f"{bin_edges[i]:.1f} - {bin_edges[i+1]:.1f}"
    for i in range(len(bin_edges) - 1)
}

# ‚úÖ Convert Age Labels to PyTorch Tensor and Move to GPU
age_labels_tensor = torch.tensor(
    adata_sample.obs["age_id"].values, dtype=torch.long, device="cuda"
)

# ‚úÖ Debugging: Count Unique Age Groups and Ensure Proper Formatting
num_types = len(id2type)
print(f"‚úÖ Found {num_types} unique age categories.")
print(f"üîπ Age Tensor Shape: {age_labels_tensor.shape}")
print(f"üîπ First 10 Age Labels: {age_labels_tensor[:10].tolist()}")


‚úÖ Found 5 unique age categories.
üîπ Age Tensor Shape: torch.Size([1058])
üîπ First 10 Age Labels: [1, 3, 3, 0, 3, 1, 2, 0, 2, 0]


In [6]:
# ‚úÖ **Ensure `input_style` is correctly mapped to available preprocessing layers**
input_style = getattr(config, "input_style", "binned")  # Default to "binned" if missing

input_layer_key = {
    "normed_raw": "X_normed",
    "log1p": "X_log1p",
    "binned": "X_binned",
}.get(input_style, "X_binned")  # Default to "X_binned" if invalid

# ‚úÖ **Ensure preprocessing was applied correctly**
if input_layer_key not in adata_sample.layers:
    raise KeyError(f"‚ùå Layer {input_layer_key} not found in adata_sample! Available: {adata_sample.layers.keys()}")

# ‚úÖ **Convert Gene Symbols to Vocab Indices**
gene_symbols = adata_sample.var["gene_symbol"].tolist()
gene_ids_np = np.array([vocab.get(gene, vocab.get(pad_token, 0)) for gene in gene_symbols], dtype=np.int64)  # Map to vocab indices

# ‚úÖ **Tokenize and Pad Gene Expression Data**
tokenized_data = tokenize_and_pad_batch(
    adata_sample.layers[input_layer_key],  # Select appropriate input style
    gene_ids=gene_ids_np,  # Now using vocab-mapped indices
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=0,
    append_cls=True,  # Append <cls> token at the beginning
    include_zero_gene=config.include_zero_gene,
)

# ‚úÖ **Extract Tokenized Values**
gene_ids = tokenized_data["genes"]
input_values = tokenized_data["values"]

# ‚úÖ **Randomly Mask Some Values for MLM (If Enabled)**
if config.MLM:  # Ensure MLM is fetched from config
    input_values = random_mask_value(
        input_values,
        mask_ratio=config.mask_ratio,
        mask_value="auto",  # Always set to auto for masked values
        pad_value=0,
    )

# ‚úÖ **Extract PyTorch Tensors and Move to GPU**
device = "cuda" if torch.cuda.is_available() else "cpu"

gene_ids_tensor = torch.tensor(gene_ids, dtype=torch.long, device=device)
values_tensor = torch.tensor(input_values, dtype=torch.float32, device=device)
target_values_tensor = torch.tensor(tokenized_data["values"], dtype=torch.float32, device=device)
batch_labels_tensor = torch.tensor(adata_sample.obs["batch_id"].values, dtype=torch.long, device=device)
age_labels_tensor = torch.tensor(adata_sample.obs["age_id"].values, dtype=torch.long, device=device)

# ‚úÖ **Split Each Tensor into Train & Test Using Correct Indices**
obs_index = np.arange(adata_sample.n_obs)  # Use numerical indices

train_idx, test_idx = train_test_split(obs_index, test_size=0.2, shuffle=True, random_state=42)

# ‚úÖ Convert indices to NumPy arrays for efficiency
train_idx = np.array(train_idx, dtype=np.int64)
test_idx = np.array(test_idx, dtype=np.int64)

# ‚úÖ **Create Train/Test Masks**
train_mask = torch.tensor(np.isin(obs_index, train_idx), dtype=torch.bool, device=device)
test_mask = torch.tensor(np.isin(obs_index, test_idx), dtype=torch.bool, device=device)

# ‚úÖ **Ensure Train/Test Split Uses the Correct Sampled Indices**
train_data = {
    "gene_ids": gene_ids_tensor[train_mask],
    "values": values_tensor[train_mask],
    "target_values": target_values_tensor[train_mask],
    "batch_labels": batch_labels_tensor[train_mask],
    "age_labels": age_labels_tensor[train_mask],
}

test_data = {
    "gene_ids": gene_ids_tensor[test_mask],
    "values": values_tensor[test_mask],
    "target_values": target_values_tensor[test_mask],
    "batch_labels": batch_labels_tensor[test_mask],
    "age_labels": age_labels_tensor[test_mask],
}

# ‚úÖ Debugging: Print Train/Test Set Sizes
print(f"‚úÖ Tokenization Complete! Train Data: {train_data['gene_ids'].shape[0]} samples, Test Data: {test_data['gene_ids'].shape[0]} samples.")


‚úÖ Tokenization Complete! Train Data: 846 samples, Test Data: 212 samples.


In [7]:
# ‚úÖ **Free Up GPU Memory Before Initializing Model**
torch.cuda.empty_cache()  # Prevent memory fragmentation issues

# ‚úÖ **Initialize Transformer Model**
model = TransformerModel(
    ntoken=len(vocab),  # Size of vocabulary
    d_model=config.layer_size,  # Embedding dimension
    nhead=config.nhead,  # Number of attention heads
    d_hid=config.layer_size,  # Hidden layer size
    nlayers=config.nlayers,  # Number of Transformer layers
    nlayers_cls=3,  # Number of classification layers
    n_cls=num_types if CLS else 1,  # Classification categories
    vocab=vocab,  # Vocabulary mapping
    dropout=config.dropout,  # Dropout probability
    pad_token=pad_token,  # Padding token
    pad_value=0,  # Padding value for embeddings
    do_mvc=MVC,  # Enable masked value prediction
    use_batch_labels=INPUT_BATCH_LABELS,  # Include batch labels
    num_batch_labels=num_batch_labels,  # Number of batch labels
    domain_spec_batchnorm=config.DSBN,  # Enable domain-specific batch norm
    input_emb_style=input_emb_style,  # Input embedding type
    n_input_bins=config.n_bins + 2,  # Number of input bins
    cell_emb_style=cell_emb_style,  # Cell embedding type
    mvc_decoder_style=mvc_decoder_style,  # MVC decoder style
    ecs_threshold=ecs_threshold,  # Elastic cell similarity threshold
    explicit_zero_prob=explicit_zero_prob,  # Explicit zero probability modeling
    use_fast_transformer=fast_transformer,  # Use optimized Transformer
    fast_transformer_backend=fast_transformer_backend,  # Transformer backend
    pre_norm=config.pre_norm,  # Pre-normalization setting
)

# ‚úÖ **Multi-GPU Support: Wrap Model with DataParallel**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

if num_gpus > 1:
    model = torch.nn.DataParallel(model)  # ‚úÖ Enables training on multiple GPUs

# ‚úÖ **Move Model to GPU**
model.to(device)

print(f"‚úÖ Transformer Model Initialized on {device} using {num_gpus} GPU(s).")


‚úÖ Transformer Model Initialized on cuda using 3 GPU(s).


In [8]:
# ‚úÖ Move Model to GPU/CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

# ‚úÖ Apply DataParallel if multiple GPUs are available
if num_gpus > 1:
    model = torch.nn.DataParallel(model)

model.to(device)

# ‚úÖ Load Pretrained Weights (If Available)
if config.load_model:
    model_checkpoint = Path(config.load_model) / "best_model.pt"
    
    try:
        # ‚úÖ Load checkpoint while handling multi-GPU and single-GPU compatibility
        state_dict = torch.load(model_checkpoint, map_location=device)

        # ‚úÖ Handle DataParallel model loading mismatch
        if num_gpus > 1:
            # If model is wrapped in DataParallel, adjust the state_dict keys
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                new_state_dict[k.replace("module.", "")] = v  # Remove "module." prefix
            model.load_state_dict(new_state_dict, strict=False)
        else:
            model.load_state_dict(state_dict, strict=False)

        print(f"‚úÖ Loaded pretrained model from {model_checkpoint} on {num_gpus} GPU(s).")

    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Failed to load full model. {e}")


‚úÖ Loaded pretrained model from /data/cellular_aging/references/scGPT_human_pretrained_model/best_model.pt on 3 GPU(s).


In [9]:
import logging

# ‚úÖ Configure Logger
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ‚úÖ Helper function to count trainable parameters
def count_params(model):
    """Count trainable parameters, handling DataParallel models correctly."""
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # Extract actual model for parameter counting
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# ‚úÖ Count Parameters Before Freezing
pre_freeze = count_params(model)

# ‚úÖ Handle Freezing Logic (Works with Multi-GPU)
if config.freeze:
    if isinstance(model, torch.nn.DataParallel):  
        model_to_freeze = model.module  # Extract actual model from DataParallel
    else:
        model_to_freeze = model

    for name, p in model_to_freeze.named_parameters():
        if "encoder" in name and "transformer_encoder" not in name:
            print(f"üîí Freezing weights for: {name}")
            p.requires_grad = False

# ‚úÖ Count Parameters After Freezing
post_freeze = count_params(model)

# ‚úÖ Log Parameter Changes
logger.info(f"üìå Pre-freeze: {pre_freeze:,}, Post-freeze: {post_freeze:,}")  # Format with commas for readability

# ‚úÖ Log to WandB
wandb.log({
    "info/pre_freeze_param_count": pre_freeze,
    "info/post_freeze_param_count": post_freeze
})


2025-02-20 07:45:50,051 - INFO - üìå Pre-freeze: 51,334,150, Post-freeze: 51,334,150


In [10]:
def define_wandb_metrics():
    """
    Define WandB metrics to monitor training and validation performance.
    These metrics help track key performance indicators such as mean squared error (MSE),
    mean relative error (MRE), and any additional metrics (like domain adaptation loss or test metrics).
    """
    try:
        # ‚úÖ Ensure WandB is initialized before defining metrics
        if wandb.run is None:
            logger.warning("‚ö†Ô∏è WandB is not initialized. Skipping metric definition.")
            return
        
        # ‚úÖ Define primary validation metrics
        wandb.define_metric("epoch")  # Ensure epoch is tracked correctly
        wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
        wandb.define_metric("valid/mre", summary="min", step_metric="epoch")
        wandb.define_metric("valid/dab", summary="min", step_metric="epoch")
        wandb.define_metric("valid/sum_mse_dab", summary="min", step_metric="epoch")
        wandb.define_metric("test/avg_bio", summary="max")

        logger.info("‚úÖ WandB metrics successfully defined.")

    except Exception as e:
        logger.error(f"‚ùå Error defining WandB metrics: {e}")


In [11]:
def train(model: nn.Module, loader: DataLoader, epoch: int) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    
    # ‚úÖ Handle DataParallel model for multi-GPU training
    if isinstance(model, torch.nn.DataParallel):
        model_to_train = model.module  # Extract actual model
    else:
        model_to_train = model

    (
        total_loss,
        total_mse,
        total_cls,
        total_cce,
        total_mvc,
        total_ecs,
        total_dab,
        total_adv_E,
        total_adv_D,
        total_zero_log_prob,
        total_mvc_zero_log_prob,
    ) = (0.0,) * 11
    
    total_error = 0.0
    start_time = time.time()
    num_batches = len(loader)

    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device, non_blocking=True)
        input_values = batch_data["values"].to(device, non_blocking=True)
        target_values = batch_data["target_values"].to(device, non_blocking=True)
        batch_labels = batch_data["batch_labels"].to(device, non_blocking=True)
        age_labels = batch_data["age_labels"].to(device, non_blocking=True)
        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])

        with torch.cuda.amp.autocast(enabled=config.amp):  # ‚úÖ Automatic Mixed Precision (AMP) for efficiency
            output_dict = model_to_train(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
                do_sample=do_sample_in_train,
            )
            masked_positions = input_values.eq(mask_value)  # Identify positions to predict
            loss = 0.0
            metrics_to_log = {}

            # ‚úÖ Masked Language Modeling Loss
            if MLM:
                loss_mse = criterion(output_dict["mlm_output"], target_values, masked_positions)
                loss += loss_mse
                metrics_to_log["train/mse"] = loss_mse.item()

            # ‚úÖ Explicit Zero Probability Loss
            if explicit_zero_prob:
                loss_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mlm_zero_probs"], target_values, masked_positions
                )
                loss += loss_zero_log_prob
                metrics_to_log["train/nzlp"] = loss_zero_log_prob.item()

            # ‚úÖ Classification Loss
            if CLS:
                loss_cls = criterion_cls(output_dict["cls_output"], age_labels)
                loss += loss_cls
                metrics_to_log["train/cls"] = loss_cls.item()
                error_rate = 1 - (output_dict["cls_output"].argmax(1) == age_labels).sum().item() / age_labels.size(0)

            # ‚úÖ Contrastive Cell Embedding Loss
            if CCE:
                loss_cce = 10 * output_dict["loss_cce"]
                loss += loss_cce
                metrics_to_log["train/cce"] = loss_cce.item()

            # ‚úÖ Masked Value Prediction Loss
            if MVC:
                loss_mvc = criterion(output_dict["mvc_output"], target_values, masked_positions)
                loss += loss_mvc
                metrics_to_log["train/mvc"] = loss_mvc.item()

            if MVC and explicit_zero_prob:
                loss_mvc_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mvc_zero_probs"], target_values, masked_positions
                )
                loss += loss_mvc_zero_log_prob
                metrics_to_log["train/mvc_nzlp"] = loss_mvc_zero_log_prob.item()

            # ‚úÖ Elastic Cell Similarity Loss
            if ECS:
                loss_ecs = 10 * output_dict["loss_ecs"]
                loss += loss_ecs
                metrics_to_log["train/ecs"] = loss_ecs.item()

            # ‚úÖ Domain Adaptation Loss
            if DAB:
                loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
                loss += dab_weight * loss_dab
                metrics_to_log["train/dab"] = loss_dab.item()

        # ‚úÖ Backpropagation
        optimizer.zero_grad(set_to_none=True)  # More efficient than model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)

        # ‚úÖ Clip Gradients to Prevent Exploding Gradients
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), 1.0, 
            error_if_nonfinite=False if scaler.is_enabled() else True
        )

        # ‚úÖ Step the Optimizer and Update the Scaler
        scaler.step(optimizer)
        scaler.update()

        # ‚úÖ Adversarial Training (If Enabled)
        if ADV:
            output_dict = model_to_train(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                CLS=CLS,
                CCE=CCE,
                MVC=MVC,
                ECS=ECS,
                do_sample=do_sample_in_train,
            )

            # ‚úÖ Train Discriminator
            loss_adv_D = criterion_adv(discriminator(output_dict["cell_emb"].detach()), batch_labels)
            if epoch > adv_D_delay_epochs:
                optimizer_D.zero_grad(set_to_none=True)
                loss_adv_D.backward()
                optimizer_D.step()

            # ‚úÖ Train Encoder
            loss_adv_E = -criterion_adv(discriminator(output_dict["cell_emb"]), batch_labels)
            if epoch > adv_E_delay_epochs:
                optimizer_E.zero_grad(set_to_none=True)
                loss_adv_E.backward()
                optimizer_E.step()

        # ‚úÖ Log Metrics
        wandb.log(metrics_to_log)

        # ‚úÖ Aggregate Losses
        total_loss += loss.item()
        total_mse += loss_mse.item() if MLM else 0.0
        total_cls += loss_cls.item() if CLS else 0.0
        total_cce += loss_cce.item() if CCE else 0.0
        total_mvc += loss_mvc.item() if MVC else 0.0
        total_ecs += loss_ecs.item() if ECS else 0.0
        total_dab += loss_dab.item() if DAB else 0.0
        total_adv_E += loss_adv_E.item() if ADV else 0.0
        total_adv_D += loss_adv_D.item() if ADV else 0.0
        total_zero_log_prob += loss_zero_log_prob.item() if explicit_zero_prob else 0.0
        total_mvc_zero_log_prob += loss_mvc_zero_log_prob.item() if MVC and explicit_zero_prob else 0.0
        total_error += error_rate

        # ‚úÖ Logging at Intervals
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            logger.info(
                f"| Epoch {epoch:3d} | Batch {batch:3d}/{num_batches:3d} | LR {lr:.5f} | Time {ms_per_batch:.2f}ms | "
                f"Loss {total_loss / log_interval:.4f} | "
                f"MSE {total_mse / log_interval:.4f} | "
                f"CLS {total_cls / log_interval:.4f} | "
                f"CCE {total_cce / log_interval:.4f} | "
                f"MVC {total_mvc / log_interval:.4f} | "
                f"ECS {total_ecs / log_interval:.4f} | "
                f"DAB {total_dab / log_interval:.4f} | "
                f"ADV_E {total_adv_E / log_interval:.4f} | "
                f"ADV_D {total_adv_D / log_interval:.4f} | "
                f"NZLP {total_zero_log_prob / log_interval:.4f} | "
                f"MVC_NZLP {total_mvc_zero_log_prob / log_interval:.4f} | "
                f"Error {total_error / log_interval:.4f}"
            )

            total_loss = total_mse = total_cls = total_cce = total_mvc = total_ecs = total_dab = 0
            total_adv_E = total_adv_D = total_zero_log_prob = total_mvc_zero_log_prob = total_error = 0
            start_time = time.time()


In [12]:
def evaluate(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> Union[float, Tuple[float, float]]:
    """
    Evaluate the model on the validation/test dataset.
    """
    model.eval()
    
    # ‚úÖ Handle DataParallel model for multi-GPU evaluation
    if isinstance(model, torch.nn.DataParallel):
        model_to_eval = model.module  # Extract actual model
    else:
        model_to_eval = model

    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []

    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device, non_blocking=True)
            input_values = batch_data["values"].to(device, non_blocking=True)
            target_values = batch_data["target_values"].to(device, non_blocking=True)
            batch_labels = batch_data["batch_labels"].to(device, non_blocking=True)
            age_labels = batch_data["age_labels"].to(device, non_blocking=True)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model_to_eval(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                    CLS=CLS,  # Only classification used here, disable MLM/CCE
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                )

                output_values = output_dict["cls_output"]
                loss = criterion_cls(output_values, age_labels)

                loss_dab = 0.0
                if DAB:
                    loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            # ‚úÖ Aggregate Metrics
            batch_size = input_gene_ids.shape[0]
            total_loss += loss.item() * batch_size
            accuracy = (output_values.argmax(1) == age_labels).sum().item()
            total_error += (1 - accuracy / batch_size) * batch_size
            total_dab += loss_dab.item() * batch_size if DAB else 0.0
            total_num += batch_size

            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)

    # ‚úÖ Compute Final Metrics
    avg_loss = total_loss / total_num
    avg_error = total_error / total_num
    avg_dab = total_dab / total_num if DAB else 0.0
    sum_mse_dab = avg_loss + dab_weight * avg_dab if DAB else avg_loss

    # ‚úÖ Log metrics to WandB
    wandb.log(
        {
            "valid/mse": avg_loss,
            "valid/err": avg_error,
            "valid/dab": avg_dab,
            "valid/sum_mse_dab": sum_mse_dab,
            "epoch": epoch,
        }
    )

    # ‚úÖ Debugging Output
    logger.info(
        f"‚úÖ Evaluation Complete | Loss: {avg_loss:.4f} | Error: {avg_error:.4f} | DAB: {avg_dab:.4f}"
    )

    if return_raw:
        return np.concatenate(predictions, axis=0)

    return avg_loss, avg_error


In [13]:
def prepare_data(tokenized_train, tokenized_valid, 
                 train_batch_labels, valid_batch_labels, 
                 train_age_labels, valid_age_labels,
                 mask_ratio, mask_value, pad_value, epoch, 
                 sort_seq_batch=False):
    """
    Prepare data for training and validation, optimizing for multi-GPU.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ‚úÖ Apply Random Masking for Masked Language Modeling
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )

    # ‚úÖ Debugging: Print Masking Information
    print(
        f"üîπ Random masking applied at epoch {epoch:3d}, "
        f"ratio of masked values in train: "
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}"
    )

    # ‚úÖ Extract Inputs & Targets
    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    # ‚úÖ Convert Batch Labels & Age Labels to PyTorch Tensors (Move to GPU)
    tensor_batch_labels_train = torch.tensor(train_batch_labels, dtype=torch.long, device=device)
    tensor_batch_labels_valid = torch.tensor(valid_batch_labels, dtype=torch.long, device=device)

    tensor_age_labels_train = torch.tensor(train_age_labels, dtype=torch.long, device=device)
    tensor_age_labels_valid = torch.tensor(valid_age_labels, dtype=torch.long, device=device)

    # ‚úÖ Optional Sorting for Sequence Batching
    if sort_seq_batch:  
        train_sort_ids = np.argsort(train_batch_labels)
        valid_sort_ids = np.argsort(valid_batch_labels)

        input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        input_values_train = input_values_train[train_sort_ids]
        target_values_train = target_values_train[train_sort_ids]
        tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]
        tensor_age_labels_train = tensor_age_labels_train[train_sort_ids]

        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]
        tensor_age_labels_valid = tensor_age_labels_valid[valid_sort_ids]

    # ‚úÖ Convert Input Data to PyTorch Tensors and Move to GPU
    train_data_pt = {
        "gene_ids": torch.tensor(input_gene_ids_train, dtype=torch.long, device=device),
        "values": torch.tensor(input_values_train, dtype=torch.float32, device=device),
        "target_values": torch.tensor(target_values_train, dtype=torch.float32, device=device),
        "batch_labels": tensor_batch_labels_train,
        "age_labels": tensor_age_labels_train,
    }
    
    valid_data_pt = {
        "gene_ids": torch.tensor(input_gene_ids_valid, dtype=torch.long, device=device),
        "values": torch.tensor(input_values_valid, dtype=torch.float32, device=device),
        "target_values": torch.tensor(target_values_valid, dtype=torch.float32, device=device),
        "batch_labels": tensor_batch_labels_valid,
        "age_labels": tensor_age_labels_valid,
    }

    return train_data_pt, valid_data_pt


In [14]:
# ‚úÖ Dataset wrapper
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        # ‚úÖ Ensure tensors are on CPU before passing to DataLoader
        self.data = {k: v.cpu() for k, v in data.items()}

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}




# ‚úÖ Data loader preparation function
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: Optional[int] = None,
) -> DataLoader:
    """
    Prepare the DataLoader optimized for multi-GPU training.
    """

    # ‚úÖ Set number of workers dynamically based on available CPU cores
    if num_workers is None:
        num_workers = max(1, min(os.cpu_count() // 2, batch_size // 2))  # Avoid excessive CPU usage

    # ‚úÖ Check if running on GPU
    use_cuda = torch.cuda.is_available()
    
    # ‚úÖ Pin memory for efficient GPU data transfer
    pin_memory = use_cuda

    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # ‚úÖ Handle per-sequence batch sampling
        subsets = []
        
        # ‚úÖ Ensure batch labels are on CPU before converting to NumPy
        batch_labels_array = data_pt["batch_labels"].cpu().numpy()

        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)

        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=pin_memory,
        )
        return data_loader

    # ‚úÖ Standard DataLoader for non-sequence-batched data
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return data_loader


In [15]:
from torch.utils.data import DataLoader 
import scipy.sparse

# ‚úÖ Define mask_value and pad_value if not defined
mask_value = 0.0
pad_value = 0
DAB_separate_optim = False  # ‚úÖ Set a default value

device = "cuda" if torch.cuda.is_available() else "cpu"

# ‚úÖ Generate observation indices and split into train and test
obs_index = np.arange(adata_sample.n_obs)
train_idx, test_idx = train_test_split(obs_index, test_size=0.2, random_state=42)
train_idx = np.array(train_idx, dtype=np.int64)
test_idx = np.array(test_idx, dtype=np.int64)

# ‚úÖ Define classification criterion
criterion_cls = nn.CrossEntropyLoss()

# ‚úÖ Define Optimizer Before Training
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

# ‚úÖ If using a learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

# ‚úÖ Initialize best scores
best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None

# ‚úÖ Define WandB metrics
define_wandb_metrics()

# ‚úÖ Training Loop
for epoch in range(1, config.epochs + 1):    
    epoch_start_time = time.time()

    # ‚úÖ Extract gene symbols and map to vocab indices
    filtered_gene_symbols = adata_sample.var["gene_symbol"].tolist()
    filtered_gene_ids = np.array(
        [vocab.get(gene, vocab.get("<pad>", 0)) for gene in filtered_gene_symbols],
        dtype=np.int64,
    )

    print(f"üîç Total Genes: {len(filtered_gene_symbols)} | Mapped: {filtered_gene_ids.shape[0]} | First 10 IDs: {filtered_gene_ids[:10]}")

    # ‚úÖ Select the correct preprocessed data layer
    if "X_binned" in adata_sample.layers:
        data_layer = adata_sample.layers["X_binned"]
    elif "X_normed" in adata_sample.layers:
        data_layer = adata_sample.layers["X_normed"]
    elif "X_log1p" in adata_sample.layers:
        data_layer = adata_sample.layers["X_log1p"]
    else:
        raise ValueError("‚ùå No valid processed data layer found in `adata_sample`!")

    # ‚úÖ Convert sparse matrix to dense
    if issparse(data_layer):
        data_layer = data_layer.toarray()

    # ‚úÖ Convert to PyTorch tensor and move to GPU
    values_tensor = torch.tensor(data_layer, dtype=torch.float32, device=device)

    print(f"üîç `values_tensor` shape before filtering: {values_tensor.shape}")

    # ‚úÖ Ensure correct number of features
    num_sampled_genes = len(filtered_gene_ids)
    if values_tensor.shape[1] != num_sampled_genes:
        raise ValueError(
            f"‚ùå Mismatch: `values_tensor` has {values_tensor.shape[1]} features, "
            f"but `filtered_gene_ids` has {num_sampled_genes} genes. Fix dataset filtering."
        )
    
    print(f"‚úÖ `values_tensor` shape after filtering: {values_tensor.shape}")

    # ‚úÖ Convert gene IDs to tensor and move to GPU
    gene_ids_tensor = torch.tensor(filtered_gene_ids, dtype=torch.long, device=device)

    # ‚úÖ Convert train/test data and move to GPU
    train_data = {
        "gene_ids": gene_ids_tensor,  # Same mapping for all cells
        "values": values_tensor[train_idx],
        "target_values": target_values_tensor[train_idx].to(device, non_blocking=True),
        "batch_labels": batch_labels_tensor[train_idx].to(device, non_blocking=True),
        "age_labels": age_labels_tensor[train_idx].to(device, non_blocking=True),
    }
    test_data = {
        "gene_ids": gene_ids_tensor,
        "values": values_tensor[test_idx],
        "target_values": target_values_tensor[test_idx].to(device, non_blocking=True),
        "batch_labels": batch_labels_tensor[test_idx].to(device, non_blocking=True),
        "age_labels": age_labels_tensor[test_idx].to(device, non_blocking=True),
    }

    print(f"üîç Train values shape: {train_data['values'].shape} | Test values shape: {test_data['values'].shape}")

    # ‚úÖ Tokenize train/validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data["values"].cpu().numpy(),  # Convert back to NumPy only when necessary
        gene_ids=filtered_gene_ids.copy(),
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=config.include_zero_gene,
    )
    tokenized_valid = tokenize_and_pad_batch(
        test_data["values"].cpu().numpy(),
        gene_ids=filtered_gene_ids.copy(),
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=config.include_zero_gene,
    )

    # ‚úÖ Prepare data for model input
    train_data_pt, valid_data_pt = prepare_data(
        tokenized_train, tokenized_valid,
        train_batch_labels=train_data["batch_labels"].cpu().numpy(),
        valid_batch_labels=test_data["batch_labels"].cpu().numpy(),
        train_age_labels=train_data["age_labels"].cpu().numpy(),
        valid_age_labels=test_data["age_labels"].cpu().numpy(),
        mask_ratio=config.mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
        epoch=epoch,
        sort_seq_batch=per_seq_batch_sample
    )

    # ‚úÖ Prepare Data Loaders
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=config.batch_size,
        shuffle=True,
        intra_domain_shuffle=True,
        drop_last=False,
        num_workers=0,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
    )

    # ‚úÖ Initialize GradScaler for mixed precision training if enabled
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # ‚úÖ Train the model
    if config.do_train:
        train(model, loader=train_loader, epoch=epoch)

    # ‚úÖ Evaluate the model
    val_loss, val_err = evaluate(model, loader=valid_loader)
    elapsed = time.time() - epoch_start_time

    logger.info("-" * 89)
    logger.info(
        f"| End of Epoch {epoch:3d} | Time: {elapsed:5.2f}s | "
        f"Validation Loss/MSE: {val_loss:.4f} | Error Rate: {val_err:.4f}"
    )
    logger.info("-" * 89)

    # ‚úÖ Track best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        logger.info(f"‚úÖ New Best Model Saved at Epoch {epoch} | Score: {best_val_loss:.4f}")

    # ‚úÖ Step the scheduler(s)
    scheduler.step()
    if DAB_separate_optim:
        scheduler_dab.step()
    if ADV:
        scheduler_D.step()
        scheduler_E.step()


# ‚úÖ Ensure `save_dir` is defined
save_dir = Path("/data/cellular_aging/results/fine-tuning")
save_dir.mkdir(parents=True, exist_ok=True)

# ‚úÖ Define `best_model_path`
best_model_path = save_dir / "best_model.pt"

# ‚úÖ Handle multi-GPU model saving
if isinstance(model, torch.nn.DataParallel):
    torch.save(best_model.module.state_dict(), best_model_path)  # Remove DataParallel wrapper
else:
    torch.save(best_model.state_dict(), best_model_path)

print(f"‚úÖ Best model saved at: {best_model_path}")

2025-02-20 07:45:50,113 - INFO - ‚úÖ WandB metrics successfully defined.


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   1, ratio of masked values in train: 0.1235


2025-02-20 07:46:05,988 - INFO - | Epoch   1 | Batch 100/282 | LR 0.00001 | Time 155.57ms | Loss 1.7141 | MSE 0.0000 | CLS 1.7141 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8333
2025-02-20 07:46:19,726 - INFO - | Epoch   1 | Batch 200/282 | LR 0.00001 | Time 137.38ms | Loss 1.6763 | MSE 0.0000 | CLS 1.6763 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8533
2025-02-20 07:46:37,220 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6419 | Error: 0.7642 | DAB: 0.0000
2025-02-20 07:46:37,222 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:46:37,223 - INFO - | End of Epoch   1 | Time: 47.11s | Validation Loss/MSE: 1.6419 | Error Rate: 0.7642
2025-02-20 07:46:37,223 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:46

üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   2, ratio of masked values in train: 0.1235


2025-02-20 07:46:51,471 - INFO - | Epoch   2 | Batch 100/282 | LR 0.00001 | Time 138.80ms | Loss 1.6549 | MSE 0.0000 | CLS 1.6549 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7667
2025-02-20 07:47:05,399 - INFO - | Epoch   2 | Batch 200/282 | LR 0.00001 | Time 139.27ms | Loss 1.6525 | MSE 0.0000 | CLS 1.6525 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7833
2025-02-20 07:47:22,461 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6647 | Error: 0.8113 | DAB: 0.0000
2025-02-20 07:47:22,463 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:47:22,463 - INFO - | End of Epoch   2 | Time: 45.22s | Validation Loss/MSE: 1.6647 | Error Rate: 0.8113
2025-02-20 07:47:22,464 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   3, ratio of masked values in train: 0.1235


2025-02-20 07:47:36,903 - INFO - | Epoch   3 | Batch 100/282 | LR 0.00001 | Time 141.01ms | Loss 1.6507 | MSE 0.0000 | CLS 1.6507 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7733
2025-02-20 07:47:50,859 - INFO - | Epoch   3 | Batch 200/282 | LR 0.00001 | Time 139.55ms | Loss 1.6227 | MSE 0.0000 | CLS 1.6227 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7667
2025-02-20 07:48:07,898 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6202 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:48:07,900 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:48:07,900 - INFO - | End of Epoch   3 | Time: 45.43s | Validation Loss/MSE: 1.6202 | Error Rate: 0.7972
2025-02-20 07:48:07,901 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:48

üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   4, ratio of masked values in train: 0.1235


2025-02-20 07:48:22,353 - INFO - | Epoch   4 | Batch 100/282 | LR 0.00001 | Time 140.96ms | Loss 1.6368 | MSE 0.0000 | CLS 1.6368 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7900
2025-02-20 07:48:36,132 - INFO - | Epoch   4 | Batch 200/282 | LR 0.00001 | Time 137.78ms | Loss 1.6141 | MSE 0.0000 | CLS 1.6141 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7400
2025-02-20 07:48:53,039 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6195 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:48:53,041 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:48:53,042 - INFO - | End of Epoch   4 | Time: 45.12s | Validation Loss/MSE: 1.6195 | Error Rate: 0.7972
2025-02-20 07:48:53,042 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:48

üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   5, ratio of masked values in train: 0.1235


2025-02-20 07:49:07,477 - INFO - | Epoch   5 | Batch 100/282 | LR 0.00001 | Time 140.76ms | Loss 1.6411 | MSE 0.0000 | CLS 1.6411 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7733
2025-02-20 07:49:21,263 - INFO - | Epoch   5 | Batch 200/282 | LR 0.00001 | Time 137.85ms | Loss 1.6352 | MSE 0.0000 | CLS 1.6352 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8133
2025-02-20 07:49:38,148 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6179 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:49:38,150 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:49:38,151 - INFO - | End of Epoch   5 | Time: 45.09s | Validation Loss/MSE: 1.6179 | Error Rate: 0.7972
2025-02-20 07:49:38,151 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:49

üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   6, ratio of masked values in train: 0.1235


2025-02-20 07:49:52,453 - INFO - | Epoch   6 | Batch 100/282 | LR 0.00001 | Time 139.40ms | Loss 1.6360 | MSE 0.0000 | CLS 1.6360 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7800
2025-02-20 07:50:06,460 - INFO - | Epoch   6 | Batch 200/282 | LR 0.00001 | Time 140.06ms | Loss 1.6251 | MSE 0.0000 | CLS 1.6251 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7633
2025-02-20 07:50:23,386 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6459 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:50:23,388 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:50:23,388 - INFO - | End of Epoch   6 | Time: 45.22s | Validation Loss/MSE: 1.6459 | Error Rate: 0.7972
2025-02-20 07:50:23,389 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   7, ratio of masked values in train: 0.1235


2025-02-20 07:50:37,672 - INFO - | Epoch   7 | Batch 100/282 | LR 0.00001 | Time 139.42ms | Loss 1.6387 | MSE 0.0000 | CLS 1.6387 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8000
2025-02-20 07:50:51,657 - INFO - | Epoch   7 | Batch 200/282 | LR 0.00001 | Time 139.84ms | Loss 1.6100 | MSE 0.0000 | CLS 1.6100 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7700
2025-02-20 07:51:08,588 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6322 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:51:08,590 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:51:08,591 - INFO - | End of Epoch   7 | Time: 45.20s | Validation Loss/MSE: 1.6322 | Error Rate: 0.7972
2025-02-20 07:51:08,591 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   8, ratio of masked values in train: 0.1235


2025-02-20 07:51:22,889 - INFO - | Epoch   8 | Batch 100/282 | LR 0.00001 | Time 139.51ms | Loss 1.6333 | MSE 0.0000 | CLS 1.6333 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7933
2025-02-20 07:51:36,871 - INFO - | Epoch   8 | Batch 200/282 | LR 0.00001 | Time 139.81ms | Loss 1.6242 | MSE 0.0000 | CLS 1.6242 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7667
2025-02-20 07:51:53,804 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6303 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:51:53,806 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:51:53,807 - INFO - | End of Epoch   8 | Time: 45.21s | Validation Loss/MSE: 1.6303 | Error Rate: 0.7972
2025-02-20 07:51:53,808 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch   9, ratio of masked values in train: 0.1235


2025-02-20 07:52:08,073 - INFO - | Epoch   9 | Batch 100/282 | LR 0.00001 | Time 139.15ms | Loss 1.6320 | MSE 0.0000 | CLS 1.6320 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7967
2025-02-20 07:52:22,039 - INFO - | Epoch   9 | Batch 200/282 | LR 0.00001 | Time 139.65ms | Loss 1.6190 | MSE 0.0000 | CLS 1.6190 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7400
2025-02-20 07:52:38,960 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6338 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:52:38,962 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:52:38,963 - INFO - | End of Epoch   9 | Time: 45.15s | Validation Loss/MSE: 1.6338 | Error Rate: 0.7972
2025-02-20 07:52:38,964 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  10, ratio of masked values in train: 0.1235


2025-02-20 07:52:53,193 - INFO - | Epoch  10 | Batch 100/282 | LR 0.00001 | Time 138.89ms | Loss 1.6459 | MSE 0.0000 | CLS 1.6459 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8167
2025-02-20 07:53:07,181 - INFO - | Epoch  10 | Batch 200/282 | LR 0.00001 | Time 139.87ms | Loss 1.6148 | MSE 0.0000 | CLS 1.6148 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7800
2025-02-20 07:53:24,112 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6219 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:53:24,114 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:53:24,114 - INFO - | End of Epoch  10 | Time: 45.15s | Validation Loss/MSE: 1.6219 | Error Rate: 0.7972
2025-02-20 07:53:24,115 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  11, ratio of masked values in train: 0.1235


2025-02-20 07:53:38,403 - INFO - | Epoch  11 | Batch 100/282 | LR 0.00001 | Time 139.48ms | Loss 1.6054 | MSE 0.0000 | CLS 1.6054 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7167
2025-02-20 07:53:52,216 - INFO - | Epoch  11 | Batch 200/282 | LR 0.00001 | Time 138.11ms | Loss 1.6201 | MSE 0.0000 | CLS 1.6201 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8267
2025-02-20 07:54:09,336 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6056 | Error: 0.7642 | DAB: 0.0000
2025-02-20 07:54:09,338 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:54:09,338 - INFO - | End of Epoch  11 | Time: 45.22s | Validation Loss/MSE: 1.6056 | Error Rate: 0.7642
2025-02-20 07:54:09,339 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:54

üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  12, ratio of masked values in train: 0.1235


2025-02-20 07:54:23,639 - INFO - | Epoch  12 | Batch 100/282 | LR 0.00001 | Time 139.39ms | Loss 1.6287 | MSE 0.0000 | CLS 1.6287 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7500
2025-02-20 07:54:37,621 - INFO - | Epoch  12 | Batch 200/282 | LR 0.00001 | Time 139.82ms | Loss 1.6250 | MSE 0.0000 | CLS 1.6250 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7800
2025-02-20 07:54:54,789 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6242 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:54:54,791 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:54:54,791 - INFO - | End of Epoch  12 | Time: 45.43s | Validation Loss/MSE: 1.6242 | Error Rate: 0.7972
2025-02-20 07:54:54,792 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  13, ratio of masked values in train: 0.1235


2025-02-20 07:55:09,074 - INFO - | Epoch  13 | Batch 100/282 | LR 0.00001 | Time 139.39ms | Loss 1.6300 | MSE 0.0000 | CLS 1.6300 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 07:55:22,870 - INFO - | Epoch  13 | Batch 200/282 | LR 0.00001 | Time 137.95ms | Loss 1.6207 | MSE 0.0000 | CLS 1.6207 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7833
2025-02-20 07:55:39,926 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6287 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:55:39,928 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:55:39,929 - INFO - | End of Epoch  13 | Time: 45.14s | Validation Loss/MSE: 1.6287 | Error Rate: 0.7972
2025-02-20 07:55:39,929 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  14, ratio of masked values in train: 0.1235


2025-02-20 07:55:54,388 - INFO - | Epoch  14 | Batch 100/282 | LR 0.00001 | Time 141.10ms | Loss 1.6292 | MSE 0.0000 | CLS 1.6292 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7800
2025-02-20 07:56:08,340 - INFO - | Epoch  14 | Batch 200/282 | LR 0.00001 | Time 139.51ms | Loss 1.6009 | MSE 0.0000 | CLS 1.6009 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7433
2025-02-20 07:56:25,529 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6308 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:56:25,530 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:56:25,531 - INFO - | End of Epoch  14 | Time: 45.60s | Validation Loss/MSE: 1.6308 | Error Rate: 0.7972
2025-02-20 07:56:25,532 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  15, ratio of masked values in train: 0.1235


2025-02-20 07:56:39,800 - INFO - | Epoch  15 | Batch 100/282 | LR 0.00000 | Time 139.25ms | Loss 1.6328 | MSE 0.0000 | CLS 1.6328 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8033
2025-02-20 07:56:53,911 - INFO - | Epoch  15 | Batch 200/282 | LR 0.00000 | Time 141.11ms | Loss 1.6024 | MSE 0.0000 | CLS 1.6024 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7333
2025-02-20 07:57:10,867 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6323 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:57:10,869 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:57:10,870 - INFO - | End of Epoch  15 | Time: 45.34s | Validation Loss/MSE: 1.6323 | Error Rate: 0.7972
2025-02-20 07:57:10,870 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  16, ratio of masked values in train: 0.1235


2025-02-20 07:57:25,310 - INFO - | Epoch  16 | Batch 100/282 | LR 0.00000 | Time 141.00ms | Loss 1.6173 | MSE 0.0000 | CLS 1.6173 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7600
2025-02-20 07:57:39,106 - INFO - | Epoch  16 | Batch 200/282 | LR 0.00000 | Time 137.95ms | Loss 1.6171 | MSE 0.0000 | CLS 1.6171 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7767
2025-02-20 07:57:56,006 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6201 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:57:56,008 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:57:56,008 - INFO - | End of Epoch  16 | Time: 45.14s | Validation Loss/MSE: 1.6201 | Error Rate: 0.7972
2025-02-20 07:57:56,009 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  17, ratio of masked values in train: 0.1235


2025-02-20 07:58:10,456 - INFO - | Epoch  17 | Batch 100/282 | LR 0.00000 | Time 141.03ms | Loss 1.6287 | MSE 0.0000 | CLS 1.6287 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7500
2025-02-20 07:58:24,258 - INFO - | Epoch  17 | Batch 200/282 | LR 0.00000 | Time 138.01ms | Loss 1.5986 | MSE 0.0000 | CLS 1.5986 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7300
2025-02-20 07:58:41,186 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6325 | Error: 0.8113 | DAB: 0.0000
2025-02-20 07:58:41,188 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:58:41,189 - INFO - | End of Epoch  17 | Time: 45.18s | Validation Loss/MSE: 1.6325 | Error Rate: 0.8113
2025-02-20 07:58:41,190 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  18, ratio of masked values in train: 0.1235


2025-02-20 07:58:55,636 - INFO - | Epoch  18 | Batch 100/282 | LR 0.00000 | Time 140.97ms | Loss 1.6149 | MSE 0.0000 | CLS 1.6149 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7733
2025-02-20 07:59:09,425 - INFO - | Epoch  18 | Batch 200/282 | LR 0.00000 | Time 137.88ms | Loss 1.6101 | MSE 0.0000 | CLS 1.6101 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7500
2025-02-20 07:59:26,341 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6144 | Error: 0.7972 | DAB: 0.0000
2025-02-20 07:59:26,343 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 07:59:26,343 - INFO - | End of Epoch  18 | Time: 45.15s | Validation Loss/MSE: 1.6144 | Error Rate: 0.7972
2025-02-20 07:59:26,344 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  19, ratio of masked values in train: 0.1235


2025-02-20 07:59:40,808 - INFO - | Epoch  19 | Batch 100/282 | LR 0.00000 | Time 141.17ms | Loss 1.6196 | MSE 0.0000 | CLS 1.6196 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7767
2025-02-20 07:59:54,623 - INFO - | Epoch  19 | Batch 200/282 | LR 0.00000 | Time 138.14ms | Loss 1.6083 | MSE 0.0000 | CLS 1.6083 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7433
2025-02-20 08:00:11,558 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6252 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:00:11,560 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:00:11,560 - INFO - | End of Epoch  19 | Time: 45.21s | Validation Loss/MSE: 1.6252 | Error Rate: 0.7972
2025-02-20 08:00:11,561 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  20, ratio of masked values in train: 0.1235


2025-02-20 08:00:25,991 - INFO - | Epoch  20 | Batch 100/282 | LR 0.00000 | Time 140.90ms | Loss 1.6278 | MSE 0.0000 | CLS 1.6278 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7967
2025-02-20 08:00:39,772 - INFO - | Epoch  20 | Batch 200/282 | LR 0.00000 | Time 137.80ms | Loss 1.6144 | MSE 0.0000 | CLS 1.6144 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7767
2025-02-20 08:00:56,987 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6218 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:00:56,989 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:00:56,989 - INFO - | End of Epoch  20 | Time: 45.43s | Validation Loss/MSE: 1.6218 | Error Rate: 0.7972
2025-02-20 08:00:56,990 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  21, ratio of masked values in train: 0.1235


2025-02-20 08:01:11,401 - INFO - | Epoch  21 | Batch 100/282 | LR 0.00000 | Time 140.67ms | Loss 1.6377 | MSE 0.0000 | CLS 1.6377 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.8100
2025-02-20 08:01:25,186 - INFO - | Epoch  21 | Batch 200/282 | LR 0.00000 | Time 137.84ms | Loss 1.6020 | MSE 0.0000 | CLS 1.6020 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:01:42,128 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6265 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:01:42,130 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:01:42,130 - INFO - | End of Epoch  21 | Time: 45.14s | Validation Loss/MSE: 1.6265 | Error Rate: 0.7972
2025-02-20 08:01:42,131 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  22, ratio of masked values in train: 0.1235


2025-02-20 08:01:56,604 - INFO - | Epoch  22 | Batch 100/282 | LR 0.00000 | Time 141.30ms | Loss 1.6237 | MSE 0.0000 | CLS 1.6237 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:02:10,407 - INFO - | Epoch  22 | Batch 200/282 | LR 0.00000 | Time 138.02ms | Loss 1.5865 | MSE 0.0000 | CLS 1.5865 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7567
2025-02-20 08:02:27,331 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6256 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:02:27,334 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:02:27,335 - INFO - | End of Epoch  22 | Time: 45.20s | Validation Loss/MSE: 1.6256 | Error Rate: 0.7972
2025-02-20 08:02:27,335 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  23, ratio of masked values in train: 0.1235


2025-02-20 08:02:41,778 - INFO - | Epoch  23 | Batch 100/282 | LR 0.00000 | Time 140.96ms | Loss 1.6174 | MSE 0.0000 | CLS 1.6174 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:02:55,566 - INFO - | Epoch  23 | Batch 200/282 | LR 0.00000 | Time 137.86ms | Loss 1.6133 | MSE 0.0000 | CLS 1.6133 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7800
2025-02-20 08:03:12,508 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6254 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:03:12,510 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:03:12,511 - INFO - | End of Epoch  23 | Time: 45.17s | Validation Loss/MSE: 1.6254 | Error Rate: 0.7972
2025-02-20 08:03:12,512 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  24, ratio of masked values in train: 0.1235


2025-02-20 08:03:26,951 - INFO - | Epoch  24 | Batch 100/282 | LR 0.00000 | Time 140.93ms | Loss 1.6210 | MSE 0.0000 | CLS 1.6210 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:03:40,730 - INFO - | Epoch  24 | Batch 200/282 | LR 0.00000 | Time 137.78ms | Loss 1.6137 | MSE 0.0000 | CLS 1.6137 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:03:57,641 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6237 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:03:57,643 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:03:57,644 - INFO - | End of Epoch  24 | Time: 45.13s | Validation Loss/MSE: 1.6237 | Error Rate: 0.7972
2025-02-20 08:03:57,645 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  25, ratio of masked values in train: 0.1235


2025-02-20 08:04:12,068 - INFO - | Epoch  25 | Batch 100/282 | LR 0.00000 | Time 140.83ms | Loss 1.6112 | MSE 0.0000 | CLS 1.6112 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7600
2025-02-20 08:04:25,832 - INFO - | Epoch  25 | Batch 200/282 | LR 0.00000 | Time 137.63ms | Loss 1.6290 | MSE 0.0000 | CLS 1.6290 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7667
2025-02-20 08:04:42,732 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6233 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:04:42,734 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:04:42,735 - INFO - | End of Epoch  25 | Time: 45.09s | Validation Loss/MSE: 1.6233 | Error Rate: 0.7972
2025-02-20 08:04:42,735 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  26, ratio of masked values in train: 0.1235


2025-02-20 08:04:57,173 - INFO - | Epoch  26 | Batch 100/282 | LR 0.00000 | Time 140.86ms | Loss 1.6216 | MSE 0.0000 | CLS 1.6216 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:05:10,942 - INFO - | Epoch  26 | Batch 200/282 | LR 0.00000 | Time 137.68ms | Loss 1.6076 | MSE 0.0000 | CLS 1.6076 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7700
2025-02-20 08:05:28,115 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6270 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:05:28,117 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:05:28,118 - INFO - | End of Epoch  26 | Time: 45.38s | Validation Loss/MSE: 1.6270 | Error Rate: 0.7972
2025-02-20 08:05:28,119 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  27, ratio of masked values in train: 0.1235


2025-02-20 08:05:42,365 - INFO - | Epoch  27 | Batch 100/282 | LR 0.00000 | Time 139.04ms | Loss 1.6282 | MSE 0.0000 | CLS 1.6282 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7567
2025-02-20 08:05:56,297 - INFO - | Epoch  27 | Batch 200/282 | LR 0.00000 | Time 139.31ms | Loss 1.6032 | MSE 0.0000 | CLS 1.6032 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7500
2025-02-20 08:06:13,346 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6303 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:06:13,347 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:06:13,348 - INFO - | End of Epoch  27 | Time: 45.23s | Validation Loss/MSE: 1.6303 | Error Rate: 0.7972
2025-02-20 08:06:13,349 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  28, ratio of masked values in train: 0.1235


2025-02-20 08:06:27,762 - INFO - | Epoch  28 | Batch 100/282 | LR 0.00000 | Time 140.66ms | Loss 1.6142 | MSE 0.0000 | CLS 1.6142 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7933
2025-02-20 08:06:41,536 - INFO - | Epoch  28 | Batch 200/282 | LR 0.00000 | Time 137.73ms | Loss 1.5974 | MSE 0.0000 | CLS 1.5974 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7367
2025-02-20 08:06:58,432 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6212 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:06:58,434 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:06:58,435 - INFO - | End of Epoch  28 | Time: 45.09s | Validation Loss/MSE: 1.6212 | Error Rate: 0.7972
2025-02-20 08:06:58,436 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  29, ratio of masked values in train: 0.1235


2025-02-20 08:07:12,857 - INFO - | Epoch  29 | Batch 100/282 | LR 0.00000 | Time 140.66ms | Loss 1.6310 | MSE 0.0000 | CLS 1.6310 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7833
2025-02-20 08:07:26,606 - INFO - | Epoch  29 | Batch 200/282 | LR 0.00000 | Time 137.48ms | Loss 1.5866 | MSE 0.0000 | CLS 1.5866 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7200
2025-02-20 08:07:43,499 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6100 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:07:43,501 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:07:43,501 - INFO - | End of Epoch  29 | Time: 45.06s | Validation Loss/MSE: 1.6100 | Error Rate: 0.7972
2025-02-20 08:07:43,501 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  30, ratio of masked values in train: 0.1235


2025-02-20 08:07:57,887 - INFO - | Epoch  30 | Batch 100/282 | LR 0.00000 | Time 140.40ms | Loss 1.6312 | MSE 0.0000 | CLS 1.6312 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7467
2025-02-20 08:08:11,633 - INFO - | Epoch  30 | Batch 200/282 | LR 0.00000 | Time 137.46ms | Loss 1.6031 | MSE 0.0000 | CLS 1.6031 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7667
2025-02-20 08:08:28,500 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6174 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:08:28,502 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:08:28,503 - INFO - | End of Epoch  30 | Time: 45.00s | Validation Loss/MSE: 1.6174 | Error Rate: 0.7972
2025-02-20 08:08:28,504 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  31, ratio of masked values in train: 0.1235


2025-02-20 08:08:42,948 - INFO - | Epoch  31 | Batch 100/282 | LR 0.00000 | Time 140.99ms | Loss 1.6174 | MSE 0.0000 | CLS 1.6174 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7433
2025-02-20 08:08:56,738 - INFO - | Epoch  31 | Batch 200/282 | LR 0.00000 | Time 137.89ms | Loss 1.6026 | MSE 0.0000 | CLS 1.6026 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7733
2025-02-20 08:09:13,675 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6230 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:09:13,677 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:09:13,677 - INFO - | End of Epoch  31 | Time: 45.17s | Validation Loss/MSE: 1.6230 | Error Rate: 0.7972
2025-02-20 08:09:13,678 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  32, ratio of masked values in train: 0.1235


2025-02-20 08:09:28,127 - INFO - | Epoch  32 | Batch 100/282 | LR 0.00000 | Time 141.03ms | Loss 1.6223 | MSE 0.0000 | CLS 1.6223 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7467
2025-02-20 08:09:41,930 - INFO - | Epoch  32 | Batch 200/282 | LR 0.00000 | Time 138.02ms | Loss 1.5984 | MSE 0.0000 | CLS 1.5984 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:09:58,870 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6177 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:09:58,872 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:09:58,873 - INFO - | End of Epoch  32 | Time: 45.19s | Validation Loss/MSE: 1.6177 | Error Rate: 0.7972
2025-02-20 08:09:58,874 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  33, ratio of masked values in train: 0.1235


2025-02-20 08:10:13,335 - INFO - | Epoch  33 | Batch 100/282 | LR 0.00000 | Time 141.20ms | Loss 1.6034 | MSE 0.0000 | CLS 1.6034 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7100
2025-02-20 08:10:27,144 - INFO - | Epoch  33 | Batch 200/282 | LR 0.00000 | Time 138.09ms | Loss 1.6170 | MSE 0.0000 | CLS 1.6170 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7700
2025-02-20 08:10:44,082 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6121 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:10:44,085 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:10:44,085 - INFO - | End of Epoch  33 | Time: 45.21s | Validation Loss/MSE: 1.6121 | Error Rate: 0.7972
2025-02-20 08:10:44,086 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  34, ratio of masked values in train: 0.1235


2025-02-20 08:10:58,542 - INFO - | Epoch  34 | Batch 100/282 | LR 0.00000 | Time 141.11ms | Loss 1.6234 | MSE 0.0000 | CLS 1.6234 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7867
2025-02-20 08:11:12,341 - INFO - | Epoch  34 | Batch 200/282 | LR 0.00000 | Time 137.98ms | Loss 1.6048 | MSE 0.0000 | CLS 1.6048 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7600
2025-02-20 08:11:29,295 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6270 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:11:29,297 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:11:29,298 - INFO - | End of Epoch  34 | Time: 45.21s | Validation Loss/MSE: 1.6270 | Error Rate: 0.7972
2025-02-20 08:11:29,299 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  35, ratio of masked values in train: 0.1235


2025-02-20 08:11:43,775 - INFO - | Epoch  35 | Batch 100/282 | LR 0.00000 | Time 141.22ms | Loss 1.6285 | MSE 0.0000 | CLS 1.6285 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7767
2025-02-20 08:11:57,582 - INFO - | Epoch  35 | Batch 200/282 | LR 0.00000 | Time 138.05ms | Loss 1.6123 | MSE 0.0000 | CLS 1.6123 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7600
2025-02-20 08:12:14,511 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6276 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:12:14,513 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:12:14,514 - INFO - | End of Epoch  35 | Time: 45.21s | Validation Loss/MSE: 1.6276 | Error Rate: 0.7972
2025-02-20 08:12:14,514 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  36, ratio of masked values in train: 0.1235


2025-02-20 08:12:28,970 - INFO - | Epoch  36 | Batch 100/282 | LR 0.00000 | Time 140.99ms | Loss 1.6212 | MSE 0.0000 | CLS 1.6212 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7767
2025-02-20 08:12:42,770 - INFO - | Epoch  36 | Batch 200/282 | LR 0.00000 | Time 137.99ms | Loss 1.6101 | MSE 0.0000 | CLS 1.6101 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7567
2025-02-20 08:12:59,676 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6217 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:12:59,678 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:12:59,679 - INFO - | End of Epoch  36 | Time: 45.16s | Validation Loss/MSE: 1.6217 | Error Rate: 0.7972
2025-02-20 08:12:59,679 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  37, ratio of masked values in train: 0.1235


2025-02-20 08:13:14,108 - INFO - | Epoch  37 | Batch 100/282 | LR 0.00000 | Time 140.76ms | Loss 1.6166 | MSE 0.0000 | CLS 1.6166 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7500
2025-02-20 08:13:27,897 - INFO - | Epoch  37 | Batch 200/282 | LR 0.00000 | Time 137.89ms | Loss 1.5991 | MSE 0.0000 | CLS 1.5991 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7400
2025-02-20 08:13:44,835 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6216 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:13:44,837 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:13:44,838 - INFO - | End of Epoch  37 | Time: 45.16s | Validation Loss/MSE: 1.6216 | Error Rate: 0.7972
2025-02-20 08:13:44,838 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  38, ratio of masked values in train: 0.1235


2025-02-20 08:13:59,314 - INFO - | Epoch  38 | Batch 100/282 | LR 0.00000 | Time 141.24ms | Loss 1.6226 | MSE 0.0000 | CLS 1.6226 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7867
2025-02-20 08:14:13,120 - INFO - | Epoch  38 | Batch 200/282 | LR 0.00000 | Time 138.05ms | Loss 1.5837 | MSE 0.0000 | CLS 1.5837 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7033
2025-02-20 08:14:30,055 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6159 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:14:30,057 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:14:30,058 - INFO - | End of Epoch  38 | Time: 45.22s | Validation Loss/MSE: 1.6159 | Error Rate: 0.7972
2025-02-20 08:14:30,059 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  39, ratio of masked values in train: 0.1235


2025-02-20 08:14:44,489 - INFO - | Epoch  39 | Batch 100/282 | LR 0.00000 | Time 140.87ms | Loss 1.6325 | MSE 0.0000 | CLS 1.6325 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7967
2025-02-20 08:14:58,268 - INFO - | Epoch  39 | Batch 200/282 | LR 0.00000 | Time 137.78ms | Loss 1.5916 | MSE 0.0000 | CLS 1.5916 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7600
2025-02-20 08:15:15,448 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6238 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:15:15,449 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:15:15,450 - INFO - | End of Epoch  39 | Time: 45.39s | Validation Loss/MSE: 1.6238 | Error Rate: 0.7972
2025-02-20 08:15:15,451 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  40, ratio of masked values in train: 0.1235


2025-02-20 08:15:29,695 - INFO - | Epoch  40 | Batch 100/282 | LR 0.00000 | Time 138.99ms | Loss 1.6268 | MSE 0.0000 | CLS 1.6268 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7633
2025-02-20 08:15:43,628 - INFO - | Epoch  40 | Batch 200/282 | LR 0.00000 | Time 139.32ms | Loss 1.6016 | MSE 0.0000 | CLS 1.6016 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7567
2025-02-20 08:16:00,682 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6235 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:16:00,684 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:16:00,685 - INFO - | End of Epoch  40 | Time: 45.23s | Validation Loss/MSE: 1.6235 | Error Rate: 0.7972
2025-02-20 08:16:00,685 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  41, ratio of masked values in train: 0.1235


2025-02-20 08:16:15,137 - INFO - | Epoch  41 | Batch 100/282 | LR 0.00000 | Time 141.02ms | Loss 1.6316 | MSE 0.0000 | CLS 1.6316 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7700
2025-02-20 08:16:28,924 - INFO - | Epoch  41 | Batch 200/282 | LR 0.00000 | Time 137.86ms | Loss 1.5943 | MSE 0.0000 | CLS 1.5943 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7700
2025-02-20 08:16:45,851 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6234 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:16:45,853 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:16:45,854 - INFO - | End of Epoch  41 | Time: 45.17s | Validation Loss/MSE: 1.6234 | Error Rate: 0.7972
2025-02-20 08:16:45,854 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  42, ratio of masked values in train: 0.1235


2025-02-20 08:17:00,287 - INFO - | Epoch  42 | Batch 100/282 | LR 0.00000 | Time 140.87ms | Loss 1.6213 | MSE 0.0000 | CLS 1.6213 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7533
2025-02-20 08:17:14,043 - INFO - | Epoch  42 | Batch 200/282 | LR 0.00000 | Time 137.56ms | Loss 1.6170 | MSE 0.0000 | CLS 1.6170 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7900
2025-02-20 08:17:30,959 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6254 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:17:30,961 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:17:30,961 - INFO - | End of Epoch  42 | Time: 45.11s | Validation Loss/MSE: 1.6254 | Error Rate: 0.7972
2025-02-20 08:17:30,962 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  43, ratio of masked values in train: 0.1235


2025-02-20 08:17:45,406 - INFO - | Epoch  43 | Batch 100/282 | LR 0.00000 | Time 140.97ms | Loss 1.6151 | MSE 0.0000 | CLS 1.6151 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7800
2025-02-20 08:17:59,173 - INFO - | Epoch  43 | Batch 200/282 | LR 0.00000 | Time 137.66ms | Loss 1.6064 | MSE 0.0000 | CLS 1.6064 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7333
2025-02-20 08:18:16,059 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6211 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:18:16,061 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:18:16,062 - INFO - | End of Epoch  43 | Time: 45.10s | Validation Loss/MSE: 1.6211 | Error Rate: 0.7972
2025-02-20 08:18:16,062 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  44, ratio of masked values in train: 0.1235


2025-02-20 08:18:30,491 - INFO - | Epoch  44 | Batch 100/282 | LR 0.00000 | Time 140.89ms | Loss 1.6203 | MSE 0.0000 | CLS 1.6203 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7600
2025-02-20 08:18:44,256 - INFO - | Epoch  44 | Batch 200/282 | LR 0.00000 | Time 137.64ms | Loss 1.5949 | MSE 0.0000 | CLS 1.5949 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7467
2025-02-20 08:19:01,145 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6188 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:19:01,147 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:19:01,148 - INFO - | End of Epoch  44 | Time: 45.08s | Validation Loss/MSE: 1.6188 | Error Rate: 0.7972
2025-02-20 08:19:01,148 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  45, ratio of masked values in train: 0.1235


2025-02-20 08:19:15,570 - INFO - | Epoch  45 | Batch 100/282 | LR 0.00000 | Time 140.74ms | Loss 1.6132 | MSE 0.0000 | CLS 1.6132 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7733
2025-02-20 08:19:29,329 - INFO - | Epoch  45 | Batch 200/282 | LR 0.00000 | Time 137.58ms | Loss 1.5953 | MSE 0.0000 | CLS 1.5953 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7367
2025-02-20 08:19:46,515 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6183 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:19:46,517 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:19:46,517 - INFO - | End of Epoch  45 | Time: 45.37s | Validation Loss/MSE: 1.6183 | Error Rate: 0.7972
2025-02-20 08:19:46,518 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  46, ratio of masked values in train: 0.1235


2025-02-20 08:20:00,920 - INFO - | Epoch  46 | Batch 100/282 | LR 0.00000 | Time 140.57ms | Loss 1.6189 | MSE 0.0000 | CLS 1.6189 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7633
2025-02-20 08:20:14,698 - INFO - | Epoch  46 | Batch 200/282 | LR 0.00000 | Time 137.77ms | Loss 1.5931 | MSE 0.0000 | CLS 1.5931 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7267
2025-02-20 08:20:31,610 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6207 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:20:31,612 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:20:31,612 - INFO - | End of Epoch  46 | Time: 45.09s | Validation Loss/MSE: 1.6207 | Error Rate: 0.7972
2025-02-20 08:20:31,613 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  47, ratio of masked values in train: 0.1235


2025-02-20 08:20:46,038 - INFO - | Epoch  47 | Batch 100/282 | LR 0.00000 | Time 140.78ms | Loss 1.6059 | MSE 0.0000 | CLS 1.6059 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7400
2025-02-20 08:20:59,808 - INFO - | Epoch  47 | Batch 200/282 | LR 0.00000 | Time 137.69ms | Loss 1.6013 | MSE 0.0000 | CLS 1.6013 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7567
2025-02-20 08:21:16,716 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6197 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:21:16,718 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:21:16,719 - INFO - | End of Epoch  47 | Time: 45.10s | Validation Loss/MSE: 1.6197 | Error Rate: 0.7972
2025-02-20 08:21:16,719 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  48, ratio of masked values in train: 0.1235


2025-02-20 08:21:31,137 - INFO - | Epoch  48 | Batch 100/282 | LR 0.00000 | Time 140.69ms | Loss 1.6133 | MSE 0.0000 | CLS 1.6133 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7467
2025-02-20 08:21:44,897 - INFO - | Epoch  48 | Batch 200/282 | LR 0.00000 | Time 137.59ms | Loss 1.6070 | MSE 0.0000 | CLS 1.6070 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7633
2025-02-20 08:22:01,805 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6177 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:22:01,807 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:22:01,807 - INFO - | End of Epoch  48 | Time: 45.09s | Validation Loss/MSE: 1.6177 | Error Rate: 0.7972
2025-02-20 08:22:01,808 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  49, ratio of masked values in train: 0.1235


2025-02-20 08:22:16,239 - INFO - | Epoch  49 | Batch 100/282 | LR 0.00000 | Time 140.83ms | Loss 1.6247 | MSE 0.0000 | CLS 1.6247 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7667
2025-02-20 08:22:29,979 - INFO - | Epoch  49 | Batch 200/282 | LR 0.00000 | Time 137.39ms | Loss 1.5904 | MSE 0.0000 | CLS 1.5904 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7300
2025-02-20 08:22:46,874 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6192 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:22:46,876 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:22:46,877 - INFO - | End of Epoch  49 | Time: 45.07s | Validation Loss/MSE: 1.6192 | Error Rate: 0.7972
2025-02-20 08:22:46,877 - INFO - -----------------------------------------------------------------------------------------


üîç Total Genes: 23794 | Mapped: 23794 | First 10 IDs: [16625  8892 18555 18552 18549 14169  9123 13965 13453  9070]
üîç `values_tensor` shape before filtering: torch.Size([1058, 23794])
‚úÖ `values_tensor` shape after filtering: torch.Size([1058, 23794])
üîç Train values shape: torch.Size([846, 23794]) | Test values shape: torch.Size([212, 23794])
üîπ Random masking applied at epoch  50, ratio of masked values in train: 0.1235


2025-02-20 08:23:01,271 - INFO - | Epoch  50 | Batch 100/282 | LR 0.00000 | Time 140.47ms | Loss 1.6162 | MSE 0.0000 | CLS 1.6162 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7333
2025-02-20 08:23:15,019 - INFO - | Epoch  50 | Batch 200/282 | LR 0.00000 | Time 137.47ms | Loss 1.5895 | MSE 0.0000 | CLS 1.5895 | CCE 0.0000 | MVC 0.0000 | ECS 0.0000 | DAB 0.0000 | ADV_E 0.0000 | ADV_D 0.0000 | NZLP 0.0000 | MVC_NZLP 0.0000 | Error 0.7400
2025-02-20 08:23:31,930 - INFO - ‚úÖ Evaluation Complete | Loss: 1.6173 | Error: 0.7972 | DAB: 0.0000
2025-02-20 08:23:31,931 - INFO - -----------------------------------------------------------------------------------------
2025-02-20 08:23:31,932 - INFO - | End of Epoch  50 | Time: 45.05s | Validation Loss/MSE: 1.6173 | Error Rate: 0.7972
2025-02-20 08:23:31,932 - INFO - -----------------------------------------------------------------------------------------


‚úÖ Best model saved at: /data/cellular_aging/results/fine-tuning/best_model.pt


In [22]:
def test(model: nn.Module, test_data: Dict[str, torch.Tensor]) -> float:
    """
    Test the model on the test data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []

    # Prepare DataLoader for testing
    test_loader = DataLoader(
        dataset=SeqDataset(test_data),
        batch_size=eval_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=min(len(os.sched_getaffinity(0)), eval_batch_size // 2),
        pin_memory=True,
    )

    with torch.no_grad():
        for batch_data in test_loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)
            age_labels = batch_data["age_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                    CLS=CLS,
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                )
                output_values = output_dict["cls_output"]
                loss = criterion_cls(output_values, age_labels)

                if DAB:
                    loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            accuracy = (output_values.argmax(1) == age_labels).sum().item()
            total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids) if DAB else 0.0
            total_num += len(input_gene_ids)
            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)

    wandb.log(
        {
            "test/mse": total_loss / total_num,
            "test/err": total_error / total_num,
            "test/dab": total_dab / total_num,
            "test/sum_mse_dab": (total_loss + dab_weight * total_dab) / total_num,
        },
    )

    return np.concatenate(predictions, axis=0)

In [28]:
# %% Inference
def test(model: nn.Module, test_data: Dict[str, torch.Tensor]) -> np.ndarray:
    """
    Perform inference on the test dataset.
    """
    model.eval()
    predictions = []

    # ‚úÖ Ensure model is not wrapped in DataParallel
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  

    # ‚úÖ Prepare DataLoader for testing
    test_loader = DataLoader(
        dataset=SeqDataset(test_data),
        batch_size=eval_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=min(len(os.sched_getaffinity(0)), eval_batch_size // 2),
        pin_memory=True,
    )

    with torch.no_grad():
        for batch_data in test_loader:
            input_gene_ids = batch_data["gene_ids"].to(device)  # Expected: (batch_size, seq_len)
            input_values = batch_data["values"].to(device)  # Expected: (batch_size, seq_len)

            # ‚úÖ Ensure input_gene_ids is correctly expanded
            if input_gene_ids.dim() == 1:
                input_gene_ids = input_gene_ids.unsqueeze(1).expand(-1, input_values.shape[1])

            # ‚úÖ Debugging: Print corrected shapes
            print(f"Fixed input_gene_ids shape: {input_gene_ids.shape}")  # Should match input_values
            print(f"Fixed input_values shape: {input_values.shape}")  

            # Ensure input shapes are now matching
            assert input_gene_ids.shape == input_values.shape, "Mismatch in input_gene_ids and input_values shape!"

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])

            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    CLS=True,
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                )

                output_values = output_dict["cls_output"]
                batch_preds = output_values.argmax(1).cpu().numpy()
                predictions.append(batch_preds)

    return np.concatenate(predictions, axis=0)

# ‚úÖ Run inference
test_predictions = test(best_model, test_data)


Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])
Fixed input_gene_ids shape: torch.Size([3, 23794])
Fixed input_values shape: torch.Size([3, 23794])


IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/data/miniconda/envs/myenv/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/data/miniconda/envs/myenv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/data/miniconda/envs/myenv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_88838/2443685756.py", line 11, in __getitem__
    return {k: v[idx] for k, v in self.data.items()}
  File "/tmp/ipykernel_88838/2443685756.py", line 11, in <dictcomp>
    return {k: v[idx] for k, v in self.data.items()}
IndexError: index 212 is out of bounds for dimension 0 with size 212


In [None]:
# Run inference on the test dataset
test_predictions, test_labels, test_results = test(best_model, test_data)


In [24]:
print(type(test_data))  # Check the type of test_data
if isinstance(test_data, dict):
    print(test_data.keys())  # Print keys if it's a dictionary

<class 'dict'>
dict_keys(['gene_ids', 'values', 'target_values', 'batch_labels', 'age_labels'])


NameError: name 'test' is not defined

In [None]:
sc.pl.umap(
    adata_test_raw,
    color=["age", "predictions"],
    palette=palette_,
    show=False,
)
plt.savefig(save_dir / "results.png", dpi=300)


In [None]:
cm = confusion_matrix(labels, predictions)
sns.heatmap(cm, annot=True, fmt=".1f", cmap="Blues")
plt.savefig(save_dir / "confusion_matrix.png", dpi=300)


In [9]:
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
import pickle
import torch
from anndata import AnnData
import scanpy as sc
import scvi
import seaborn as sns
import numpy as np
import wandb
from scipy.sparse import issparse, csr_matrix
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import Vocab as VocabPybind
from sklearn.metrics import confusion_matrix
sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

# Ignore warnings
sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

# ‚úÖ **Hyperparameters and Configurations**
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="AIDA1_sample",
    do_train=True,
    load_model="/data/cellular_aging/references/scGPT_human_pretrained_model",
    mask_ratio=0.15,
    epochs=5,
    n_bins=51,
    MVC=False,  # Masked value prediction for cell embedding
    ecs_thres=0.0,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=0.0,
    lr=1e-5,
    batch_size=2,
    layer_size=512,
    nlayers=12,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead=8,  # number of heads in nn.MultiheadAttention
    dropout=0.2,  # dropout probability
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=5,
    fast_transformer=True,
    pre_norm=False,
    amp=True,  # Automatic Mixed Precision
    include_zero_gene=False,
    freeze=False,  # freeze
    DSBN=False,  # Domain-specific batch normalization
)

# ‚úÖ **Initialize WandB**
run = wandb.init(
    config=hyperparameter_defaults,
    dir="/data/cellular_aging/results/fine-tuning",
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
set_seed(config.seed)

# ‚úÖ **Preprocessing Settings**
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"  # Always set to auto for masked values
include_zero_gene = config.include_zero_gene  # Include zero genes in HVGs if True
max_seq_len = 4417
n_bins = config.n_bins

# ‚úÖ **Input/Output Representation**
input_style = "binned"  # Options: "normed_raw", "log1p", or "binned"
output_style = "binned"  # Options: "normed_raw", "log1p", or "binned"

# ‚úÖ **Training Settings**
MLM = False
CLS = True
ADV = False
CCE = False
MVC = config.MVC
ECS = config.ecs_thres > 0
DAB = False
INPUT_BATCH_LABELS = False
input_emb_style = "continuous"
cell_emb_style = "cls"
adv_E_delay_epochs = 0
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight
explicit_zero_prob = MLM and include_zero_gene
do_sample_in_train = False and explicit_zero_prob
per_seq_batch_sample = False

# ‚úÖ **Optimizer Settings**
lr = config.lr
lr_ADV = 1e-3
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 1

# ‚úÖ **Model Architecture Settings**
fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"
embsize = config.layer_size
d_hid = config.layer_size
nlayers = config.nlayers
nhead = config.nhead
dropout = config.dropout

# ‚úÖ **Logging & Evaluation**
log_interval = 100
save_eval_interval = config.save_eval_interval
do_eval_scib_metrics = True

# -----------------------------------
# ‚úÖ **Filtering Step: Load & Preprocess Dataset**
# -----------------------------------

# ‚úÖ Load full dataset (AnnData format)
adata = sc.read_h5ad('/data/cellular_aging/dataset/AIDA.h5ad')

# ‚úÖ Track Total Expression Before Filtering
total_expression_before = adata.X.sum()

# ‚úÖ Load FINAL gene mapping (from BioMart + MyGene.info)
df_final_mapping = pd.read_csv("final_gene_mapping.csv")
final_gene_dict = dict(zip(df_final_mapping["ensembl_id"], df_final_mapping["gene_symbol"]))

# ‚úÖ Extract Ensembl IDs & Apply Gene Mapping
adata.var["ensembl_id"] = adata.var_names.str.split(".").str[0]
adata.var["gene_symbol"] = adata.var["ensembl_id"].map(final_gene_dict)

# ‚úÖ Load `vocab.json`
vocab_path = "/data/cellular_aging/references/scGPT_human_pretrained_model/vocab.json"
with open(vocab_path, "r") as f:
    vocab = json.load(f)

# ‚úÖ Convert vocab genes to a set
vocab_genes = set(vocab.keys())

# ‚úÖ Filter dataset to keep only genes in `vocab.json`
adata_filtered_vocab = adata[:, adata.var["gene_symbol"].isin(vocab_genes)].copy()

# ‚úÖ Track Total Expression After Filtering
total_expression_after = adata_filtered_vocab.X.sum()

# ‚úÖ Compute Gene & Expression Coverage
gene_coverage_vocab = (adata_filtered_vocab.shape[1] / adata.shape[1]) * 100
expression_coverage_vocab = (total_expression_after / total_expression_before) * 100

# ‚úÖ Print Final Summary
print("\nüìä **Final Gene Mapping & Expression Coverage Summary:**")
print(f"üîπ Total Genes Before Vocab Filtering: {adata.shape[1]}")
print(f"üîπ Total Genes After Vocab Filtering: {adata_filtered_vocab.shape[1]}")
print(f"üîπ Total Expression Before Vocab Filtering: {total_expression_before:.2f}")
print(f"üîπ Total Expression After Vocab Filtering: {total_expression_after:.2f}")
print(f"‚úÖ Gene Coverage After Vocab Filtering: {gene_coverage_vocab:.2f}%")
print(f"‚úÖ Expression Coverage After Vocab Filtering: {expression_coverage_vocab:.2f}%")

# ‚úÖ **Step 1: Random Sampling**
sample_fraction = 0.001  # Adjust as needed
num_sample = int(sample_fraction * adata_filtered_vocab.n_obs)

random_indices = np.random.choice(adata_filtered_vocab.n_obs, num_sample, replace=False)
adata_sample = adata_filtered_vocab[random_indices, :].copy()
print(f"‚úÖ Sampled dataset shape: {adata_sample.shape}")

# ‚úÖ **Step 2: Convert to Sparse Format**
adata_sample.X = csr_matrix(adata_sample.X)



üìä **Final Gene Mapping & Expression Coverage Summary:**
üîπ Total Genes Before Vocab Filtering: 36161
üîπ Total Genes After Vocab Filtering: 23794
üîπ Total Expression Before Vocab Filtering: 2888279552.00
üîπ Total Expression After Vocab Filtering: 2841698816.00
‚úÖ Gene Coverage After Vocab Filtering: 65.80%
‚úÖ Expression Coverage After Vocab Filtering: 98.39%
‚úÖ Sampled dataset shape: (1058, 23794)


In [10]:
# ‚úÖ Settings for Input and Preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"  # Always set to auto for masked values
include_zero_gene = config.include_zero_gene  # Include zero genes in HVGs if True
max_seq_len = 4417
n_bins = config.n_bins

# ‚úÖ Input/Output Representation
input_style = "binned"  # Options: "normed_raw", "log1p", or "binned"
output_style = "binned"  # Options: "normed_raw", "log1p", or "binned"

# ‚úÖ Training Settings
MLM = False  # Masked Language Modeling (MLM), always on
CLS = True  # Enable Classification Objective
ADV = False  # Adversarial Training for Batch Correction
CCE = False  # Contrastive Cell Embedding Objective
MVC = config.MVC  # Masked Value Prediction for Cell Embedding
ECS = config.ecs_thres > 0  # Elastic Cell Similarity Objective (Enabled if > 0)
DAB = False  # Domain Adaptation via Reverse Backpropagation (set to 2 for separate optimizer)
INPUT_BATCH_LABELS = False  # Helps MLM and MVC, but not classifier
input_emb_style = "continuous"  # Options: "category", "continuous", "scaling"
cell_emb_style = "cls"  # Options: "avg-pool", "w-pool", "cls"
adv_E_delay_epochs = 0  # Delay Adversarial Training (Encoder)
adv_D_delay_epochs = 0  # Delay Adversarial Training (Discriminator)
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight
explicit_zero_prob = MLM and include_zero_gene  # Use explicit Bernoulli for zero values
do_sample_in_train = False and explicit_zero_prob  # Sample Bernoulli in training
per_seq_batch_sample = False  # Per-sequence batch sampling disabled

# ‚úÖ Optimizer Settings
lr = config.lr  # Learning Rate
lr_ADV = 1e-3  # Learning Rate for Adversarial Discriminator (if ADV is True)
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 1  # Interval for learning rate scheduling

# ‚úÖ Model Architecture Settings
fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"  # Options: "linear", "flash"
embsize = config.layer_size  # Embedding Dimension
d_hid = config.layer_size  # Hidden Dimension in TransformerEncoder
nlayers = config.nlayers  # Number of Transformer Encoder Layers
nhead = config.nhead  # Number of Attention Heads in MultiheadAttention
dropout = config.dropout  # Dropout Probability

# ‚úÖ Logging & Evaluation
log_interval = 100  # Log every 100 iterations
save_eval_interval = config.save_eval_interval  # Save evaluation results every `save_eval_interval` epochs
do_eval_scib_metrics = True  # Enable evaluation using SciB metrics


In [26]:
def evaluate(model: nn.Module, loader: DataLoader, return_raw: bool = False) -> Union[float, Tuple[float, float]]:
    """
    Evaluate the model on the validation/test dataset.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    predictions = []

    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)
            age_labels = batch_data["age_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if INPUT_BATCH_LABELS or config.DSBN else None,
                    CLS=CLS,  # Only classification used here, disable MLM/CCE
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=do_sample_in_train,
                )

                output_values = output_dict["cls_output"]
                loss = criterion_cls(output_values, age_labels)

                if DAB:
                    loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            accuracy = (output_values.argmax(1) == age_labels).sum().item()
            total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids) if DAB else 0.0
            total_num += len(input_gene_ids)

            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)

    # Log metrics
    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/err": total_error / total_num,
            "valid/dab": total_dab / total_num,
            "valid/sum_mse_dab": (total_loss + dab_weight * total_dab) / total_num,
            "epoch": epoch,
        }
    )

    if return_raw:
        return np.concatenate(predictions, axis=0)

    return total_loss / total_num, total_error / total_num


In [None]:
# ‚úÖ **Log Best Model Metrics**
wandb.log(
    {
        "best_val_loss": best_val_loss,
        "best_model_epoch": best_model_epoch,
    }
)

# ‚úÖ **Save Best Model**
best_model_path = save_dir / "best_model.pt"
torch.save(best_model.state_dict(), best_model_path)
logger.info(f"‚úÖ Best model saved at: {best_model_path}")

# ‚úÖ **Cleanup Resources**
wandb.finish()
gc.collect()
torch.cuda.empty_cache()  # Free up GPU memory
logger.info("‚úÖ Cleanup complete. WandB session closed, memory freed.")
