# Prerequisites

In [1]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
sys.path.append("..")

In [2]:
%load_ext autoreload
%autoreload 2


import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision.transforms.v2 as transforms
from IPython.display import clear_output
from monai.networks.nets import resnet10
from skimage.io import imread, imsave
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

import transforms3d as T
from dataset3d import BNSet, BNSetMasks, get_dloader_mask, get_dloader_noise
from model3d import CNN3d
from util3d import get_obj_score3d, get_saliency3d, show_volume

sns.set_theme()

In [3]:
# data_dir = "../data/bugNIST_DATA"
data_dir = "/work3/s191510/data/BugNIST_DATA"

name_legend = {
    "ac": "brown_cricket",
    "bc": "black_cricket",
    "bf": "blow_fly",
    "bl": "buffalo_bettle_larva",
    "bp": "blow_fly_pupa",
    "cf": "curly-wing_fly",
    "gh": "grasshopper",
    "ma": "maggot",
    "ml": "mealworm",
    "pp": "green_bottle_fly_pupa",
    "sl": "soldier_fly_larva",
    "wo": "woodlice",
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# Training loop

In [4]:
lr = 1e-4

batch_size = 16
num_workers = 16

# subset = ["bc", "wo"]
# subset = ["ac", "bc", "ml"]
subset = list(name_legend.keys())

persistent_workers = num_workers > 0
trainloader = get_dloader_noise(
    "train",
    batch_size,
    data_dir=data_dir,
    subset=subset,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
)
valloader = get_dloader_noise(
    "val",
    batch_size=1,
    data_dir=data_dir,
    subset=subset,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
)

# model = CNN3d(len(name_legend))
# model = resnet18(
#     spatial_dims=3,
#     n_input_channels=1,
#     num_classes=len(name_legend),
# )
model = resnet10(
    spatial_dims=3,
    n_input_channels=1,
    no_max_pool=False,
    conv1_t_stride=2,
    num_classes=len(name_legend),
)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)

get_input = lambda volumes, masks, noise: volumes * masks + ~masks * noise
# get_input = lambda volumes, masks, noise: volumes * masks

In [5]:
stats = {}
for epoch in range(0, 50):
    metrics_train = {
        "loss": [],
        "preds": [],
        "labels": [],
    }
    metrics_val = {
        "loss": [],
        "preds": [],
        "labels": [],
        "object_scores": [],
    }

    print(f"Epoch {epoch}")
    model.train()
    for volumes, labels, masks, noise in tqdm(trainloader):
        out = model(get_input(volumes, masks, noise).to(device))

        optimizer.zero_grad()
        loss = criterion(out, labels.type(torch.LongTensor).to(device))
        
        loss.backward()
        optimizer.step()
        

        _, indices = torch.max(out.cpu(), 1)

        metrics_train["loss"].append(loss.cpu().detach().item())
        metrics_train["preds"].append(indices.detach().numpy())
        metrics_train["labels"].append(labels.numpy())

    model.eval()
    for volumes, labels, masks, noise in tqdm(valloader):
        # out = model(get_input(volumes, masks, noise).to(device))
        
        slc, score, indices, out = get_saliency3d(model, get_input(volumes, masks, noise), device=device)
        obj_score = get_obj_score3d(slc, masks)
        
        with torch.no_grad():
            loss = criterion(out.to(device), labels.type(torch.LongTensor).to(device))

        _, indices = torch.max(out.cpu(), 1)

        metrics_val["loss"].append(loss.cpu().detach().item())
        metrics_val["preds"].append(indices.detach().numpy())
        metrics_val["labels"].append(labels.numpy())
        metrics_val["object_scores"].append(obj_score)

    performance = {
        "train_loss": np.mean(metrics_train["loss"]),
        "train_accuracy": np.mean(
            np.concatenate(metrics_train["preds"])
            == np.concatenate(metrics_train["labels"])
        ).item(),
        "val_loss": np.mean(metrics_val["loss"]),
        "val_accuracy": np.mean(
            np.concatenate(metrics_val["preds"])
            == np.concatenate(metrics_val["labels"])
        ).item(),
        "obj_score": np.mean(metrics_val["object_scores"]),
    }
    print(performance)
    stats[epoch] = performance
    

Epoch 0


  0%|          | 0/344 [00:00<?, ?it/s]

  0%|          | 0/1831 [00:00<?, ?it/s]

IndexError: Caught IndexError in DataLoader worker process 10.
Original Traceback (most recent call last):
  File "/zhome/28/4/143111/context/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/zhome/28/4/143111/context/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/zhome/28/4/143111/context/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/zhome/28/4/143111/gitrepos/cnn-context/notebooks/../dataset3d.py", line 157, in __getitem__
    noise = self.noise_sampler(item).copy()[np.newaxis]
  File "/zhome/28/4/143111/gitrepos/cnn-context/notebooks/../dataset3d.py", line 148, in sampler
    return self.noise[item % len(self)]
  File "/zhome/28/4/143111/context/lib/python3.10/site-packages/numpy/core/memmap.py", line 335, in __getitem__
    res = super().__getitem__(index)
IndexError: index 1098 is out of bounds for axis 0 with size 1098


In [None]:
sns.set_theme()
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(15, 3), dpi=150)

ax0.plot(stats.keys(), [stats[epoch]["train_loss"] for epoch in stats])
ax0.plot(stats.keys(), [stats[epoch]["val_loss"] for epoch in stats])
ax0.set_title("CE Loss")

ax1.plot(stats.keys(), [stats[epoch]["train_accuracy"] * 100 for epoch in stats])
ax1.plot(stats.keys(), [stats[epoch]["val_accuracy"] * 100 for epoch in stats])
ax1.set_title("Accuracy")

ax2.plot(stats.keys(), [stats[epoch]["obj_score"] * 100 for epoch in stats])
ax2.set_title("Object score")
plt.show()

In [None]:
dloader = get_dloader_noise("val", 1, data_dir=data_dir, subset=subset, num_workers=0)
model.eval();

In [None]:
dloader_iter = iter(dloader)

In [None]:
sns.reset_orig()
volumes, labels, masks, noise = next(dloader_iter)
volumes = get_input(volumes, masks, noise)

slc, score, indices, out = get_saliency3d(model, volumes, device=device)
obj_score = get_obj_score3d(slc, masks)

slc_abs = np.abs(slc)

size = 3

fig, axs = plt.subplots(2, 3, figsize=(size * 3, size * 2), tight_layout=True)

label = labels.item()
gt = list(name_legend.values())[label].replace("_", " ")
pred = list(name_legend.values())[indices.item()].replace("_", " ")
conf = torch.softmax(out, 1)[0, indices]

show_volume(volumes.detach().numpy(), title=f'Ground truth: {gt}\nPrediction: {pred} ({conf.item():.1%})\nObj. score: {obj_score:.2f}', cmap="viridis", fig_axs=(fig, axs[0]))
show_volume(slc_abs / slc_abs.max(), fig_axs=(fig, axs[1]),  cmap="inferno")
plt.show()
# obj_score, indices.item(), labels.item()

In [None]:
obj_scores = []
for volumes, labels, masks, noise in tqdm(dloader):
    volumes = get_input(volumes, masks, noise)

    slc, score, indices, out = get_saliency3d(model, volumes, device=device)
    obj_score = get_obj_score3d(slc, masks)

    obj_scores.append(obj_score)

np.mean(obj_scores)

Mealworm vs brown cricket:
Mean obj score: 0.831