In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as Tf
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.data import Cataract101
from src.model.phase_classifier_model import PhaseClassifier

In [None]:
dev = 'cuda'

train_ds = Cataract101(root='/local/scratch/cataract-101-processed/',
                       n_seq_frames=3,
                       dt=1,
                       transforms=Tf.Compose([
                           Tf.Resize((128, 128)),
                           Tf.Normalize(0.5, 0.5)
                       ]),
                       sample_phase_annotations=True,
                       split="Training")

val_ds = Cataract101(root='/local/scratch/cataract-101-processed/',
                       n_seq_frames=3,
                       dt=1,
                       transforms=Tf.Compose([
                           Tf.Resize((128, 128)),
                           Tf.Normalize(0.5, 0.5)
                       ]),
                       sample_phase_annotations=True,
                       split="Validation")

test_ds = Cataract101(root='/local/scratch/cataract-101-processed/',
                      n_seq_frames=3,
                      dt=1,
                      transforms=Tf.Compose([
                          Tf.Resize((128, 128)),
                          Tf.Normalize(0.5, 0.5)
                      ]),
                      sample_phase_annotations=True,
                      split="Test")

test_dl = DataLoader(val_ds, batch_size=1, num_workers=1, shuffle=False)

In [None]:
m = PhaseClassifier(n_seq_frames=3, n_classes=11).to(dev)
#m.load_state_dict(torch.load('../../../results/phase_model/phase_model_extended1.pth', map_location='cpu'))
m.load_state_dict(torch.load('../../../results/phase_model/phase_model.pth', map_location='cpu'))
m.eval()

In [None]:
phase_predictions = None
phase_target = None
case_id = None
with torch.no_grad():
    for id, sample in enumerate(tqdm(test_dl)):
        img = sample['img_seq']
        _case_id = sample['case_id']
        N, T, C, H, W = img.shape
        img = img.view((N, T*C, H, W)).to(dev)
        phase = sample['phase_seq'][:, -1]
        predicted_phase = m(img)
        phase_predictions = predicted_phase if phase_predictions is None else torch.cat([phase_predictions, predicted_phase], dim=0)
        phase_target = phase if phase_target is None else torch.cat([phase_target, phase], dim=0)

In [None]:
_phase_predictions = torch.argmax(phase_predictions, dim=-1)
_phase_target = torch.argmax(phase_target, dim=-1)
plt.figure()
plt.grid()
plt.scatter(np.arange(0, _phase_target.shape[0]), _phase_target.numpy()-0.2, label='target')
plt.scatter(np.arange(0, _phase_predictions.shape[0]), _phase_predictions.cpu().numpy(), label='prediction')
plt.ylim(-1, 11)
plt.yticks(np.arange(0, 11))
plt.legend()
plt.show()

In [None]:
from torchmetrics import F1Score, AUROC, Accuracy, AveragePrecision
from src.metrics.temporal_consistency import time_seg_cluster_metric

In [None]:
ap_score = AveragePrecision(num_classes=11, average=None).to(dev)
f1_score = F1Score(num_classes=11, average=None).to(dev)
auroc_score = AUROC(num_classes=11, average=None).to(dev)
print(torch.stack(ap_score(phase_predictions, _phase_target.to(dev))))
print(torch.stack(ap_score(phase_predictions, _phase_target.to(dev))).var())
print(f1_score(phase_predictions, _phase_target.to(dev)))
print(f1_score(phase_predictions, _phase_target.to(dev)).var())
print(auroc_score(phase_predictions, _phase_target.to(dev)))
print(auroc_score(phase_predictions, _phase_target.to(dev)).var())

In [None]:
ap_score = AveragePrecision(num_classes=11).to(dev)
f1_score = F1Score(num_classes=11).to(dev)
auroc_score = AUROC(num_classes=11).to(dev)
print(ap_score(phase_predictions, _phase_target.to(dev)))
print(f1_score(phase_predictions, _phase_target.to(dev)))
print(auroc_score(phase_predictions, _phase_target.to(dev)))

In [None]:
_phase_target.shape

In [None]:
print(time_seg_cluster_metric(phase_predictions.argmax(-1).cpu().numpy(), _phase_target.numpy()))

In [None]:
ap_score = AveragePrecision(num_classes=11).to(dev)
f1_score = F1Score(num_classes=11).to(dev)
auroc_score = AUROC(num_classes=11).to(dev)
print(ap_score(phase_predictions, _phase_target.to(dev)))
print(f1_score(phase_predictions, _phase_target.to(dev)))
print(auroc_score(phase_predictions, _phase_target.to(dev)))

In [None]:
_phase_target.shape

In [None]:
print(time_seg_cluster_metric(phase_predictions.argmax(-1).cpu().numpy(), _phase_target.numpy()))