In [13]:
import torch
import scanpy as sc
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import TOSICA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import umap
from run import config
from scgpt.model import TransformerModel
from scgpt.tokenizer import tokenize_and_pad_batch, GeneVocab
from scgpt.preprocess import Preprocessor
import wandb
import json
import os
from torch.utils.data import DataLoader
project = config['project']
pretrained_path = os.path.normpath(os.path.dirname(config['scgpt_params']['pretrained_path']))
scgpt_ref=config['ref_dataset_path']
scgpt_query=config['query_dataset_path']
def pre_process(data_path, project):
    
    if model_name == 'TOSICA':
        project=project
        adata = sc.read(data_path)
        adata = adata[:, adata.var_names]
        return adata
    elif model_name=='scgpt':
        project=project
        adata = sc.read(data_path)
        preprocessor = Preprocessor(
            use_key="X",
            filter_gene_by_counts=False,
            normalize_total=1e4,
            log1p=True,
            binning=config['scgpt_params']['n_bins']
        )
        preprocessor(adata)
        adata.obs['Celltype'] = adata.obs['Celltype'].astype('category')
        
    return adata

def scgpt_annotation(ref_adata, query_adata):
    """scGPT cell type annotation pipeline"""
    
    # Load gene vocab
    vocab_path = os.path.join(pretrained_path, "vocab.json")
    with open(vocab_path, 'r') as file:
        gene_list = json.load(file)
    
    gene_names = [gene for gene, _ in gene_list.items()]
    vocab = GeneVocab(gene_list=gene_names)
    
    # Add padding token if missing
    pad_token = '<pad>'
    if pad_token not in vocab:
        vocab.add_token(pad_token)
    print("Available vocab keys:", vocab.keys())
    print("Pad Token:", pad_token)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    query_adata=sc.read(config['query_dataset_path'])
    celltype_categories = query_adata.obs['Celltype'].cat.categories

    model = TransformerModel(
        ntoken=len(vocab),
        d_model=128,
        nhead=4,
        d_hid=128,
        nlayers=4,
        nlayers_cls=config['scgpt_params']['n_layers_cls'],
        n_cls=len(celltype_categories)  # Fix: use correct category length
    ).to(device)

    # Load pretrained model weights
    model.load_state_dict(torch.load(config['scgpt_params']['pretrained_path'] + "/best_model.pt"))

    def prepare_scgpt_batches(adata):
        """Tokenize and batch scGPT data"""
        gene_ids = np.array([vocab[gene] for gene in adata.var_names], dtype=int)
        tokenized = tokenize_and_pad_batch(
            adata.layers['X_binned'],
            gene_ids,
            max_len=3001,
            vocab=vocab,
            pad_value=vocab['<pad>']
        )
        return DataLoader(SeqDataset(tokenized), batch_size=config['scgpt_params']['batch_size'])

    # Fine-tuning model
    optimizer = torch.optim.Adam(model.parameters(), lr=config['scgpt_params']['lr'])
    criterion = torch.nn.CrossEntropyLoss()

    train_loader = prepare_scgpt_batches(train_data)
    for epoch in range(config['scgpt_params']['epochs']):
        model.train()
        for batch in train_loader:
            inputs = batch['gene_ids'].to(device)
            values = batch['values'].to(device)
            labels = batch['Celltype_labels'].to(device)
            
            outputs = model(inputs, values)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Predict on query data
    query_loader = prepare_scgpt_batches(query_data)
    predictions = []
    with torch.no_grad():
        for batch in query_loader:
            inputs = batch['gene_ids'].to(device)
            values = batch['values'].to(device)
            outputs = model(inputs, values)
            predictions.extend(outputs.argmax(dim=1).cpu().numpy())

    return predictions

def pre_train(ref_adata, epochs, project):
    TOSICA.train(ref_adata, gmt_path='human_gobp', label_name='Celltype', epochs=epochs, project=project)
    return f'./{project}/model-0.pth'


def fine_tune(model_weight_path, query_adata, project):
    model_weight_path= f'./{project}/model-0.pth'
    new_adata = TOSICA.pre(query_adata, model_weight_path=model_weight_path, project=project)
    new_adata.write('tosica_att.h5ad')
    return new_adata


def evaluate(predictions, labels):
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted')
    return accuracy, f1

