In [3]:
import os
import pandas as pd
import numpy as np
import TOSICA
from scMMT.scMMT_API import scMMT_API
from sklearn.metrics import f1_score
import scanpy as sc

def get_dataset(dataset_name):
    data_path = "data/"
    if dataset_name == "tos":
        data_path += "tosica/"
        train = sc.read(data_path + 'demo_train.h5ad')
        train = train[:, train.var_names]
        test = sc.read(data_path + 'demo_test.h5ad')
        test = test[:, train.var_names]
    elif dataset_name == "pmbc":
        data_path += "pbmc/"
        adata_gene = sc.read(data_path + "pbmc_gene.h5ad")
        adata_protein = sc.read(data_path + "pbmc_protein.h5ad")
        adata_gene = adata_gene[:1000]
        adata_protein = adata_protein[:1000]
        adata_gene.X = adata_gene.X.toarray()
        adata_protein.X = adata_protein.X.toarray()
        sc.pp.normalize_total(adata_protein)
        sc.pp.log1p(adata_protein)
        patients = np.unique(adata_protein.obs['donor'].values)
        for patient in patients:
            indices = [x == patient for x in adata_protein.obs['donor']]
            sub_adata = adata_protein[indices]
            sc.pp.scale(sub_adata)
            adata_protein[indices] = sub_adata.X
        train_bool = [x in ['P1', 'P3', 'P4', 'P7'] for x in adata_protein.obs['donor']]
        adata_gene_train = adata_gene[train_bool].copy()
        adata_protein_train = adata_protein[train_bool].copy()
        adata_gene_test = adata_gene[np.invert(train_bool)].copy()
        adata_protein_test = adata_protein[np.invert(train_bool)].copy()
        train = [adata_gene_train, adata_protein_train]
        test = [adata_gene_test, adata_protein_test]
    return train, test, data_path

def run_tosica_model(dataset):
    train, test, data_path = get_dataset(dataset)
    num_epoch = 3
    model_path = f"saved_models/tosica_model_{dataset}"
    
    # Train the TOSICA model
    TOSICA.train(train, gmt_path='human_gobp', label_name='Celltype', epochs=num_epoch, project=model_path)
    
    # Dummy values for accuracy and F1 score
    accuracy = 0.95  # Replace with actual evaluation
    f1_score_value = 0.93  # Replace with actual evaluation
    return accuracy, f1_score_value

def run_scmmt_model(dataset):
    train, test, data_path = get_dataset(dataset)
    num_epoch = 10  # Set the number of epochs for training
    model_path = f"saved_models/scmmt_model_{dataset}"
    print(train,test)
    
    adata_gene_train, adata_protein_train = train[:2]
    adata_gene_test, adata_protein_test = test [:2]
    
    # Initialize and train the scMMT model
    scMMT = scMMT_API(
        gene_trainsets=[adata_gene_train],
        protein_trainsets=[adata_protein_train],
        gene_test=adata_gene_test,
        train_batchkeys=['donor'],
        test_batchkey='donor',
        log_normalize=True,
        type_key='celltype.l3',
        data_dir=data_path + "preprocess_data_l3.pkl",
        data_load=False,
        dataset_batch=True,
        log_weight=3,
        val_split=None,
        min_cells=0,
        min_genes=0,
        n_svd=300,
        n_fa=180,
        n_hvg=550,
    )
    
    scMMT.train(n_epochs=num_epoch, ES_max=12, decay_max=6, decay_step=0.1, lr=10**(-3), label_smoothing=0.4,
                h_size=600, drop_rate=0.15, n_layer=4, weights_dir=model_path, load=False)
    
    predicted_test = scMMT.predict()
    predicted_labels = predicted_test.obs['transfered cell labels'].cpu().numpy()
    true_labels = predicted_test.obs['celltype.l3'].cpu().numpy()
    accuracy = (predicted_test.obs['transfered cell labels'] == predicted_test.obs['celltype.l3']).mean()
    f1 = f1_score(predicted_test.obs['transfered cell labels'], predicted_test.obs['celltype.l3'], average=None)
    f1_avg = np.median(f1)
    print(accuracy, f1_avg)
    return accuracy, f1_avg

def run_model(model_name, dataset):
    if model_name == "tosica":
        
        return run_tosica_model(dataset)
    elif model_name == "scmmt":
        a=run_scmmt_model(dataset)
        print(a)
        return a
        #return run_scmmt_model(dataset,num_epoch)
    else:
        print(f"Model {model_name} is not recognized.")
        return None, None

def main():
    #Input dataset and task
    dataset = input("Enter the dataset name (tos or pmbc): ")
    task = input("Enter the task (fine-tuning or pre-training): ").strip().lower()
    
    if task not in ['fine-tuning', 'pre-training']:
        print("Invalid task. Please enter 'fine-tuning' or 'pre-training'.")
        return

    #dataset="pmbc"
    #task="fine-tuning"
    #num_epoch=1
    # Define the models to run
    models = ["tosica","scmmt",]
    results = []
    
    # Run all models and collect results
    for model in models:
        accuracy, f1_score_value = run_model(model, dataset)
        if accuracy is not None and f1_score_value is not None:
            results.append({
                'dataset': dataset,
                'classifier': model,
                'accuracy': accuracy,
                'f1_score': f1_score_value
            })
    
    # Create a DataFrame and save to CSV
    results_df = pd.DataFrame(results)
    output_file = f"model_results_{dataset}.csv"
    
    # Ensure the DataFrame has the correct column names
    results_df.columns = ['dataset', 'classifier', 'accuracy', 'f1_score']
    
    results_df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")

if __name__ == "__main__":
    main()
    

Enter the dataset name (tos or pmbc):  tos
Enter the task (fine-tuning or pre-training):  fine-tuning


cuda:0


  ct_counts = pd.value_counts(data[:,-1])


Mask loaded!
Model builded!


[train epoch 0] loss: 2.233, acc: 0.156: 100%|█████████████████████████████████████| 3567/3567 [00:50<00:00, 70.76it/s]
[valid epoch 0] loss: 1.338, acc: 0.733: 100%|████████████████████████████████████| 1528/1528 [00:09<00:00, 169.20it/s]
[train epoch 1] loss: 0.878, acc: 0.668: 100%|█████████████████████████████████████| 3567/3567 [00:50<00:00, 70.85it/s]
[valid epoch 1] loss: 0.133, acc: 1.967: 100%|████████████████████████████████████| 1528/1528 [00:09<00:00, 169.46it/s]
[train epoch 2] loss: 0.189, acc: 0.976: 100%|█████████████████████████████████████| 3567/3567 [00:50<00:00, 70.54it/s]
[valid epoch 2] loss: 0.081, acc: 1.979: 100%|████████████████████████████████████| 1528/1528 [00:09<00:00, 168.81it/s]


Training finished!
View of AnnData object with n_obs × n_vars = 10600 × 3000
    obs: 'Celltype'
    var: 'Gene Symbol' View of AnnData object with n_obs × n_vars = 4218 × 3000
    obs: 'Celltype'
    var: 'Gene Symbol'
Searching for GPU
GPU detected, using GPU


AssertionError: 