In [None]:
import sys
import os 
import torch

from torch import nn
from torch.utils.data import DataLoader
import numpy as np
sys.path.append("/home/local/kyeonghunjeong_920205/nipa_bu/COVID19/3.analysis/9.MIL/scAMIL_cell/scMILD")
from utils import *
from dataset import *
from model import *
import pandas as pd
from sklearn.model_selection import train_test_split
import scanpy as sc
from scipy import sparse
import modin.pandas as pd
import ray
ray.init()

In [None]:
dir_path="Lupus"
base_path = f"../../data/{dir_path}/"
target_dir = f'{base_path}/AE/'

In [None]:
device_num = 4
device = torch.device(f'cuda:{device_num}' if torch.cuda.is_available() else 'cpu')
print("INFO: Using device: {}".format(device))

In [None]:

def save_cell_scores(saved_model_path, exp, test_dataset, label_encoder, device, suffix=None):
    instance_test_dataset = update_instance_labels_with_bag_labels(test_dataset, device)

    model_teacher = torch.load(f'{saved_model_path}/model_teacher_exp{exp}.pt', map_location=device)
    model_encoder = torch.load(f'{saved_model_path}/model_encoder_exp{exp}.pt', map_location=device)
    model_student = torch.load(f'{saved_model_path}/model_student_exp{exp}.pt', map_location=device)

    model_encoder.eval()
    model_student.eval()
    model_teacher.eval()
    with torch.no_grad():
        features = model_encoder(instance_test_dataset.data.clone().detach().float().to(device))[:, :model_teacher.input_dims].detach().requires_grad_(False)
        cell_score_teacher = model_teacher.attention_module(features).squeeze(0)
    
    features_np = features.cpu().detach().numpy()
    cell_score_teacher_np = cell_score_teacher.cpu().detach().numpy()

    df = pd.DataFrame(features_np, columns=[f'feature_{i}' for i in range(features_np.shape[1])])
    df['cell_type'] = label_encoder.inverse_transform(instance_test_dataset.instance_labels.cpu().detach().numpy())
    df['cell_score'] = cell_score_teacher_np
    df['bag_labels'] = instance_test_dataset.bag_labels.cpu().detach().numpy()
    df['instance_labels'] = instance_test_dataset.instance_labels.cpu().detach().numpy()
    df['cell_score_minmax'] = (df['cell_score'].values - min(df['cell_score'].values)) / (max(df['cell_score'].values) - min(df['cell_score'].values))
    if suffix is not None: 
        df.to_csv(f'cell_score_{exp}_{suffix}.csv', index=False)    
    else: 
        df.to_csv(f'cell_score_{exp}.csv', index=False)
        
    return 0

def save_test_data(exp, sample_labels, adata):
    split_ratio = [0.5, 0.25, 0.25]
    train_val_set, test_set = train_test_split(sample_labels, test_size=split_ratio[2], random_state=exp, stratify=sample_labels['disease_numeric'])
    train_set, val_set = train_test_split(train_val_set, test_size=split_ratio[1] / (1 - split_ratio[2]), random_state=exp, stratify=train_val_set['disease_numeric'])
    test_set.to_csv(f"test_set_barcodes_{exp}.csv")
    test_data = adata[adata.obs['sample_id_numeric'].isin(test_set['sample_id_numeric'])]    
    test_data.obs.to_csv(f"obs_{exp}.csv")
    
    
    test_data.obs.index = test_data.obs['cell.names'] 
    test_data.var.index = test_data.var['gene.names'] 
    
    test_data.obs.columns = [sub.replace('(', '') for sub in test_data.obs.columns]
    test_data.obs.columns = [sub.replace(')', '') for sub in test_data.obs.columns]
    test_data.obs.columns = [sub.replace('/', '') for sub in test_data.obs.columns]
    test_data.obs.columns = [sub.replace('=', '.') for sub in test_data.obs.columns]
    test_data.obs.columns = [sub.replace(' ', '_') for sub in test_data.obs.columns]
    test_data.obs.columns = [sub.replace('-', '_') for sub in test_data.obs.columns]
    
    test_data.obs.index = test_data.obs['cell.names']
    test_data.var.index = test_data.var['gene.names']
    test_data.write(filename=f"anndata_{exp}.h5ad")


In [None]:
# saved_model_path = '../../results/model_PBMC_ae_ed128_md64_lr0.0001_500_0.1_5_15'
saved_model_paths = [
        '../../results/model_PBMC_ae_ed128_md64_lr0.0001_500_0.1_5_15',
        '../../results/model_PBMC_ae_ed128_md64_lr0.0001_500_0.1_500_15_baseline2',
    ]
for saved_model_path in saved_model_paths:
    for exp in range(1, 9):
        print(f'Experiment {exp}')
        _, _, test_dataset, label_encoder = load_dataset_and_preprocessors(base_path, exp, device)
        suffix = 'baseline' if 'baseline' in saved_model_path else None
        save_cell_scores(saved_model_path, exp, test_dataset, label_encoder, device, suffix)
        # save_test_data(exp, sample_labels, adata)
        torch.cuda.empty_cache()