# Predict with final model

In [None]:
import os
import sys
import pickle

import numpy as np
import torch
from sklearn.metrics import classification_report 


sys.path.insert(0, "../") #load the cellm outside of notebooks
sys.path.insert(0, "../reproduce/") #load the rep_utils outside of notebooks

from rep_utils import SampleCellsDataModuleCustom, CellClassifyModel

import warnings
import random
from lightning.pytorch import Trainer, seed_everything
warnings.filterwarnings('ignore')


tiledb_base_path = '/projects/global/gred/resbioai/CeLLM/tiledb'

CELLURI = "scimilarity_human_10x_cell_metadata"
GENEURI = "scimilarity_human_10x_gene_metadata"
COUNTSURI = "scimilarity_human_10x_counts"

In [None]:
'''
model_path: please change the string with your model path
attention: Attention design in the trained model
batch_size: number of samples per batch used for testing
sample_size: number of cells in each sample used for testing
classify_mode: Type of classification (binary or multilabel)
resample: wheter resampling cells from the same sample.
'''

model_path = '' 
attention = 'nonlinear_attn'
batch_size = 1
sample_size = 1500
classify_mode = 'multilabel'
resample = False

In [None]:

class_model = CellClassifyModel.load_from_checkpoint(model_path, num_genes=28231, masking_strategy=None, attn = attention, classify_mode = classify_mode, ) 

class_model.eval()

scd = SampleCellsDataModuleCustom(batch_size = batch_size, sample_size=sample_size, classify_mode =classify_mode, resample=resample)

In [None]:
# adding an option str.
f1score_list = []
finalauroc_list = []
full_results = False # wether to predict for all seeds or not.

if full_results == False:
    with torch.no_grad():
        seed = 0
        seed_everything(seed, workers=True)
        true_label = []
        pred_label = []
        auroc_list = []
        for i in scd.test_dataloader():
            output_annot = class_model.obtain_annotation(i, '0')
            pred_label += list(output_annot[0].cpu().numpy())
            true_label += list(i.disease_label.cpu().numpy())
        print(set(true_label))
        print(set(pred_label))
        print(classification_report(true_label, pred_label, digits=4))
        f1score_list.append(classification_report(true_label, pred_label, digits=4, output_dict=True)['weighted avg']['f1-score'])
        print(f1score_list)
else:
    with torch.no_grad():
        for seed in range(0,10):
            seed_everything(seed, workers=True)
            
            true_label = []
            pred_label = []
            auroc_list = []
            for i in scd.test_dataloader():
                output_annot = class_model.obtain_annotation(i, '0')
                pred_label += list(output_annot[0].cpu().numpy())
                true_label += list(i.disease_label.cpu().numpy())
            print(set(true_label))
            print(set(pred_label))
            print(classification_report(true_label, pred_label, digits=4))
            print(classification_report(true_label, pred_label, digits=4, output_dict=True)['weighted avg']['f1-score'])
            f1score_list.append(classification_report(true_label, pred_label, digits=4, output_dict=True)['weighted avg']['f1-score'])
        

In [None]:
# print the weighted f1 score
print(f1score_list)