In [1]:
import random, os, torch, numpy as np


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
#seed_everything()

In [2]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import pandas as pd
#import numpy as np
import torch.fft
import subprocess
import logging
#import random
import shutil
import psutil
import sklearn
import scipy
#import torch
import copy
import yaml
import time
import tqdm
import sys
import gc

import segmentation_models_pytorch as smp

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
#import torch

from holodecml.data import PickleReader, UpsamplingReader, XarrayReader, XarrayReaderLabels
from holodecml.propagation import InferencePropagator
from holodecml.transforms import LoadTransformations
from holodecml.models import load_model
from holodecml.losses import load_loss

#import os
import warnings
warnings.filterwarnings("ignore")
import lpips

In [3]:
import sklearn, sklearn.metrics

def man_metrics(results):
    result = {}
    for metric in ["f1", "auc", 'pod', "far", "csi"]: #"man_prec", "man_recall",
        if metric == 'f1':
            score = sklearn.metrics.f1_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'prec':
            score = sklearn.metrics.precision_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'recall':
            score = sklearn.metrics.recall_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'auc':
            try:
                score = sklearn.metrics.roc_auc_score(results["true"], results["pred"], average = "weighted")
            except:
                score = 1.0
        elif metric ==  "csi":
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = TP / (TP + FN + FP)
            except:
                score = 1
        elif metric == 'far':
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = FP / (TP + FP)
            except:
                score = 1
        elif metric == 'pod':
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = TP / (TP + FN)
            except: 
                score = 1
        result[metric] = score
        #print(metric, round(score, 3))
    return result

In [4]:
def apply_transforms(transforms, image):
    im = {"image": image}
    for image_transform in transforms:
        im = image_transform(im)
    image = im["image"]
    return image

In [5]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

available_ncpus = len(psutil.Process().cpu_affinity())

# Set up the GPU
is_cuda = torch.cuda.is_available()
device = torch.device("cpu") if not is_cuda else torch.device("cuda:0")

In [6]:
# ### Set seeds for reproducibility
# def seed_everything(seed=1234):
#     random.seed(seed)
#     os.environ['PYTHONHASHSEED'] = str(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(seed)
#         torch.backends.cudnn.benchmark = True
#         torch.backends.cudnn.deterministic = True
        
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [7]:
config = "../results/random_test/model.yml" #"../config/gan.yml"
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [8]:
# Set seeds for reproducibility
seed = 1000 if "seed" not in conf else conf["seed"]
seed_everything(seed)

save_loc = conf["save_loc"]
os.makedirs(save_loc, exist_ok = True)
os.makedirs(os.path.join(save_loc, "images"), exist_ok = True)
if not os.path.isfile(os.path.join(save_loc, "model.yml")):
    shutil.copyfile(config, os.path.join(save_loc, "model.yml"))

tile_size = int(conf["data"]["tile_size"])
step_size = int(conf["data"]["step_size"])
data_path = conf["data"]["output_path"]
data_path_raw = conf["data"]["output_path_raw"]

total_positive = int(conf["data"]["total_positive"])
total_negative = int(conf["data"]["total_negative"])
total_examples = int(conf["data"]["total_training"])

transform_mode = "None" if "transform_mode" not in conf["data"] else conf["data"]["transform_mode"]
config_ncpus = int(conf["data"]["cores"])
use_cached = False if "use_cached" not in conf["data"] else conf["data"]["use_cached"]

name_tag = f"{tile_size}_{step_size}_{total_positive}_{total_negative}_{total_examples}_{transform_mode}"
fn_train = f"{data_path}/training_{name_tag}.nc"
fn_valid = f"{data_path}/validation_{name_tag}.nc"
fn_train_raw = data_path_raw
#fn_train_raw = f"{data_path_raw}/training_{name_tag}.nc"
#fn_valid_raw = f"{data_path_raw}/validation_{name_tag}.nc"

# Trainer params
train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]

epochs = conf["trainer"]["epochs"]
batches_per_epoch = conf["trainer"]["batches_per_epoch"]
Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor
adv_loss = conf["trainer"]["adv_loss"]
lambda_gp = conf["trainer"]["lambda_gp"]
mask_penalty = conf["trainer"]["mask_penalty"]
regression_penalty = conf["trainer"]["regression_penalty"]
train_gen_every = conf["trainer"]["train_gen_every"]
train_disc_every = conf["trainer"]["train_disc_every"]
threshold = conf["trainer"]["threshold"]

