In [None]:
# Use if main package is not in python path
# import os, sys
# sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

import time
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from utils.datasets import TrainSetArtefacts
from models.models_UNet import UNet
from training_funcs import sample_P

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
np.random.seed(42)
torch.manual_seed(42);

n_gpus = torch.cuda.device_count()
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
print(f"Using GPU: {cuda}")
print(f"Available GPUs: {n_gpus}")
print("Only tested with CUDA enabled")

In [None]:
crop_dim = 384  # Quadratic dimension of the image crops

wu = 16  # Width unit
bs = 16

predictor = UNet(1, 1, wu, crop_dim, 1, dilation=False, kernel_size=3, padding=1, bias=False, padding_mode="reflect")
print(f"Predictor Number of Paramters: {predictor.count_parameters():,}")

if cuda:
    predictor = predictor.cuda(device)
    

predictor.load_state_dict(torch.load("model_saves/predictor.pt", map_location=device))
predictor.to(device)
predictor.eval()
predictor.requires_grad_(False)

test_data = TrainSetArtefacts(10000, crop_dim, n_test=10, test=True)
test_loader = DataLoader(test_data, batch_size=bs, shuffle=False, num_workers=0, drop_last=False, pin_memory=True)

In [None]:
loss_hist = []
acc_hist = np.zeros(0)
criterion = nn.BCEWithLogitsLoss()
start_time = time.time()        
for it, (crop, anno) in enumerate(test_loader):  # Shapes: [bs, 1, crop_dim, crop_dim]
    if cuda:
        crop = crop.cuda(device)
        anno = anno.cuda(device)

    with torch.cuda.amp.autocast():
        anno_pred = predictor(crop)
        loss = criterion(anno_pred, anno.half())
        pred = torch.round(torch.sigmoid(anno_pred))
        acc = pred.eq_(anno.half()).mean(dim=(-1, -2))
        acc = acc.squeeze(1).cpu().detach().numpy()

    loss_hist.append(loss.item())
    acc_hist = np.append(acc_hist, acc)
print(f"Evaluation finished after {time.time()-start_time:.1f}s.")

In [None]:
torch.save(acc_hist, "model_saves/eval_acc_hist")
acc_hist = torch.load("model_saves/eval_acc_hist")

In [None]:
plt.figure(figsize=(16,5))
plt.title("Segmentation Accuracy", size=15)
bin_size = 1/500
bins = np.arange(0.8, 1+bin_size, bin_size)
plt.hist(acc_hist, bins, label="Sample Accuracy")
y_max = plt.gca().get_yticks()[-2]
plt.vlines(np.mean(acc_hist), 0, y_max, linestyle="--", color="black", label="Mean Accuracy")
plt.xlabel(f"Annotation Coverage (bin size {bin_size})", size=12)
plt.ylabel("No. of Occurences", size=12)
plt.gca().set_xticklabels([f"{x*100:.1f}%" for x in plt.gca().get_xticks()])
plt.legend()
# plt.savefig("hparam_tuning_imgs/AL_eval.png", dpi=200, bbox_inches='tight')
plt.show()

In [None]:
with torch.cuda.amp.autocast():
    sample_P(test_loader, predictor, cuda=cuda, device=device, plot=True, save_as=None, seed=None)