In [None]:
from train_img_clfs import *

In [None]:
def eval(modality:str, img_size:int):
    eval_ds = MimicIMG(modality=modality, split='eval', img_size=img_size, transform=False, undersample_dataset=False)
    eval_loader = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=DL_WORKERS)

    lightning_module = LM(str_labels=eval_ds.str_labels)
    lightning_module.model.load_state_dict(torch.load(f'state_dicts/{modality}_clf_{img_size}.pth', map_location=DEVICE))

    predictions, targets = [], []
    lightning_module.model.eval()

    with torch.no_grad():
        for batch in eval_loader:
            x, y = batch
            x = x.to(DEVICE)
            logits = lightning_module(x)
            targets = torch.cat((targets, y.cpu()))
            predictions = torch.cat((predictions, logits.cpu()))

    for idx, label in enumerate(eval_ds.str_labels):
        preds_label = predictions[:, idx]
        y_label = targets[:, idx].int()
        auroc_score = auroc(preds_label, y_label)
        av_precision_score = average_precision(preds_label, y_label)
        preds_thr = (preds_label > 0.5).int()
        acc = accuracy_metric(preds_thr, y_label)
        prec = precision(preds_thr, y_label)
        rec = recall(preds_thr, y_label)
        print(f'{label}_auroc', auroc_score)
        print(f'{label}_avg_precision', av_precision_score)
        print(f'{label}_val_acc', acc)
        print(f'{label}_val_precision', prec)
        print(f'{label}_val_recall', rec)
        print(f'{label}_pred_pos', sum(preds_thr).item())
        print(f'{label}_true_pos', sum(y_label).item())

In [None]:
eval('pa', 256)

In [None]:
eval('lat', 256)