def extract_latent_representation(adata):
    # Extract latent representations from TOSICA if available
    adata = tosica_model.preprocess_data(adata)
    latent_representations = tosica_model.get_latent_representation(adata)
    adata.obsm['X_latent'] = latent_representations
    return adata

def save_results(results, output_csv):
    df = pd.DataFrame(results)
    df.to_csv(output_csv, index=False)


def perform_umap_and_clustering(new_adata):
    embeddings = new_adata.obsm['X_latent']  # Adjust based on how embeddings are stored
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='euclidean')
    umap_embeddings = reducer.fit_transform(embeddings)

    # Clustering with K-Means
    n_clusters = 5  # Adjust based on your data
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    new_adata.obs['kmeans_labels'] = kmeans.fit_predict(umap_embeddings)

    # Visualization
    plt.figure(figsize=(10, 8))
    plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], c=new_adata.obs['kmeans_labels'], cmap='Spectral', s=5)
    plt.title('UMAP Projection of Cell Embeddings with K-Means Clustering')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.colorbar(label='Cluster Label')
    plt.show()


#  Introduce perturbations
def introduce_noise(adata, noise_level=0.1):
    if not isinstance(adata.X, np.ndarray):
        adata.X = adata.X.toarray()
    adata.X = adata.X.astype(np.float32)
    noise = np.random.normal(0, noise_level, adata.X.shape).astype(np.float32)
    adata.X += noise
    adata.X = adata.X.astype(np.int32)
    return adata