In [9]:
# Load the preprocessing transforms
if "Normalize" in conf["transforms"]["training"]:
    conf["transforms"]["validation"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]
    conf["transforms"]["inference"]["Normalize"]["mode"] = conf["transforms"]["training"]["Normalize"]["mode"]

train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

In [10]:
train_synthetic_dataset = XarrayReader(fn_train, train_transforms)
test_synthetic_dataset = XarrayReader(fn_valid, valid_transforms)

# train_synthetic_dataset = PickleReader(
#     fn_train,
#     transform=train_transforms,
#     max_images=int(0.8 * conf["data"]["total_training"]),
#     max_buffer_size=int(0.1 * conf["data"]["total_training"]),
#     color_dim=conf["discriminator"]["in_channels"],
#     shuffle=True
# )

# test_synthetic_dataset = PickleReader(
#     fn_valid,
#     transform=valid_transforms,
#     max_images=int(0.1 * conf["data"]["total_training"]),
#     max_buffer_size=int(0.1 * conf["data"]["total_training"]),
#     color_dim=conf["discriminator"]["in_channels"],
#     shuffle=False
# )

In [11]:
train_synthetic_loader = torch.utils.data.DataLoader(
    train_synthetic_dataset,
    batch_size=train_batch_size,
    num_workers=available_ncpus//2,
    pin_memory=True,
    shuffle=True)

# test_synthetic_loader = torch.utils.data.DataLoader(
#     test_synthetic_dataset,
#     batch_size=valid_batch_size,
#     num_workers=0,  # 0 = One worker with the main process
#     pin_memory=True,
#     shuffle=False)

In [12]:
train_holodec_dataset = XarrayReaderLabels(fn_train_raw, train_transforms)
#test_holodec_dataset = XarrayReader(fn_valid_raw, valid_transforms)

# train_holodec_dataset = PickleReader(
#     fn_train_raw,
#     transform=train_transforms,
#     max_images=int(0.8 * conf["data"]["total_training"]),
#     max_buffer_size=int(0.1 * conf["data"]["total_training"]),
#     color_dim=conf["discriminator"]["in_channels"],
#     shuffle=True
# )

# test_holodec_dataset = PickleReader(
#     fn_valid_raw,
#     transform=valid_transforms,
#     max_images=int(0.1 * conf["data"]["total_training"]),
#     max_buffer_size=int(0.1 * conf["data"]["total_training"]),
#     color_dim=conf["discriminator"]["in_channels"],
#     shuffle=False
# )

In [13]:
train_holodec_loader = torch.utils.data.DataLoader(
    train_holodec_dataset,
    batch_size=train_batch_size,
    num_workers=available_ncpus//2,
    pin_memory=True,
    shuffle=True)

# test_holodec_loader = torch.utils.data.DataLoader(
#     test_holodec_dataset,
#     batch_size=valid_batch_size,
#     num_workers=0,  # 0 = One worker with the main process
#     pin_memory=True,
#     shuffle=False)

### Load models

In [14]:
generator = load_model(conf["generator"]).to(device) 
discriminator = load_model(conf["discriminator"]).to(device)
model = load_model(conf["model"]).to(device)

In [15]:
adv_loss = conf["trainer"]["adv_loss"]
if adv_loss == "bce":
    adversarial_loss = torch.nn.BCELoss().to(device)
    
perceptual_alex = lpips.LPIPS(net='alex').cuda()

Mask_Loss = load_loss("focal-tyversky")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /glade/work/schreck/py37/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth


In [16]:
optimizer_G = torch.optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()),
    lr = conf["optimizer_G"]["learning_rate"],
    betas = (conf["optimizer_G"]["b0"], conf["optimizer_G"]["b1"]))

optimizer_D = torch.optim.Adam(
    filter(lambda p: p.requires_grad, discriminator.parameters()), 
    lr = conf["optimizer_D"]["learning_rate"], 
    betas = (conf["optimizer_D"]["b0"], conf["optimizer_D"]["b1"]))

