In [1]:
import os
import random
import numpy as np
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset, Dataset

from torchvision import transforms

import medmnist
from medmnist import NoduleMNIST3D
from utils import set_seet, ToTensor3D, SyntheticVAEDataset, train_classifier, evaluate_classifier, train_vae
from cnn_architecture import CNN3D
from genai_architecture import VAE3D

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
data_flag = "nodulemnist3d"
download = True
batch_size = 64
set_seet(42)

transform = ToTensor3D()

In [4]:
train_dataset = NoduleMNIST3D(split="train", transform=transform, download=download)
val_dataset = NoduleMNIST3D(split="val", transform=transform, download=download)
test_dataset = NoduleMNIST3D(split="test", transform=transform, download=download)

print("Train Size:", len(train_dataset))
print("Val Size:", len(val_dataset))
print("Test size:", len(test_dataset))

Train Size: 1158
Val Size: 165
Test size: 310


In [5]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [6]:
print("\n=== Training baseline classifier (no GenAI) ===")
baseline_model = CNN3D(num_classes=2)
baseline_model, baseline_val_acc = train_classifier(
    model=baseline_model,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=30
)
baseline_test_acc, _ = evaluate_classifier(baseline_model, device, test_loader)
print(f"Baseline Test Acc: {baseline_test_acc:.4f}")


=== Training baseline classifier (no GenAI) ===
[Epoch 01] Train Loss: 0.5390 Acc: 0.7539 | Val Loss: 0.6050 Acc: 43.1273
[Epoch 02] Train Loss: 0.4700 Acc: 0.7832 | Val Loss: 0.5954 Acc: 43.0121
[Epoch 03] Train Loss: 0.4365 Acc: 0.8083 | Val Loss: 0.5032 Acc: 40.9091
[Epoch 04] Train Loss: 0.4162 Acc: 0.8161 | Val Loss: 0.6934 Acc: 24.1455
[Epoch 05] Train Loss: 0.4286 Acc: 0.8212 | Val Loss: 0.4894 Acc: 34.5394
[Epoch 06] Train Loss: 0.4182 Acc: 0.8143 | Val Loss: 0.4219 Acc: 39.6545
[Epoch 07] Train Loss: 0.3962 Acc: 0.8273 | Val Loss: 0.4155 Acc: 39.0970
[Epoch 08] Train Loss: 0.3963 Acc: 0.8247 | Val Loss: 0.7519 Acc: 22.9394
[Epoch 09] Train Loss: 0.4106 Acc: 0.8230 | Val Loss: 0.4043 Acc: 38.8788
[Epoch 10] Train Loss: 0.4032 Acc: 0.8290 | Val Loss: 0.4199 Acc: 38.1697
[Epoch 11] Train Loss: 0.3992 Acc: 0.8195 | Val Loss: 0.4766 Acc: 40.6970
[Epoch 12] Train Loss: 0.3782 Acc: 0.8472 | Val Loss: 0.7041 Acc: 41.2909
[Epoch 13] Train Loss: 0.3943 Acc: 0.8377 | Val Loss: 0.4243 Ac

In [7]:
print("\n=== Training 3D VAE (GenAI) on nodules ===")
vae = VAE3D(latent_dim=64)
vae = train_vae(vae, device, train_loader, epochs=30, lr=1e-3)


=== Training 3D VAE (GenAI) on nodules ===
[VAE Epoch 01] Loss: 18540.1284
[VAE Epoch 02] Loss: 11283.6282
[VAE Epoch 03] Loss: 11274.8632
[VAE Epoch 04] Loss: 11273.5960
[VAE Epoch 05] Loss: 11272.3972
[VAE Epoch 06] Loss: 11272.0250
[VAE Epoch 07] Loss: 11278.5184
[VAE Epoch 08] Loss: 11271.7751
[VAE Epoch 09] Loss: 11272.0153
[VAE Epoch 10] Loss: 11277.5482
[VAE Epoch 11] Loss: 11275.8354
[VAE Epoch 12] Loss: 11273.2453
[VAE Epoch 13] Loss: 11271.2042
[VAE Epoch 14] Loss: 11272.7120
[VAE Epoch 15] Loss: 11272.3623
[VAE Epoch 16] Loss: 11271.0447
[VAE Epoch 17] Loss: 11276.4660
[VAE Epoch 18] Loss: 11272.4534
[VAE Epoch 19] Loss: 11271.7535
[VAE Epoch 20] Loss: 11270.9862
[VAE Epoch 21] Loss: 11270.0983
[VAE Epoch 22] Loss: 11269.7990
[VAE Epoch 23] Loss: 11273.2150
[VAE Epoch 24] Loss: 11274.2627
[VAE Epoch 25] Loss: 11271.9031
[VAE Epoch 26] Loss: 11269.9094
[VAE Epoch 27] Loss: 11276.6574
[VAE Epoch 28] Loss: 11273.1808
[VAE Epoch 29] Loss: 11270.1599
[VAE Epoch 30] Loss: 11269.6

In [8]:
labels = [train_dataset[i][1].item() for i in range(len(train_dataset))]
labels = np.array(labels)
num_classes = 2
counts = np.bincount(labels, minlength=num_classes)
label_probs = counts / counts.sum()
print("Label distribution in train set:", counts, "-> probs:", label_probs)

Label distribution in train set: [863 295] -> probs: [0.74525043 0.25474957]


In [9]:
num_synthetic = len(train_dataset)  # e.g. 1x the real data size
synthetic_dataset = SyntheticVAEDataset(vae, device, length=num_synthetic,
                                        label_probs=label_probs,
                                        num_classes=num_classes)

augmented_train_dataset = ConcatDataset([train_dataset, synthetic_dataset])
aug_train_loader = DataLoader(augmented_train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=2)

In [10]:
print("\n=== Training classifier with GenAI-augmented data ===")
genai_model = CNN3D(num_classes=2)
genai_model, genai_val_acc = train_classifier(
    genai_model,
    device,
    aug_train_loader,
    val_loader,
    epochs=30,
    lr=1e-3
)

genai_test_acc, _ = evaluate_classifier(genai_model, device, test_loader)
print(f"GenAI-Augmented Test Acc: {genai_test_acc:.4f}")


=== Training classifier with GenAI-augmented data ===


AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ~~~~~~~~~~~~~~~^^^^^^
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\collate.py", line 212, in collate
    collate(samples, collate_fn_map=collate_fn_map)
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\collate.py", line 269, in collate_tensor_fn
    numel = sum(x.numel() for x in batch)
  File "c:\Users\Dominik Hahn\anaconda3\Lib\site-packages\torch\utils\data\_utils\collate.py", line 269, in <genexpr>
    numel = sum(x.numel() for x in batch)
                ^^^^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'numel'