if __name__ == "__main__":
    model_name = config['model_name']
    task = config['task']
    ref_dataset_path = config['ref_dataset_path']
    query_dataset_path = config['query_dataset_path']
    epochs = config['epochs'] 
    model_weight_path = f'./{project}/model-0.pth'
    results = []

    if model_name == "TOSICA":
        if task == 'cell_type_annotation':
            # Perform pre-training, fine-tuning, and evaluation
            ref_adata = pre_process(ref_dataset_path, project=project)
            model_path = pre_train(ref_adata, epochs=epochs, project=project)
            query_adata = pre_process(query_dataset_path, project=project)
            new_adata = fine_tune(model_weight_path, query_adata, project=project)
            labels = new_adata.obs['Celltype'].values
            print(labels)
            predictions = new_adata.obs['Prediction'].values
            print(predictions)
            acc, f1 = evaluate(predictions, labels)
            results.append({'task': task, 'accuracy': acc, 'f1_score': f1})
            new_adata.raw = new_adata
            sc.pp.normalize_total(new_adata, target_sum=1e4)
            sc.pp.log1p(new_adata)
            sc.pp.scale(new_adata, max_value=10)
            sc.tl.pca(new_adata, svd_solver='arpack')
            sc.pp.neighbors(new_adata, n_neighbors=10, n_pcs=40)
            sc.tl.umap(new_adata)
            col = np.array([
                "#98DF8A","#E41A1C" ,"#377EB8", "#4DAF4A" ,"#984EA3" ,"#FF7F00" ,"#FFFF33" ,"#A65628" ,"#F781BF" ,"#999999","#1F77B4","#FF7F0E","#279E68","#FF9896"
                    ]).astype('<U7')
            prediction_categories = new_adata.obs['Prediction'].unique()
            celltype_categories = new_adata.obs['Celltype'].unique()
            new_adata.obs['Prediction'] = pd.Categorical(new_adata.obs['Prediction'], categories=prediction_categories)
            new_adata.obs['Celltype'] = pd.Categorical(new_adata.obs['Celltype'], categories=celltype_categories)
            common_categories = list(set(prediction_categories) & set(celltype_categories))
            new_adata.obs['Prediction'] = new_adata.obs['Prediction'].cat.set_categories(common_categories)
            new_adata.obs['Celltype'] = new_adata.obs['Celltype'].cat.set_categories(common_categories)
            prediction_order = new_adata.obs['Prediction'].value_counts().index
            new_adata.obs['Prediction'] = new_adata.obs['Prediction'].cat.reorder_categories(prediction_order)
            # Plot UMAP with 'Celltype' categories
            plt.figure(figsize=(10, 8))
            sc.pl.umap(new_adata, color='Celltype', palette='Set1', title='UMAP colored by Celltype', size=50)
            plt.show()
            latent_representations = TOSICA.get_latent_representation(ref_adata)
            kmeans = KMeans(n_clusters=10)
            cluster_labels = kmeans.fit_predict(latent_representations)
            ref_adata.obs['cluster'] = cluster_labels

            # Step 6: Visualization
            umap_embeddings = umap.UMAP().fit_transform(latent_representations)
            plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], c=cluster_labels, cmap='Spectral', s=5)
            plt.colorbar()
            plt.title("UMAP Visualization of Latent Representations")
            plt.show()
            # Ensure both 'Prediction' and 'Celltype' have the same categories
             
            ## common_categories = list(set(prediction_categories) & set(celltype_categories))
            # new_adata.obs['Prediction'] = new_adata.obs['Prediction'].cat.set_categories(common_categories)
            # new_adata.obs['Celltype'] = new_adata.obs['Celltype'].cat.set_categories(common_categories)

            # new_adata.obs['Prediction'] = new_adata.obs['Prediction'].astype('category')
            # print( 'prediction ='+new_adata.obs['Prediction'].cat.categories)
            # print("____________________________________________________")
            # print(celltype)
            # new_adata.obs['Prediction'] = new_adata.obs['Prediction'].cat.reorder_categories(list(celltype))
            # new_adata.uns['Prediction_colors'] = col[1:]
            
            #  celltype = new_adata.obs['Celltype'].values
            #  new_adata.obs['Celltype'] = new_adata.obs['Celltype'].astype('category')
            #  new_adata.obs['Celltype'] = new_adata.obs['Celltype'].cat.reorder_categories(list(celltype))
            # new_adata.uns['Celltype_colors'] = col[:11]
            
        elif task == 'perturbation':
            query_adata = pre_process(query_dataset_path, project=project)  
            query_adata = introduce_noise(query_adata, noise_level=0.1)
            new_adata = fine_tune(model_weight_path, query_adata, project=project)

        elif task == 'clustering':
                query_adata = pre_process(query_dataset_path, project=project)
                
                try:
                    new_adata = sc.read('tosica_att.h5ad')
                    print("Using previously fine-tuned data.")
                except FileNotFoundError:
                    print("No previous fine-tuned data found. Performing pre-training and fine-tuning.")
                    ref_adata = pre_process(ref_dataset_path,project=project)
                    model_path = pre_train(ref_adata, epochs=epochs, project=project)  
                    new_adata = fine_tune(model_weight_path, query_adata, project=project)
                    new_adata.write('tosica.h5ad')
                    results.append({'task': task, 'perturbation': 'Noise introduced and fine-tuning performed'})
                perform_umap_and_clustering(new_adata)
                results.append({'task': task, 'clustering': 'UMAP and KMeans performed'})
        elif task == 'latent_representation':
                # Load and process the query dataset
                query_adata = pre_process(query_dataset_path, project=project)
                query_adata = extract_latent_representation(query_adata)
                perform_umap_and_clustering(query_adata)
                results.append({'task': task, 'latent_representation_shape': query_adata.obsm['X_latent'].shape})
        elif task == 'gene_and_cell_embeddings':
                query_adata = pre_process(query_dataset_path, project=project)
                try:
                    new_adata = sc.read('tosica_att.h5ad')
                    print("Using previously fine-tuned data.")
                except FileNotFoundError:
                    print("No previous fine-tuned data found. Performing pre-training and fine-tuning.")
                    ref_adata = pre_process(ref_dataset_path, project=project)
                    model_path = pre_train(ref_adata, epochs=epochs, project=project)
                    new_adata = fine_tune(model_weight_path, query_adata, project=project)
                    
                embeddings = new_adata.obsm['X']  # Adjust based on how embeddings are stored
                results.append({'task': task, 'embeddings_shape': embeddings.shape})

        save_results(results, 'C:/Users/gaiacronus/Downloads/work/combine/modelresults/results.csv')
        
    else:
        if task == 'cell_type_annotation':
            print("\nRunning scGPT annotation...")
        scgpt_pred = scgpt_annotation(scgpt_ref, scgpt_query)
        results['scgpt'] = {
            'predictions': scgpt_pred,
            'accuracy': accuracy_score(scgpt_query.obs['Celltype'].cat.codes, scgpt_pred),
            'f1': f1_score(scgpt_query.obs['Celltype'].cat.codes, scgpt_pred, average='weighted')
        }
            
        
        
      



Running scGPT annotation...


AttributeError: 'GeneVocab' object has no attribute 'keys'