In [13]:
import os
import torch
import torch.nn as nn
import scanpy as sc
from sklearn.metrics import accuracy_score, f1_score
from gensim.corpora import Dictionary
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, f1_score
from scgpt.tokenizer import GeneVocab, random_mask_value, tokenize_and_pad_batch
from scgpt.model import TransformerModel
from anndata import AnnData
import scipy.sparse
import TOSICA
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
# Ensure SCGPTModel inherits from nn.Module
class SCGPTModel(nn.Module):  
    def __init__(self, vocab_size, emb_size, num_heads, num_layers, d_hid, dropout, vocab):
        super(SCGPTModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=emb_size, 
                nhead=num_heads, 
                dim_feedforward=d_hid, 
                dropout=dropout
            ),
            num_layers=num_layers
        )
        self.classifier = nn.Linear(emb_size, len(vocab))  # Output size matches vocab

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = x.mean(dim=1)  # Example pooling operation for classification
        return self.classifier(x)
#Preprocessing Functions
def preprocess_adata(adata, n_top_genes=2000, target_sum=1e4):
    """Preprocess AnnData by normalizing, log-transforming, and selecting highly variable genes."""
    sc.pp.normalize_total(adata, target_sum=target_sum)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, subset=True)
    return adata

# Tokenization Function
def tokenize_data(adata, vocab, max_len=512):
    """Tokenize AnnData using GeneVocab."""
    gene_ids = [
        vocab.dictionary.token2id.get(gene, vocab.default_index) for gene in adata.var_names
    ]
    data_matrix = (
        adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X
    )
    return tokenize_and_pad_batch(
        data=data_matrix,
        gene_ids=np.array(gene_ids),
        max_len=max_len,
        cls_id=vocab.dictionary.token2id.get("<cls>", 0),
        pad_id=vocab.dictionary.token2id.get("<pad>", 0),
        pad_value=0,
        append_cls=True,
        include_zero_gene=False,
        return_pt=True,
    )
# Updated evaluate function
def evaluate(model, data_loader, device):
    """Evaluate model performance."""
    model.eval()  # Ensure the model is in evaluation mode
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Convert model outputs to predicted class labels
            _, preds = torch.max(outputs, dim=1)  # Get class indices (integer)

            # Ensure predictions are integers
            preds = preds.cpu().numpy().astype(int)
            labels = labels.cpu().numpy().astype(int)

            # Extend lists with consistent types
            all_preds.extend(preds)
            all_labels.extend(labels)

    # Ensure consistent lengths before calculating metrics
    if len(all_preds) != len(all_labels):
        raise ValueError(
            f"Final length mismatch: len(all_preds) ({len(all_preds)}) != len(all_labels) ({len(all_labels)})"
        )

    # Calculate metrics
    return accuracy_score(all_labels, all_preds), f1_score(all_labels, all_preds, average="macro")

# Paths
train_path = "C:/Users/gaiacronus/Downloads/work/combine/adata/hPancreas_train_adata.h5ad"
test_path = "C:/Users/gaiacronus/Downloads/work/combine/adata/hPancreas_test_adata.h5ad"

# Data Loading
adata_train = sc.read_h5ad(train_path)
adata_test = sc.read_h5ad(test_path)

# Preprocess Datasets
adata_train = preprocess_adata(adata_train)
adata_test = preprocess_adata(adata_test)

# Align Gene Names
common_genes = adata_train.var_names.intersection(adata_test.var_names)
adata_train = adata_train[:, common_genes]
adata_test = adata_test[:, common_genes]

# Vocabulary Creation
special_tokens = ["<pad>", "<cls>"]
gene_vocab = GeneVocab(
    gene_list=adata_train.var_names.tolist(),
    specials=special_tokens,
    default_token="<pad>"
)

# Tokenization
train_tokenized = tokenize_data(adata_train, gene_vocab)
test_tokenized = tokenize_data(adata_test, gene_vocab)

# Model Initialization
scgpt_model = SCGPTModel(
    vocab_size=len(gene_vocab),
    emb_size=128,
    num_heads=4,
    num_layers=4,
    d_hid=512,
    dropout=0.2,
    vocab=gene_vocab
)
def create_data_loader(tokenized_data, batch_size=32):
    """Create a DataLoader for tokenized data."""
    dataset = TensorDataset(tokenized_data["genes"], tokenized_data["values"])
    return DataLoader(dataset, batch_size=batch_size)

train_loader = create_data_loader(train_tokenized)
# Load the pre-trained scGPT model
model_weight_path = "C:/Users/gaiacronus/Downloads/work/combine/best_model.pt"  # Path to the pre-trained model weights
scgpt_model.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cpu')), strict=False)  # Load the model weights

# Fine-tune the model using the training data
# Adjust learning rate and other parameters as needed for fine-tuning
fine_tuning_epochs = 3
fine_tuning_batch_size = 32

# Fine-tuning logic
optimizer = torch.optim.Adam(scgpt_model.parameters(), lr=1e-4)  # Define optimizer
scgpt_model.train()  # Set model to training mode

for epoch in range(fine_tuning_epochs):
    for inputs, labels in train_loader:  
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        inputs, labels = inputs.to(device), labels.to(device)
        labels = torch.argmax(labels, dim=1)
        optimizer.zero_grad()  # Clear gradients
        outputs = scgpt_model(inputs)  # Forward pass
        loss = nn.CrossEntropyLoss()(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

# Save the fine-tuned model
fine_tuned_model_path = "fine_tuned_scGPT_model.pt"
torch.save(scgpt_model.state_dict(), fine_tuned_model_path)
project_name="28try"
# Evaluate the fine-tuned model
predicted_adata = TOSICA.pre(
    adata_test,
    model_weight_path=fine_tuned_model_path,
    project=project_name
)
test_loader = create_data_loader(test_tokenized)
predictions = []
scgpt_model.eval()  # Set model to evaluation mode
with torch.no_grad():
    for inputs in test_loader:  
        inputs = inputs.to(device)
        outputs = scgpt_model(inputs)
        _, preds = torch.max(outputs, dim=1)
        predictions.extend(preds.cpu().numpy())

# Store predictions in the test data
adata_test.obs["scGPT_predictions"] = predictions

# Check the predicted labels and probabilities
print(predicted_adata.obs.head())  # Check the predicted labels and probabilities

# Save the results for further analysis
predicted_adata.write("h_pancreas_predicted.h5ad")
test_acc, test_f1 = evaluate(scgpt_model, test_loader, device)
# Results
results = pd.DataFrame({
    "Dataset": ["Train", "Test"],
    "Accuracy": [ test_acc],
    "F1 Score": [ test_f1],
})
results.to_csv("results.csv", index=False)
print("Results saved to results.csv.")

  scgpt_model.load_state_dict(torch.load(model_weight_path, map_location=torch.device('cpu')), strict=False)  # Load the model weights


KeyboardInterrupt: 