In [1]:
import os, random
from unicodedata import name
import pandas as pd
import numpy as np
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from model import APANET, APAData
from train_script import build_dataloaders
import pickle
# from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
MODELS_ROOT = '/data/users/goodarzilab/aiden/projects/APA/codes/CNN_MATTs/APA-Net_2024_05_08_v2/'
DATA_ROOT = '/data/users/goodarzilab/aiden/projects/APA/input_data/2024_rev/22May/significant'

In [5]:
# initialize the model
config = {
        "batch_size":128,
        "epochs": 100,
        "project_name": 'test_performace',
        "device": 'cuda:03',
        "opt": "Adam",
        "loss": "mse",
        "lr": 2.5e-05,
        "adam_weight_decay": 0.09, # 0.06 before
        "conv1kc": 128, #128, 64
        "conv1ks": 12,
        "conv1st": 1,
        "pool1ks": 16,
        "pool1st": 16,
        "cnvpdrop1": 0,
        "Matt_heads": 8,
        "Matt_drop": 0.2,
        "fc1_dims": [
            8192, # 8192, 5120
            4048,
            1024,
            512,
            256,
        ],  # first dimension will be calculated dynamically
        "fc1_dropouts": [0.25, 0.25, 0.25, 0, 0],
        "fc2_dims": [128, 32, 16, 1],  # first dimension will be calculated dynamically
        "fc2_dropouts": [0.2, 0.2, 0, 0],
        'psa_query_dim': 128, # make sure this is correct
        'psa_num_layers': 1,
        'psa_nhead': 1,
        'psa_dim_feedforward':1024,
        'psa_dropout': 0 
    }

def get_model(model_path='model.pth'):
    model = APANET(config)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda:3')))
    model.eval()
    return model

In [6]:
# get model performace for test data
def get_performance(model, data_loader):
    model.eval()
    y_pred = []
    y_true = []
    celltypes = []
    switch_names = []
    with torch.no_grad():
        for seq_X, Y, celltype, celltype_name, switch_name in data_loader:
            outputs = torch.squeeze(model(seq_X, celltype))
            y_pred.append(outputs.cpu().numpy())
            y_true.append(Y.cpu().numpy())
            celltypes.append(celltype_name)
            switch_names.append(switch_name)
    y_pred = np.concatenate(y_pred)
    y_true = np.concatenate(y_true)
    celltypes = np.concatenate(celltypes)
    switch_names = np.concatenate(switch_names)
    return y_pred, y_true, celltypes, switch_names

In [7]:
perfdf = pd.DataFrame()
for i in range(5):
    model = get_model(f'{MODELS_ROOT}model_outs/significant/model_out_fold{i}.pt').to('cuda:3')
    test_data = np.load(f'{DATA_ROOT}/test_fold_{i}.npy', allow_pickle=True)
    _, test_loader = build_dataloaders(
            'cuda:3',
            test_data,
            test_data,
            config["batch_size"],
        )
    y_pred, y_true, celltypes, switch_names = get_performance(model, test_loader)
    fold_column = [f'fold_{i}']*len(y_pred)
    perfdf = pd.concat([perfdf, pd.DataFrame({'y_pred': y_pred, 'y_true': y_true, 'celltypes': celltypes, 'switch_names': switch_names, 'fold': fold_column})])
    valid_R = stats.pearsonr(perfdf['y_true'], perfdf['y_pred'])
perfdf.head()

Unnamed: 0,y_pred,y_true,celltypes,switch_names,fold
0,-0.126883,-0.243074,Astro,chr6:AARS2:44299560:44300312:-,fold_0
1,-0.170426,0.463603,Astro,chr17:ACADVL:7222082:7225225:+,fold_0
2,0.357559,-0.104181,Astro,chr1:ACP6:147647368:147649817:-,fold_0
3,-0.505027,0.43645,Astro,chr1:ACP6:147649817:147659061:-,fold_0
4,-0.948004,-0.6193,Astro,chr15:ACSBG1:78169717:78170845:-,fold_0


In [8]:
# save the performance dataframe
perfdf.to_csv('test_set_performance.csv', index=False)

In [9]:
stats.pearsonr(perfdf['y_true'], perfdf['y_pred'])

PearsonRResult(statistic=0.5854367301153052, pvalue=0.0)

In [10]:
perfdf['celltypes'].value_counts()

Inh          8217
Exc_upper    6851
OPC          6200
Astro        5789
Microglia    5755
Oligo        5681
Exc_deep     4561
Exc_int      4250
Name: celltypes, dtype: int64