In [1]:
import torch
from gears import PertData

# Load the Adamson dataset
pert_data = PertData("./data")
pert_data.load(data_name="adamson")
pert_data.prepare_split(split="simulation", seed=1)
pert_data.get_dataloader(batch_size=64, test_batch_size=64)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Found local copy...
Extracting zip file...
Done!
Creating pyg object for each cell in the data...
100%|██████████| 87/87 [01:04<00:00,  1.36it/s]
Saving new dataset pyg object at ./data/adamson/data_pyg/cell_graphs.pkl


: 

In [None]:
from scgpt.model import TransformerGenerator
from scgpt.tokenizer import GeneVocab

# Load the vocabulary
vocab = GeneVocab.from_file("path/to/vocab.json")

# Load model configuration
with open("path/to/model_config.json", "r") as f:
    model_configs = json.load(f)

# Create the model
model = TransformerGenerator(
    ntokens=len(vocab),
    embsize=model_configs["embsize"],
    nhead=model_configs["nheads"],
    d_hid=model_configs["d_hid"],
    nlayers=model_configs["nlayers"],
    nlayers_cls=model_configs["n_layers_cls"],
    n_cls=1,
    vocab=vocab,
    dropout=model_configs["dropout"],
    pad_token=model_configs["pad_token"],
    pad_value=model_configs["pad_value"],
    do_mvc=True,
    do_dab=False,
    use_batch_labels=False,
)

# Load pre-trained weights
model.load_state_dict(torch.load("path/to/pretrained_model.pt"))
model.to(device)

In [None]:
import torch.nn as nn

class CellBehaviorClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

# Assuming we have 3 cell behavior classes
classifier = CellBehaviorClassifier(input_dim=model_configs["embsize"], num_classes=3)
classifier.to(device)

In [None]:
import torch.optim as optim
from torch.nn import functional as F

# Combine scGPT and classifier
class CombinedModel(nn.Module):
    def __init__(self, scgpt, classifier):
        super().__init__()
        self.scgpt = scgpt
        self.classifier = classifier

    def forward(self, input_gene_ids, input_values, input_pert_flags):
        scgpt_output = self.scgpt(input_gene_ids, input_values, input_pert_flags)
        cell_embedding = scgpt_output["cell_embedding"]
        classifier_output = self.classifier(cell_embedding)
        return classifier_output

combined_model = CombinedModel(model, classifier)
combined_model.to(device)

# Set up optimizer and loss function
optimizer = optim.Adam(combined_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    combined_model.train()
    for batch in pert_data.dataloader["train_loader"]:
        optimizer.zero_grad()
        
        input_gene_ids = batch.genes.to(device)
        input_values = batch.expressions.to(device)
        input_pert_flags = batch.pert_flags.to(device)
        
        # Assuming we have cell behavior labels
        labels = batch.cell_behavior_labels.to(device)
        
        outputs = combined_model(input_gene_ids, input_values, input_pert_flags)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Evaluation
combined_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in pert_data.dataloader["test_loader"]:
        input_gene_ids = batch.genes.to(device)
        input_values = batch.expressions.to(device)
        input_pert_flags = batch.pert_flags.to(device)
        labels = batch.cell_behavior_labels.to(device)
        
        outputs = combined_model(input_gene_ids, input_values, input_pert_flags)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on test set: {100 * correct / total}%")

In [None]:
torch.save(combined_model.state_dict(), "path/to/fine_tuned_model.pt")

In [7]:
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0  # for padding values
pert_pad_id = 0
include_zero_gene = "all"
max_seq_len = 1536

# settings for training
MLM = True  # whether to use masked language modeling
CLS = False  # celltype classification objective
CCE = False  # Contrastive cell embedding objective
MVC = False  # Masked value prediction for cell embedding
ECS = False  # Elastic cell similarity objective
amp = True

# settings for optimizer
lr = 1e-4
batch_size = 64
eval_batch_size = 64
epochs = 15
schedule_interval = 1
early_stop = 10

# settings for the model
embsize = 512
d_hid = 512
nlayers = 12
nhead = 8
n_layers_cls = 3
dropout = 0
use_fast_transformer = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# Create vocabulary
genes = pert_data.adata.var["gene_name"].tolist()
vocab = Vocab(
    VocabPybind(genes + special_tokens, None)
)
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)

# Initialize the model
ntokens = len(vocab)
model = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=use_fast_transformer,
)
model.to(device)

# Define loss function and optimizer
criterion = masked_mse_loss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=0.9)
scaler = torch.cuda.amp.GradScaler(enabled=amp)

NameError: name 'Vocab' is not defined