In [None]:
import os
from typing import List, Optional
import numpy as np

import torch
from lightning import LightningDataModule

sys.path.insert(0, "../") #load the cellm outside of notebooks
sys.path.insert(0, "../reproduce/") #load the rep_utils outside of notebooks

from cellm.data.data_scimilarity_gred import SampleCellsDataModule_disease, scDataset_disease

# from cellm.data.data_structures import CellSample
import pickle

from dataclasses import dataclass

import torch
import json
import sklearn

from captum.attr import IntegratedGradients

import scipy.stats
import pandas as pd

import matplotlib.pyplot as plt
from sklearn.metrics import classification_report #1e-3, best save
from sklearn.metrics import roc_auc_score
import scanpy as sc

def most_frequent_per_row(array):
    most_frequent_values = []
    for row in array:
        counts = np.bincount(row)
        most_frequent = np.argmax(counts)
        most_frequent_values.append(most_frequent)
    return np.array(most_frequent_values)

from cellm.components.cell_to_cell import CellToCellPytorchTransformer
from cellm.components.cell_to_output import CellToOutputMLP
from cellm.components.gene_to_cell import GeneToCellLinear
from cellm.components.masking import Masking
from cellm.data.data_structures import CellSample
from rep_exp_utils import CellClassifyModel


In [None]:
f1score_list = []
finalauroc_list = []
class_model = CellClassifyModel.load_from_checkpoint("./disease_class/disease2classlr1e3wd1e4batch32epoch50_weightaverage_dim1-epoch=36-val_accuracy=0.82.ckpt", num_genes=28231, masking_strategy=None, attn = 'linear_attn')

ig = IntegratedGradients(class_model)

In [None]:
adata = sc.read_h5ad("./combined_4dataset_processed_data.h5ad")
gene_list = pd.read_csv("/gstore/data/omni/scdb/cleaned_h5ads/gene_order.tsv", header=None)
gene_list = gene_list[0].values


In [None]:
adata_new = sc.AnnData(np.zeros((1,len(gene_list))))
adata_new.var_names = gene_list
adata_c = sc.concat([adata_new, adata], join='outer', label='batch_new')

adata_c_f = adata_c[adata_c.obs['batch_new'] == '1']
adata_c_f = adata_c_f[:, gene_list]
adata_c_f.obs['sampleID']

In [None]:
if torch.cuda.is_available():
    sample_data_list = []
    dis_label = []
    with torch.no_grad():
        for item in adata_c_f.obs['sampleID'].unique():
            adata_s = adata_c_f[adata_c_f.obs['sampleID'] == item]
            input_data = adata_s.X.toarray().reshape(1,  adata_s.X.shape[0], adata_s.X.shape[1])
            query = torch.FloatTensor(input_data).cuda()
            sample_data,probs = class_model.obtain_annotation_directly(query)
            sample_data_list.append(probs.cpu().numpy())
            dis_label.append(adata_s.obs['condition'][0])
else:
    sample_data_list = []
    dis_label = []
    with torch.no_grad():
        for item in adata_c_f.obs['sampleID'].unique():
            adata_s = adata_c_f[adata_c_f.obs['sampleID'] == item]
            input_data = adata_s.X.toarray().reshape(1,  adata_s.X.shape[0], adata_s.X.shape[1])
            query = torch.FloatTensor(input_data)
            sample_data,probs = class_model.obtain_annotation_directly(query)
            sample_data_list.append(probs.cpu().numpy())
            dis_label.append(adata_s.obs['condition'][0])

healthy_average = []
disease_average = []
for i,j in zip(sample_data_list,dis_label):
    if j == 'Mild':
        healthy_average.append(i[0][1])
    else:
        disease_average.append(i[0][1])


print(scipy.stats.ranksums(healthy_average, disease_average, alternative='less'))


df1 = pd.DataFrame({'Mild':healthy_average})
df2 = pd.DataFrame({'Severe':disease_average})



df = pd.concat([df1, df2])


df.boxplot(figsize=(4,4), fontsize=15)
plt.ylabel('Probability')
plt.show()

