# Prerequisites

In [1]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
sys.path.append("..")

In [2]:
%load_ext autoreload
%autoreload 2


import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision.transforms.v2 as transforms
from IPython.display import clear_output
from monai.networks.nets import resnet10
from skimage.io import imread, imsave
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

import transforms3d as T
from dataset3d import BNSet, BNSetMasks, get_dloader_mask, get_dloader_noise
from model3d import CNN3d
from util3d import get_obj_score3d, get_saliency3d, show_volume

sns.set_theme()

In [3]:
# data_dir = "../data/bugNIST_DATA"
data_dir = "/work3/s191510/data/BugNIST_DATA"

name_legend = {
    "ac": "brown_cricket",
    "bc": "black_cricket",
    "bf": "blow_fly",
    "bl": "buffalo_bettle_larva",
    "bp": "blow_fly_pupa",
    "cf": "curly-wing_fly",
    "gh": "grasshopper",
    "ma": "maggot",
    "ml": "mealworm",
    "pp": "green_bottle_fly_pupa",
    "sl": "soldier_fly_larva",
    "wo": "woodlice",
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# Training loop

In [4]:
lr = 1e-4

batch_size = 8
num_workers = 16

# subset = ["ac", "bc"]
subset = ["ac", "ml"]
# subset = list(name_legend.keys())

persistent_workers = num_workers > 0
trainloader = get_dloader_noise(
    "train",
    batch_size,
    data_dir=data_dir,
    subset=subset,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
)
valloader = get_dloader_noise(
    "val",
    batch_size,
    data_dir=data_dir,
    subset=subset,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
)

model = CNN3d(len(subset))
# model = resnet18(
#     spatial_dims=3,
#     n_input_channels=1,
#     num_classes=len(subset),
# )
# model = resnet10(
#     spatial_dims=3,
#     n_input_channels=1,
#     num_classes=len(subset),
# )
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)

get_input = lambda volumes, masks, noise: volumes * masks + ~masks * noise
# get_input = lambda volumes, masks, noise: volumes * masks

In [5]:
stats = {}
for epoch in range(0, 50):
    metrics_train = {
        "loss": [],
        "preds": [],
        "labels": [],
    }
    metrics_val = {
        "loss": [],
        "preds": [],
        "labels": [],
    }

    print(f"Epoch {epoch}")
    model.train()
    for volumes, labels, masks, noise in tqdm(trainloader):
        out = model(get_input(volumes, masks, noise).to(device))

        loss = criterion(out, labels.type(torch.LongTensor).to(device))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        _, indices = torch.max(out.cpu(), 1)

        metrics_train["loss"].append(loss.cpu().detach().item())
        metrics_train["preds"].append(indices.detach().numpy())
        metrics_train["labels"].append(labels.numpy())

    model.eval()
    for volumes, labels, masks, noise in tqdm(valloader):
        out = model(get_input(volumes, masks, noise).to(device))

        with torch.no_grad():
            loss = criterion(out, labels.type(torch.LongTensor).to(device))

        _, indices = torch.max(out.cpu(), 1)

        metrics_val["loss"].append(loss.cpu().detach().item())
        metrics_val["preds"].append(indices.detach().numpy())
        metrics_val["labels"].append(labels.numpy())

    performance = {
        "train_loss": np.mean(metrics_train["loss"]),
        "train_accuracy": np.mean(
            np.concatenate(metrics_train["preds"])
            == np.concatenate(metrics_train["labels"])
        ).item(),
        "val_loss": np.mean(metrics_val["loss"]),
        "val_accuracy": np.mean(
            np.concatenate(metrics_val["preds"])
            == np.concatenate(metrics_val["labels"])
        ).item(),
    }
    print(performance)
    stats[epoch] = performance
    break

Epoch 0


  0%|          | 0/110 [00:00<?, ?it/s]

../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [5,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [6,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [7,0,0] Assertion `t >= 0 && t < n_classes` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# metrics_train["labels"], 
metrics_val["labels"]

In [None]:
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9, 3), dpi=150)

ax0.plot(stats.keys(), [stats[epoch]["train_loss"] for epoch in stats])
ax0.plot(stats.keys(), [stats[epoch]["val_loss"] for epoch in stats])
ax0.set_title("CE Loss")

ax1.plot(stats.keys(), [stats[epoch]["train_accuracy"] * 100 for epoch in stats])
ax1.plot(stats.keys(), [stats[epoch]["val_accuracy"] * 100 for epoch in stats])
ax1.set_title("Accuracy")
plt.show()

In [None]:
dloader = get_dloader_noise("val", 1, data_dir=data_dir, subset=subset, num_workers=0)
model.eval();

In [None]:
dloader_iter = iter(dloader)

In [None]:
sns.reset_orig()
volumes, labels, masks, noise = next(dloader_iter)
volumes = get_input(volumes, masks, noise)

slc, score, indices, out = get_saliency3d(model, volumes, device=device)
obj_score = get_obj_score3d(slc, masks)

slc_abs = np.abs(slc)

show_volume(volumes.detach().numpy(), labels.item(), cmap="viridis")
show_volume(slc_abs / slc_abs.max(), cmap="inferno")
obj_score, indices.item(), labels.item()