# Prerequisites

In [1]:
import sys

sys.path.append("..")

In [2]:
%load_ext autoreload
%autoreload 2

import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision.transforms.v2 as transforms
from IPython.display import clear_output
from monai.networks.nets import resnet18
from skimage.io import imread, imsave
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

import transforms3d as T
from dataset3d import BNSet, BNSetMasks, get_dloader_mask, get_dloader_noise
from model3d import CNN3d
from util3d import get_obj_score3d, get_saliency3d, show_volume

sns.set_theme()

In [3]:
data_dir = "../data/bugNIST_DATA"

name_legend = {
    "ac": "brown_cricket",
    "bc": "black_cricket",
    "bf": "blow_fly",
    "bl": "buffalo_bettle_larva",
    "bp": "blow_fly_pupa",
    "cf": "curly-wing_fly",
    "gh": "grasshopper",
    "ma": "maggot",
    "ml": "mealworm",
    "pp": "green_bottle_fly_pupa",
    "sl": "soldier_fly_larva",
    "wo": "woodlice",
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# Training loop

In [24]:
lr = 1e-4

batch_size = 2
num_workers = 2

# subset = ["ac", "bc"]
subset = name_legend.keys

persistent_workers = num_workers > 0
trainloader = get_dloader_noise(
    "train",
    batch_size,
    data_dir=data_dir,
    subset=subset,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
)
valloader = get_dloader_noise(
    "val",
    batch_size,
    data_dir=data_dir,
    subset=subset,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=persistent_workers,
)

# model = CNN3d(len(subset))
model = resnet18(
    spatial_dims=3,
    n_input_channels=1,
    num_classes=2,
)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)

In [25]:
stats = {}
for epoch in range(50):
    metrics_train = {
        "loss": [],
        "preds": [],
        "labels": [],
    }
    metrics_val = {
        "loss": [],
        "preds": [],
        "labels": [],
    }

    print(f"Epoch {epoch}")
    model.train()
    for volumes, labels, masks, noise in tqdm(trainloader):
        out = model((volumes * masks + ~masks * noise).to(device))

        loss = criterion(out, labels.type(torch.LongTensor).to(device))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        _, indices = torch.max(out.cpu(), 1)

        metrics_train["loss"].append(loss.cpu().detach().item())
        metrics_train["preds"].append(indices.detach().numpy())
        metrics_train["labels"].append(labels.numpy())

    model.eval()
    for volumes, labels, masks in tqdm(valloader):
        out = model((volumes * masks).to(device))

        with torch.no_grad():
            loss = criterion(out, labels.type(torch.LongTensor).to(device))

        _, indices = torch.max(out.cpu(), 1)

        metrics_val["loss"].append(loss.cpu().detach().item())
        metrics_val["preds"].append(indices.detach().numpy())
        metrics_val["labels"].append(labels.numpy())

    performance = {
        "train_loss": np.mean(metrics_train["loss"]),
        "train_accuracy": np.mean(
            np.concatenate(metrics_train["preds"])
            == np.concatenate(metrics_train["labels"])
        ).item(),
        "val_loss": np.mean(metrics_val["loss"]),
        "val_accuracy": np.mean(
            np.concatenate(metrics_val["preds"])
            == np.concatenate(metrics_val["labels"])
        ).item(),
    }
    print(performance)
    stats[epoch] = performance

Epoch 0


  0%|          | 0/446 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 2.35 GiB is allocated by PyTorch, and 18.03 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)