In [None]:
import sys
sys.path.append("../../src")

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6, 7"

import evaluate
import numpy as np
from datasets import load_from_disk
from stella import STELLADataCollatorV1
from stella.models import STELLAForSequenceClassification
from transformers import set_seed, Trainer, TrainingArguments
from stella.tokenizer import TranscriptomeTokenizerForCellClassification

set_seed(seed=42)

### 1. Tokenize your adata

- Notice: gene symbol should be in `adata.var_names`

In [None]:
tokenizer = TranscriptomeTokenizerForCellClassification(
    seed=42,
    nproc=8,
    max_length=2048,
    custom_attr_name_dict={"celltype": "celltype"}
)

In [None]:
ds = tokenizer(
    h5ad_file="./zheng68k/zheng68k.h5ad",  # your h5ad file path
    save_dir="./tokenized_zheng68k"  # save your tokenized dataset
)

### 2. If you have already tokenized your data, load it directly instead of tokenizing it again.

In [None]:
ds = load_from_disk("./tokenized_zheng68k")
ds

### 3. Process celltype labels

In [None]:
# unique celltype
uniq_ct = np.unique(ds["celltype"]).tolist()

# celltype: label_id
ct2label = dict(zip(uniq_ct, range(len(uniq_ct))))

def process_func(example):
    example["labels"] = ct2label[example["celltype"]]
    return example

ds = ds.map(process_func, num_proc=8, remove_columns=["celltype"])
ds = ds.class_encode_column("labels")

### 4. Split the dataset into a training set, a validation set, and a test set

In [None]:
ds = ds.shuffle(seed=42)

train_test_split = ds.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
train_ds, test_ds = train_test_split["train"], train_test_split["test"]

train_validation_split = train_ds.train_test_split(test_size=0.1, seed=42, stratify_by_column="labels")
train_ds, validation_ds = train_validation_split["train"], train_validation_split["test"]

# train_size, validation_size, test_size
train_ds.shape[0], validation_ds.shape[0], test_ds.shape[0]

### 5. Load Pretrained Model

In [None]:
model = STELLAForSequenceClassification.from_pretrained(
    "../../pretrained_models/B100_L2048", 
    num_labels=len(ct2label)
)

### 6. If you don't have enough GPU memory, try freezing some layers

In [None]:
def freeze_first_k_layers(k=4):
    for name, param in model.named_parameters():
        if any(f"stella.encoder.layer.{i}" in name for i in range(k)):
            param.requires_grad = False

# freeze the first k layers
freeze_first_k_layers(k=0)  # no freeze

# check the trainable status of the parameters
for name, params in model.named_parameters():
    print(name, "\t", params.requires_grad)

### 7. Start Training

In [None]:
training_args = TrainingArguments(
    seed=42,
    output_dir="./celltype_annotation",
    report_to="tensorboard",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    gradient_accumulation_steps=1,
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0,
    max_grad_norm=1.0,
    optim="adamw_torch",
    bf16=True,
    dataloader_pin_memory=True,
    dataloader_num_workers=8,
    dataloader_persistent_workers=True,
    ddp_find_unused_parameters=False,
    gradient_checkpointing=False,
)


clf_metrics = evaluate.combine(
    [
        "../../src/stella/metrics/accuracy",
        "../../src/stella/metrics/precision",
        "../../src/stella/metrics/recall",
        "../../src/stella/metrics/f1",
    ]
)


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    return clf_metrics.compute(predictions=preds, references=labels)


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=STELLADataCollatorV1,
    train_dataset=train_ds,
    eval_dataset=validation_ds,
    compute_metrics=compute_metrics,
)

In [None]:
# Start training
trainer.train()

In [None]:
# Performance on validation dataset
trainer.evaluate(validation_ds)

In [None]:
# Performance on test dataset
trainer.predict(test_ds)