In [None]:
if torch.cuda.is_available():
    sample_data_list = []
    dis_label = []
    with torch.no_grad():
        for item in adata_c_f.obs['sampleID'].unique():
            adata_s = adata_c_f[adata_c_f.obs['sampleID'] == item]
            input_data = adata_s.X.toarray().reshape(1,  adata_s.X.shape[0], adata_s.X.shape[1])
            query = torch.FloatTensor(input_data).cuda()
            sample_data = class_model.obtain_embeddings_zs(query)
            sample_data_list.append(sample_data.cpu().numpy())
            dis_label.append(adata_s.obs['condition'][0])
else:
    sample_data_list = []
    dis_label = []
    with torch.no_grad():
        for item in adata_c_f.obs['sampleID'].unique():
            adata_s = adata_c_f[adata_c_f.obs['sampleID'] == item]
            input_data = adata_s.X.toarray().reshape(1,  adata_s.X.shape[0], adata_s.X.shape[1])
            query = torch.FloatTensor(input_data)
            sample_data = class_model.obtain_embeddings_zs(query)
            sample_data_list.append(sample_data.cpu().numpy())
            dis_label.append(adata_s.obs['condition'][0])


adata_sample_emb = sc.AnnData(np.array(sample_data_list)[:,0,:])



adata_sample_emb.obs['condition'] = dis_label 


sc.pp.neighbors(adata_sample_emb, use_rep='X')
sc.tl.umap(adata_sample_emb)
sc.pl.umap(adata_sample_emb, color='condition')

In [None]:
adata_c_f_old = adata_c_f.copy()

In [None]:

adata_c_f = adata_c_f_old[adata_c_f_old.obs['celltype'] == 'Non classical monocytes']


attributions_list = []
approximation_error_list = []
label_list = []

if torch.cuda.is_available():
    for item in adata_c_f.obs['sampleID'].unique():
        adata_s = adata_c_f[adata_c_f.obs['sampleID'] == item]
        input_data = adata_s.X.toarray().reshape(1,  adata_s.X.shape[0], adata_s.X.shape[1])
        baselines = torch.FloatTensor(np.zeros_like(input_data)).to('cuda')
        query = torch.FloatTensor(input_data).to('cuda')
        attributions, approximation_error = ig.attribute(query,
                                                         baselines=baselines,
                                                         target=1,
                                                         return_convergence_delta=True)
        attributions_list.append(attributions.detach().to('cpu'))
        approximation_error_list.append(approximation_error.detach().to('cpu'))
        label_list.append(adata_s.obs['condition'][0])
        del input_data 
        del baselines
        del query
else:
    for item in adata_c_f.obs['sampleID'].unique():
        adata_s = adata_c_f[adata_c_f.obs['sampleID'] == item]
        input_data = adata_s.X.toarray().reshape(1,  adata_s.X.shape[0], adata_s.X.shape[1])
        baselines = torch.FloatTensor(np.zeros_like(input_data))
        query = torch.FloatTensor(input_data)
        attributions, approximation_error = ig.attribute(query,
                                                         baselines=baselines,
                                                         target=1,
                                                         return_convergence_delta=True)
        attributions_list.append(attributions.detach().to('cpu'))
        approximation_error_list.append(approximation_error.detach().to('cpu'))
        label_list.append(adata_s.obs['condition'][0])


gene_average = []
for i,j in zip(attributions_list,label_list):
    gene_average.append(i.cpu().numpy().mean(axis=1))




mean_set = np.array(gene_average).mean(axis=0)

matched_gene = []
for i in range(len(adata_c_f.var)):
    if adata_c_f.var_names[i] in adata.var_names:
        if mean_set[0][i] >0:
            matched_gene.append(i)



overall_average = []
for i,j in zip(attributions_list,label_list):
    overall_average.append(i.cpu().numpy()[:,:,matched_gene].mean())

scal = sklearn.preprocessing.MinMaxScaler()
overall_average = scal.fit_transform(np.array(overall_average).astype('float').reshape(-1,1))




healthy_average = []
disease_average = []
for i,j in zip(overall_average.T[0],label_list):
    if j == 'Mild':
        healthy_average.append(i)
    else:
        disease_average.append(i)
        

scipy.stats.ranksums(healthy_average, disease_average, alternative='less')
df1 = pd.DataFrame({'Mild':healthy_average})
df2 = pd.DataFrame({'Severe':disease_average})



df = pd.concat([df1, df2])


df.boxplot(figsize=(4,4), fontsize=15)
plt.ylabel('Attributions')
plt.show()