optimizer_M = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr =  conf["optimizer_M"]["learning_rate"], 
    betas = (conf["optimizer_M"]["b0"], conf["optimizer_M"]["b1"]))

In [17]:
def compute_gradient_penalty(discriminator, real_imgs, gen_imgs):
    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
    interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
    out = discriminator(interpolated)[1]
    grad = torch.autograd.grad(outputs=out,
                               inputs=interpolated,
                               grad_outputs=torch.ones(out.size()).cuda(),
                               retain_graph=True,
                               create_graph=True,
                               only_inputs=True)[0]
    grad = grad.view(grad.size(0), -1)
    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
    return d_loss_gp

In [18]:
lr_G_decay = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.2)
lr_D_decay = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.2)
lr_M_decay = torch.optim.lr_scheduler.StepLR(optimizer_M, step_size=10, gamma=0.2)

In [None]:
results = defaultdict(list)
for epoch in range(epochs):
    
    ### Train
    real_images = iter(train_holodec_loader)
    synthethic_images = iter(train_synthetic_loader)
    dual_iter = tqdm.tqdm(
        enumerate(zip(real_images, synthethic_images)),
        total = batches_per_epoch, 
        leave = True)
    
    train_results = defaultdict(list)
    for i, ((holo_img, holo_label), (synth_img, synth_label)) in dual_iter:
            
        if holo_img.shape[0] != synth_img.shape[0]:
            continue
                            
        # Adversarial ground truths
        valid = Variable(Tensor(holo_img.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(holo_img.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(holo_img.type(Tensor))
        synthethic_imgs = Variable(synth_img.type(Tensor))
                
        # Sample noise as generator input
        #z = Variable(Tensor(np.random.normal(0, 1.0, (holo_img.shape[0], 512))))
        z = Variable(Tensor(np.random.normal(0, 1.0, synthethic_imgs.shape)))
        # C-GAN-like input using the synthethic image as conditional input
        gen_input = torch.cat([synthethic_imgs, z], 1)
        # Generate a batch of images
        gen_imgs = generator(gen_input)
        # Discriminate the fake images
        pred_masks, verdict = discriminator(gen_imgs)
        
        # -----------------
        #  Train mask model
        # -----------------
        
        optimizer_M.zero_grad()
        requires_grad(model, True)
        pred_masks = model(gen_imgs.detach())
        mask_loss = Mask_Loss(pred_masks, synth_label.to(device))
        train_results["mask_loss"].append(mask_loss.item())
        mask_loss.backward()
        optimizer_M.step()
        requires_grad(model, False)
        
        # -----------------
        #  Train Generator
        # -----------------
            
        if (i + 1) % train_gen_every == 0:
            
            optimizer_G.zero_grad()
            requires_grad(generator, True)
            requires_grad(discriminator, False)
            
            # Loss measures generator's ability to fool the discriminator
            if adv_loss == 'wgan-gp':
                g_loss = -verdict.mean()
            elif adv_loss == 'hinge':
                g_loss = -verdict.mean()
            elif adv_loss == 'bce':
                g_loss = adversarial_loss(verdict, valid)
                
            # compute L1/L2 reg term
#             mask = synth_label.to(device).bool()
#             reg_loss = torch.nn.MSELoss()(
#                 torch.masked_select(gen_imgs, mask),
#                 torch.masked_select(synthethic_imgs, mask)
#             )

            mask = synth_label.to(device)
            reg_loss = torch.nn.MSELoss()(gen_imgs, synthethic_imgs)

            train_results["g_reg"].append(reg_loss.item())
            g_loss += regression_penalty * reg_loss
            
            # compute mask loss reg term
            pred_masks = model(gen_imgs)
            mask_loss = Mask_Loss(pred_masks, synth_label.to(device))
            train_results["fake_mask_loss"].append(mask_loss.item())
            g_loss += mask_penalty * mask_loss
            train_results["g_loss"].append(g_loss.item())
                
            g_loss.backward()
            optimizer_G.step()
            
            # compute perception scores
            p_score_syn = perceptual_alex(gen_imgs, synthethic_imgs).mean()
            p_score_real = perceptual_alex(gen_imgs, real_imgs).mean()
            train_results["p_syn"].append(p_score_syn.item())
            train_results["p_real"].append(p_score_real.item())
            
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        if (i + 1) % train_disc_every == 0:
        
            optimizer_D.zero_grad()
            requires_grad(generator, False)
            requires_grad(discriminator, True)

            # Measure discriminator's ability to classify real from generated samples
            _, disc_real = discriminator(real_imgs)
            _, disc_synth = discriminator(gen_imgs.detach())
            _, disc_synth_true = discriminator(synthethic_imgs)
            
            train_results["real_acc"].append(((disc_real > threshold) == valid).float().mean().item())
            train_results["syn_acc"].append(((disc_synth > threshold) == fake).float().mean().item())
            train_results["syn_true_acc"].append(((disc_synth_true > threshold) == valid).float().mean().item())
            
            if adv_loss == 'wgan-gp':
                real_loss = -torch.mean(disc_real) 
                fake_loss = disc_synth.mean() 
            elif adv_loss == 'hinge':
                real_loss = torch.nn.ReLU()(1.0 - disc_real).mean() 
                fake_loss = torch.nn.ReLU()(1.0 + disc_synth).mean()             
            elif adv_loss == 'bce':
                real_loss = adversarial_loss(disc_real, valid) 
                fake_loss = adversarial_loss(disc_synth, fake) 
                
            d_loss = real_loss + fake_loss 
            train_results["d_loss"].append(d_loss.item())
            
            if adv_loss == 'wgan-gp':
                # Compute gradient penalty
                alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda().expand_as(real_imgs)
                interpolated = Variable(alpha * real_imgs.data + (1 - alpha) * gen_imgs.data, requires_grad=True)
                out = discriminator(interpolated)[1]

                grad = torch.autograd.grad(outputs=out,
                                           inputs=interpolated,
                                           grad_outputs=torch.ones(out.size()).cuda(),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]

                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
                d_loss_reg = lambda_gp * d_loss_gp
                d_loss += d_loss_reg
                train_results["d_reg"].append(d_loss_reg.item())
            
            d_loss.backward()
            optimizer_D.step()

        print_str =  f'Epoch {epoch}'
        print_str += f' D_loss {np.mean(train_results["d_loss"]):.6f}'
        if adv_loss == 'wgan-gp':
            print_str += f' D_reg {np.mean(train_results["d_reg"]):.6f}'
        print_str += f' G_loss {np.mean(train_results["g_loss"]):6f}'
        print_str += f' G_reg {np.mean(train_results["g_reg"]):6f}'
        print_str += f' mask_loss {np.mean(train_results["mask_loss"]):6f}'
        #print_str += f' fake_mask_loss {np.mean(train_results["fake_mask_loss"]):6f}'
        print_str += f' h_acc {np.mean(train_results["real_acc"]):.4f}'
        print_str += f' f_acc {np.mean(train_results["syn_acc"]):.4f}'
        print_str += f' s_acc {np.mean(train_results["syn_true_acc"]):.4f}'
        print_str += f' p_syn {np.mean(train_results["p_syn"]):.4f}'
        print_str += f' p_real {np.mean(train_results["p_real"]):.4f}'
        dual_iter.set_description(print_str)
        dual_iter.refresh()
        
        if i == batches_per_epoch and i > 0:
            break
        
    # Validate 
    requires_grad(model, False)
    inputs = torch.from_numpy(np.load(os.path.join(
        data_path, f'manual_images_{transform_mode}_test.npy'))).float()
    labels = torch.from_numpy(np.load(os.path.join(
        data_path, f'manual_labels_{transform_mode}_test.npy'))).float()
    inputs = torch.from_numpy(np.expand_dims(
        np.vstack([apply_transforms(valid_transforms, x) for x in inputs.numpy()]), 1))
    
    results_dict_set1 = defaultdict(list)
    my_iter = tqdm.tqdm(enumerate(zip(inputs, labels)), total = inputs.shape[0], leave = True)
    for k, (x, y) in my_iter:
        pred_label = model(x.unsqueeze(0).to(device))
        arr, n = scipy.ndimage.label(pred_label.cpu() > 0.5)
        centroid = scipy.ndimage.find_objects(arr)
        pred_label = len(centroid)
        if pred_label > 0:
            pred_label = 1
        else:
            pred_label = 0
        results_dict_set1["pred"].append(pred_label)
        results_dict_set1["true"].append(y[0].item())
        mets = man_metrics(results_dict_set1)
        f1, pod, far, csi = mets["f1"], mets["pod"], mets["far"], mets["csi"]
        my_iter.set_description(f"Epoch {epoch} F1: {f1:.3f} POD: {pod:.3f} FAR: {far:.3f} CSI: {csi:3f}")
        my_iter.refresh()
    
        
    # Epoch is over. Save some stuff.
    save_image(synthethic_imgs.data[:16], f'{conf["save_loc"]}/images/synth_{epoch}.png', nrow=4, normalize=True)
    save_image(real_imgs.data[:16], f'{conf["save_loc"]}/images/real_{epoch}.png', nrow=4, normalize=True)
    save_image(gen_imgs.data[:16], f'{conf["save_loc"]}/images/pred_{epoch}.png', nrow=4, normalize=True)

    # Save the dataframe to disk
    results["epoch"].append(epoch)
    results["d_loss"].append(np.mean(train_results["d_loss"]))
    if adv_loss == 'wgan-gp':
        results["d_loss_reg"].append(np.mean(train_results["d_reg"]))
    results["g_loss"].append(np.mean(train_results["g_loss"]))
    results["g_reg"].append(np.mean(train_results["g_reg"]))
    results["mask_loss"].append(np.mean(train_results["mask_loss"]))
    #results["fake_mask_loss"].append(np.mean(train_results["fake_mask_loss"]))
    results["holo_acc"].append(np.mean(train_results["real_acc"]))
    results["fake_acc"].append(np.mean(train_results["syn_acc"]))
    results["synth_acc"].append(np.mean(train_results["syn_true_acc"]))
    results["perception_syn"].append(np.mean(train_results["p_syn"]))
    results["perception_holo"].append(np.mean(train_results["p_real"]))
    
    df = pd.DataFrame.from_dict(results).reset_index()
    df.to_csv(f'{conf["save_loc"]}/training_log.csv', index=False)

    # Save the model
    state_dict = {
        'epoch': epoch,
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'generator_state_dict': generator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'model_state_dict': model.state_dict(),
        'model_optimizer_state_dict': optimizer_M.state_dict()
    }
    torch.save(state_dict, f'{conf["save_loc"]}/best.pt')
    
    # Anneal learning rates 
    lr_G_decay.step(epoch)
    lr_D_decay.step(epoch)
    lr_M_decay.step(epoch)

Epoch 0 D_loss 1.574306 G_loss 35.002747 G_reg 0.078397 mask_loss 0.927094 h_acc 0.6875 f_acc 0.6875 s_acc 0.4375 p_syn 0.6728 p_real 0.9966: 100%|██████████| 250/250 [06:04<00:00,  1.46s/it]
Epoch 0 F1: 0.668 POD: 0.939 FAR: 0.238 CSI: 0.725906: 100%|██████████| 1154/1154 [00:27<00:00, 41.88it/s]
Epoch 1 D_loss 0.266075 G_loss 39.314060 G_reg 0.068820 mask_loss 0.995426 h_acc 1.0000 f_acc 0.9875 s_acc 0.0250 p_syn 0.7815 p_real 1.0689:  38%|███▊      | 95/250 [02:19<03:45,  1.46s/it]

###  Validate on the manual data

In [None]:
import sklearn, sklearn.metrics

def man_metrics(results):
    result = {}
    for metric in ["f1", "auc", 'pod', "far", "csi"]: #"man_prec", "man_recall",
        if metric == 'f1':
            score = sklearn.metrics.f1_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'prec':
            score = sklearn.metrics.precision_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'recall':
            score = sklearn.metrics.recall_score(results["true"], results["pred"], average = "weighted")
        elif metric == 'auc':
            try:
                score = sklearn.metrics.roc_auc_score(results["true"], results["pred"], average = "weighted")
            except:
                score = 1.0
        elif metric ==  "csi":
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = TP / (TP + FN + FP)
            except:
                score = 1
        elif metric == 'far':
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = FP / (TP + FP)
            except:
                score = 1
        elif metric == 'pod':
            try:
                TN, FP, FN, TP = sklearn.metrics.confusion_matrix(results["true"], results["pred"]).ravel()
                score = TP / (TP + FN)
            except: 
                score = 1
        result[metric] = score
        #print(metric, round(score, 3))
    return result

In [None]:
def apply_transforms(transforms, image):
    im = {"image": image}
    for image_transform in transforms:
        im = image_transform(im)
    image = im["image"]
    return image

In [None]:
model = load_model(conf["model"]).to(device)

In [None]:
checkpoint = torch.load(
    f"{save_loc}/best.pt",
    map_location=lambda storage, loc: storage
)

In [None]:
checkpoint["epoch"]

In [None]:
model.load_state_dict(checkpoint["model_state_dict"])
model = model.eval()

### Performance on set 1

In [None]:
inputs = torch.from_numpy(np.load(os.path.join(
    data_path, f'manual_images_{transform_mode}.npy'))).float()
labels = torch.from_numpy(np.load(os.path.join(
    data_path, f'manual_labels_{transform_mode}.npy'))).float()
inputs = torch.from_numpy(np.expand_dims(
    np.vstack([apply_transforms(valid_transforms, x) for x in inputs.numpy()]), 1))

In [None]:
results_dict_set1 = defaultdict(list)
my_iter = tqdm.tqdm(enumerate(zip(inputs, labels)), total = inputs.shape[0], leave = True)
for k, (x, y) in my_iter:
    pred_label = model(x.unsqueeze(0).to(device))
    arr, n = scipy.ndimage.label(pred_label.cpu() > 0.5)
    centroid = scipy.ndimage.find_objects(arr)
    pred_label = len(centroid)
    if pred_label > 0:
        pred_label = 1
    else:
        pred_label = 0
    results_dict_set1["pred"].append(pred_label)
    results_dict_set1["true"].append(y[0].item())
    mets = man_metrics(results_dict_set1)
    f1, pod, far, csi = mets["f1"], mets["pod"], mets["far"], mets["csi"]
    my_iter.set_description(f"F1: {f1:.3f} POD: {pod:.3f} FAR: {far:.3f} CSI: {csi:3f}")
    #my_iter.set_description(f"Accuracy: {np.mean(results_dict_set1['accuracy'])}")
    my_iter.refresh()

In [None]:
set1_metrics = man_metrics(results_dict_set1)

In [None]:
for key, val in set1_metrics.items():
    print(key, val)

### Performance on set 2

In [None]:
inputs_2 = torch.from_numpy(np.load(os.path.join(
    data_path, f'manual_images_{transform_mode}_test.npy'))).float()
labels_2 = torch.from_numpy(np.load(os.path.join(
    data_path, f'manual_labels_{transform_mode}_test.npy'))).float()
confs_2 = torch.from_numpy(np.load(os.path.join(
    data_path, f'manual_conf_{transform_mode}_test.npy'))).float()
inputs_2 = torch.from_numpy(np.expand_dims(
    np.vstack([apply_transforms(valid_transforms, x) for x in inputs_2.numpy()]), 1))

In [None]:
results_dict_set2 = defaultdict(list)
my_iter = tqdm.tqdm(enumerate(zip(inputs_2, labels_2)), total = inputs_2.shape[0], leave = True)
for k, (x, y) in my_iter:
    pred_label = model(x.unsqueeze(0).to(device))
    arr, n = scipy.ndimage.label(pred_label.cpu() > 0.5)
    centroid = scipy.ndimage.find_objects(arr)
    pred_label = len(centroid)
    if pred_label > 0 and pred_label <= 10000:
        pred_label = 1
    else:
        pred_label = 0
    results_dict_set2["pred"].append(pred_label)
    results_dict_set2["true"].append(y[0].item())
    mets = man_metrics(results_dict_set2)
    f1, pod, far, csi = mets["f1"], mets["pod"], mets["far"], mets["csi"]
    my_iter.set_description(f"F1: {f1:.3f} POD: {pod:.3f} FAR: {far:.3f} CSI: {csi:3f}")
    my_iter.refresh()

In [None]:
set2_metrics = man_metrics(results_dict_set2)

In [None]:
for key, val in set2_metrics.items():
    print(key, val)