## Imports

In [1]:
import torch
import yaml
from data.dataloader import load_data
from model.network import create_model
from data.analysis import get_confusion_matrix, fisher_test

## Config

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}\n')

config = yaml.load(open('./config.yaml', 'r'), Loader=yaml.FullLoader)
config['device'] = device

## Load Data

In [None]:
train_data_loader, val_data_loader, test_data_loader = load_data(config)

## Create Model and Load Checkpoint

In [None]:
model = create_model(config)

checkpoint_path = 'CHANGE THIS'
checkpoint_path = f'./checkpoints/best/{config["task"]}/model.pt'

model.load_state_dict(
    torch.load(
        checkpoint_path
    )['model_state_dict'], 
    strict=False
)

## Confusion Matrix

In [None]:
train_conf_mat = get_confusion_matrix(
    model, train_data_loader, 
    device, save=False, 
    fname=f'{config["task"]}-train'
)
val_conf_mat = get_confusion_matrix(
    model, val_data_loader, 
    device, save=False, 
    fname=f'{config["task"]}-val'
)
test_conf_mat = get_confusion_matrix(
    model, test_data_loader, 
    device, save=False, 
    fname=f'{config["task"]}-test'
)

print(train_conf_mat, val_conf_mat, test_conf_mat)

## P-values using Fisher's Exact Test (Two-sided)

In [None]:
train_pvalue = fisher_test(train_conf_mat)
val_pvalue = fisher_test(val_conf_mat)
test_pvalue = fisher_test(test_conf_mat)

print(train_pvalue, val_pvalue, test_pvalue)