In [1]:
import pandas as pd
import numpy as np
import os
import toml
import torch
from m6anet.utils.builder import build_dataloader
from m6anet.utils.training_utils import validate, get_accuracy, get_roc_auc, get_pr_auc
from m6anet.model.model import MILModel
from tqdm.notebook import tqdm

In [2]:
model_config = "../m6anet/model/configs/model_configs/1_neighbor/prod_pooling_attention.toml"
model_weight = "../m6anet/model/model_states/attention_pooling_pr_auc.pt"
training_config = "../m6anet/model/configs/training_configs/m6a_classification_nanopolish/1_neighbor/oversampled.toml"

model_config = toml.load(model_config)
training_config = toml.load(training_config)
model = MILModel(model_config)
model.load_state_dict(torch.load(model_weight))

_, _, val_dl = build_dataloader(training_config, 25, verbose=True)

There are 81628 train sites
There are 20091 test sites
There are 25713 val sites


In [21]:
n_iterations = 5
model.eval()
all_y_true = None
all_y_pred = []
loss_results = {}
device = "cuda:1"
model = model.to(device)

with torch.no_grad():
    for n in range(n_iterations):
        y_true_tmp = []
        y_pred_tmp = []
        for batch in tqdm(val_dl, total=len(val_dl)):
            y_true = batch.pop('y').to(device).flatten()
            X = {key: val.to(device) for key, val in batch.items()}
            y_pred = model(X)

            if all_y_true is None:
                y_true = y_true.detach().cpu().numpy()
                y_true_tmp.extend(y_true)

            y_pred = y_pred.detach().cpu().numpy()

            if (len(y_pred.shape) == 1) or (y_pred.shape[1] == 1):
                y_pred_tmp.extend(y_pred.flatten())
            else:
                y_pred_tmp.extend(y_pred[:, 1])
        if all_y_true is None:
            all_y_true = y_true_tmp
        all_y_pred.append(y_pred_tmp)

y_pred_avg = np.mean(all_y_pred, axis=0)
all_y_true = np.array(all_y_true).flatten()

accuracy_score = get_accuracy(all_y_true, (y_pred_avg.flatten() > 0.5) * 1)
roc_auc = get_roc_auc(all_y_true, y_pred_avg)
pr_auc = get_pr_auc(all_y_true, y_pred_avg)

val_results = {}

val_results['y_pred'] = all_y_pred
val_results['y_true'] = all_y_true
val_results['accuracy'] = accuracy_score
val_results['roc_auc'] = roc_auc
val_results['pr_auc'] = pr_auc

print(accuracy_score, roc_auc, pr_auc)

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

In [23]:
n_iterations = 5
model.eval()
all_y_true = None
all_y_pred = []
loss_results = {}
device = "cuda:1"
model = model.to(device)

with torch.no_grad():
    for n in range(n_iterations):
        y_true_tmp = []
        y_pred_tmp = []
        for batch in tqdm(val_dl, total=len(val_dl)):
            y_true = batch.pop('y').to(device).flatten()
            X = {key: val.to(device) for key, val in batch.items()}
            y_pred = model.get_read_probability(X)
            y_pred = 1 - torch.prod(1 - y_pred, axis=1)
            if all_y_true is None:
                y_true = y_true.detach().cpu().numpy()
                y_true_tmp.extend(y_true)

            y_pred = y_pred.detach().cpu().numpy()
            
            if (len(y_pred.shape) == 1) or (y_pred.shape[1] == 1):
                y_pred_tmp.extend(y_pred.flatten())
            else:
                y_pred_tmp.extend(y_pred[:, 1])
        if all_y_true is None:
            all_y_true = y_true_tmp
        all_y_pred.append(y_pred_tmp)

y_pred_avg = np.mean(all_y_pred, axis=0)
all_y_true = np.array(all_y_true).flatten()

accuracy_score = get_accuracy(all_y_true, (y_pred_avg.flatten() > 0.5) * 1)
roc_auc = get_roc_auc(all_y_true, y_pred_avg)
pr_auc = get_pr_auc(all_y_true, y_pred_avg)

val_results = {}

val_results['y_pred'] = all_y_pred
val_results['y_true'] = all_y_true
val_results['accuracy'] = accuracy_score
val_results['roc_auc'] = roc_auc
val_results['pr_auc'] = pr_auc

print(accuracy_score, roc_auc, pr_auc)

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

0.8897833780577917 0.9161007223550838 0.4447791167585492
