In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, ConcatDataset, Subset
import torchvision.transforms as Tf
from torchmetrics import F1Score, AUROC, Accuracy
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.data import Cataract101
from src.model.phase_classifier_model import PhaseClassifier

## Load datasets

In [None]:
dev = 'cuda'

train_ds = Cataract101(root='/local/scratch/cataract-101-processed/',
                       n_seq_frames=3,
                       dt=1,
                       transforms=Tf.Compose([
                           Tf.Resize((128, 128)),
                           Tf.Normalize(0.5, 0.5)
                       ]),
                       sample_phase_annotations=True,
                       split="Training")

val_ds = Cataract101(root='/local/scratch/cataract-101-processed/',
                       n_seq_frames=3,
                       dt=1,
                       transforms=Tf.Compose([
                           Tf.Resize((128, 128)),
                           Tf.Normalize(0.5, 0.5)
                       ]),
                       sample_phase_annotations=True,
                       split="Validation")

test_ds = Cataract101(root='/local/scratch/cataract-101-processed/',
                      n_seq_frames=3,
                      dt=1,
                      transforms=Tf.Compose([
                          Tf.Resize((128, 128)),
                          Tf.Normalize(0.5, 0.5)
                      ]),
                      sample_phase_annotations=True,
                      split="Test")

test_ds = ConcatDataset([val_ds, test_ds])

## Eval. extended model

In [None]:
m = PhaseClassifier(n_seq_frames=3, n_classes=11).to(dev)
m.load_state_dict(torch.load('../../../results/phase_model/phase_model_extended1.pth', map_location='cpu'))
m.eval()

In [None]:
print(len(test_ds))
N = 30
acc_scores_per_split = []
f1_scores_per_split = []
auroc_scores_per_split = []
acc_score = Accuracy(num_classes=11).to(dev)
f1_score = F1Score(num_classes=11).to(dev)
auroc_score = AUROC(num_classes=11).to(dev)
for start in np.arange(0, len(test_ds), int(len(test_ds))/N):
    _test_ds = Subset(test_ds, range(int(start), int(start + len(test_ds)//5)))
    test_dl = DataLoader(_test_ds, batch_size=16, num_workers=4, shuffle=False)
    phase_predictions = None
    phase_target = None
    case_id = None
    with torch.no_grad():
        for id, sample in enumerate(tqdm(test_dl)):
            img = sample['img_seq']
            _case_id = sample['case_id']
            N, T, C, H, W = img.shape
            img = img.view((N, T*C, H, W)).to(dev)
            phase = sample['phase_seq'][:, -1]
            predicted_phase = m(img)
            phase_predictions = predicted_phase if phase_predictions is None else torch.cat([phase_predictions, predicted_phase], dim=0)
            phase_target = phase if phase_target is None else torch.cat([phase_target, phase], dim=0)
    _phase_target = torch.argmax(phase_target, dim=-1)
    acc_scores_per_split.append(acc_score(phase_predictions, _phase_target.to(dev)).item())
    f1_scores_per_split.append(f1_score(phase_predictions, _phase_target.to(dev)).item())
    auroc_scores_per_split.append(auroc_score(phase_predictions, _phase_target.to(dev)).item())

In [None]:
print(acc_scores_per_split)

In [None]:
print(f1_scores_per_split)

In [None]:
print(auroc_scores_per_split)

## Eval. normal model

In [None]:
m = PhaseClassifier(n_seq_frames=3, n_classes=11).to(dev)
m.load_state_dict(torch.load('../../../results/phase_model/phase_model.pth', map_location='cpu'))
m.eval()

In [None]:
N = 30
acc_scores_per_split= []
f1_scores_per_split = []
auroc_scores_per_split = []
for start in np.arange(0, len(test_ds), int(len(test_ds))/N):
    _test_ds = Subset(test_ds, range(int(start), int(start + len(test_ds)//5)))
    test_dl = DataLoader(_test_ds, batch_size=16, num_workers=4, shuffle=False)
    phase_predictions = None
    phase_target = None
    case_id = None
    with torch.no_grad():
        for id, sample in enumerate(tqdm(test_dl)):
            img = sample['img_seq']
            _case_id = sample['case_id']
            N, T, C, H, W = img.shape
            img = img.view((N, T*C, H, W)).to(dev)
            phase = sample['phase_seq'][:, -1]
            predicted_phase = m(img)
            phase_predictions = predicted_phase if phase_predictions is None else torch.cat([phase_predictions, predicted_phase], dim=0)
            phase_target = phase if phase_target is None else torch.cat([phase_target, phase], dim=0)
    _phase_target = torch.argmax(phase_target, dim=-1)
    acc_scores_per_split.append(acc_score(phase_predictions, _phase_target.to(dev)).item())
    f1_scores_per_split.append(f1_score(phase_predictions, _phase_target.to(dev)).item())
    auroc_scores_per_split.append(auroc_score(phase_predictions, _phase_target.to(dev)).item())

In [None]:
print(acc_scores_per_split)

In [None]:
print(f1_scores_per_split)

In [None]:
print(auroc_scores_per_split)