## Imports etc.

In [None]:
import os

import timm
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from lib.utils.pre_train import get_configs
from lib.utils.misc import WrappedModel
from lib.dataset.synthetic_dataset import SyntheticCATARACTSDataset

## Load configs

In [None]:
# DATA_PATH = '/home/yannik/CataractsConditionalDiffusion/results/ddim_gen_samples/'
# DATA_PATH = '/home/yannik/CataractsConditionalDiffusion/results/cgan_gen_samples/'
DATA_PATH = '/home/yannik/CataractsConditionalDiffusion/results/vqvae2_gen_samples/'
LOG_PATH = 'results/tool_classifier_resnet50/2023.02.28 07_50_36/'
TARGET_PATH = os.path.join(LOG_PATH, "eval/")
DEV = 'cuda'
data_conf, model_conf, diffusion_conf, train_conf = get_configs(LOG_PATH + "config.yaml")
os.makedirs(TARGET_PATH, exist_ok=True)
STEPS = 1
BATCH_SIZE = 64
print(f"Avail. GPUs: ", torch.cuda.device_count())

## Load data

In [None]:
test_ds = SyntheticCATARACTSDataset(
    root=DATA_PATH,
    resize_shape=eval(data_conf['SHAPE'])[1:],
    normalize=eval(data_conf['NORM']),
)

# TODO: Weighted sampling / sampling from p(toolset|phase)
dl = DataLoader(test_ds, batch_size=BATCH_SIZE, num_workers=8, drop_last=True, shuffle=True, pin_memory=False)
print(f"{len(test_ds)} samples")

## Load model

In [None]:
match model_conf['TYPE'].upper():
    case 'INCEPTIONV4':
        m = timm.create_model('inception_v4',
                              pretrained=True,
                              num_classes=test_ds.dataset.num_tool_classes).to(DEV)
    case 'RESNET18':
        m = timm.create_model('resnet18',
                              pretrained=True,
                              num_classes=test_ds.dataset.num_tool_classes).to(DEV)
    case 'RESNET50':
        m = timm.create_model('resnet50',
                              pretrained=True,
                              num_classes=test_ds.dataset.num_tool_classes).to(DEV)
    case _:
        raise NotImplementedError
m = torch.nn.DataParallel(m, device_ids=[DEV]) if not DEV == 'cpu' else WrappedModel(m)
try:
    m.load_state_dict(torch.load(LOG_PATH + "ckpt.pth", map_location='cpu')[0])
except:
    m.module.load_state_dict(torch.load(LOG_PATH + "ckpt.pth", map_location='cpu')[0])
m.eval()

In [None]:
import gc

import torch
from torchmetrics.functional import f1_score, auroc, accuracy

from lib.utils.misc import label_vectors_to_names

gc.collect()
torch.cuda.empty_cache()

## Test performance

In [None]:
test_targets = None
test_predictions = None

test_phase_wise_targets = [None] * test_ds.num_phases_classes
test_phase_wise_predictions = [None] * test_ds.num_phases_classes
tool_count, no_tool_count = [0] * test_ds.num_phases_classes, [0]*test_ds.num_phases_classes

with torch.no_grad():
    for i, (img, _, file_name, phase_label, tool_label) in enumerate(tqdm(dl)):

        phase_label = phase_label.int()

        img, tool_label = img.to(DEV), tool_label.to(DEV)

        pred = torch.sigmoid(m(img))

        """
        for n in range(BATCH_SIZE):
            # TODO: pred to int by thresholding
            pred_n_int = (pred[n] > .5).int()
            print("Prediction: ", label_vectors_to_names(test_ds, None, pred_n_int)[1])
            print("Target: ", label_vectors_to_names(test_ds, None, tool_label[n])[1])
            print()
        """

        test_targets = tool_label if test_targets is None else torch.cat([test_targets, tool_label], dim=0)
        test_predictions = pred if test_predictions is None else torch.cat([test_predictions, pred], dim=0)

        for n in range(BATCH_SIZE):
            phase_id_n = phase_label[n].item()
            if (tool_label[n] == 0.).all():
                no_tool_count[phase_id_n] += 1
            else:
                tool_count[phase_id_n] += 1
            test_phase_wise_targets[phase_id_n] = tool_label[n] if test_phase_wise_targets[phase_id_n] is None \
                else torch.cat([test_phase_wise_targets[phase_id_n], tool_label[n]], dim=0)
            test_phase_wise_predictions[phase_id_n] = pred[n] if test_phase_wise_predictions[phase_id_n] is None \
                else torch.cat([test_phase_wise_predictions[phase_id_n], pred[n]], dim=0)

print("F1 Score: ", f1_score(test_predictions, test_targets, task='binary', threshold=0.5))
print("AUROC: ", auroc(test_predictions, test_targets, task='binary'))
print("Accuray: ", accuracy(test_predictions, test_targets, task='binary', threshold=0.5))

## Plotting

In [None]:
import matplotlib.pyplot as plt
import numpy as np
f1_score_per_phase = [0]*len(test_ds.phase_label_names)
auroc_per_phase = [0]*len(test_ds.phase_label_names)
acc_per_phase = [0]*len(test_ds.phase_label_names)
plt.figure(figsize=(30, 5))
for phase_id, phase_name in enumerate(test_ds.phase_label_names):
    print(f"########## ---------- {phase_name}")

    preds = test_phase_wise_predictions[phase_id]
    targs = test_phase_wise_targets[phase_id].int()
    #preds = val_phase_wise_predictions[phase_id]
    #targs = val_phase_wise_targets[phase_id]

    f1 = f1_score(preds, targs).item()
    aur = auroc(preds, targs, task='binary').item()
    acc = accuracy(preds, targs).item()
    print("F1 Score: ", f1)
    print("AUROC: ", aur)
    print("Accuray: ", acc)
    f1_score_per_phase[phase_id] = f1
    auroc_per_phase[phase_id] = aur
    acc_per_phase[phase_id] = acc
#plt.bar(x=test_ds.phase_label_names, width=0.3, height=f1_score_per_phase, label='F1')
#plt.bar(x=test_ds.phase_label_names, width=0.3, height=auroc_per_phase, label='AUROC')
#plt.bar(x=test_ds.phase_label_names, width=0.3, height=acc_per_phase, label='F1')
plt.bar(x=np.arange(0, test_ds.num_phases_classes) - .2, width=0.2, height=f1_score_per_phase, label='F1')
#plt.bar(x=np.arange(0, test_ds.num_phases_classes) , width=0.2, height=auroc_per_phase, label='AUROC')
#plt.bar(x=np.arange(0, test_ds.num_phases_classes) + .2, width=0.2, height=acc_per_phase, label='Acc.')
plt.legend()
plt.xticks(ticks=np.arange(0, test_ds.num_phases_classes), labels=test_ds.phase_label_names, rotation = 45)
plt.autoscale()
# plt.savefig('plots/val_set_phase_wise_performance.svg', format='svg', bbox_inches='tight')
# plt.savefig('plots/test_set_phase_wise_performance.svg', format='svg', bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(30, 5))
plt.bar(x=np.arange(0, test_ds.num_phases_classes), width=0.2, height=tool_count, label='tools present')
plt.bar(x=np.arange(0, test_ds.num_phases_classes) + .2, width=0.2, height=no_tool_count, label='no tools')
plt.legend()
plt.xticks(ticks=np.arange(0, test_ds.num_phases_classes), labels=test_ds.phase_label_names, rotation = 45)
plt.autoscale()
plt.show()