## AE notebook
This notebook goes through training steps of the 3D AE model. It may not be possible to run the model in the notebook. Therefore, we refer to multi_big_vali.py and ./bash/multi_big.sh, to test the model yourself.


### Important libs and custom functions


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
import os
import sys
import itertools
import csv
from pathlib import Path
import matplotlib.pyplot as plt

# --- Custom Imports ---
PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(PROJECT_ROOT))

from func.utill import save_predictions, plot_learning_curves
from func.loss import DiceLoss, ComboLoss, TverskyLoss, FocalLoss
from func.Models import MultiTaskNet_big
from func.dataloaders import VolumetricPatchDataset

### Defining hyperparameters. 
We also switch to gpu using cuda and save important paths. Splitting data into ~90% train, ~5% validation, ~5% test. Initiation of csv file where we save losses and mIoU.

In [None]:
# --- CONFIGURATION ---
BLACKHOLE_PATH = os.environ.get('BLACKHOLE', '.')
INPUT_SHAPE = (128, 128, 128) 
NUM_CLASSES = 4
LATENT_DIM = 256 
BATCH_SIZE = 3 
SAVE_INTERVAL = 20
NUM_EPOCHS = 400
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
class_weight = torch.tensor([1.0, 1.0, 1.0, 3.0])

OUTPUT_DIR = PROJECT_ROOT / "output_big_vali_res"
CSV_PATH = PROJECT_ROOT / "stats" / "training_log_big.csv"
print(f"Logging metrics to: {CSV_PATH}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SAVE_PATH = PROJECT_ROOT / "Trained_models" / "multi_big_res_best.pth"
SAVE_PATH_FINAL = PROJECT_ROOT / "Trained_models" / "multi_big_res_final.pth"
SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)

# --- DATA SPLITS ---
test_cols = [1,2, 33, 34]      
val_cols = [27, 28, 29, 30]
labeled_cols = [3,4,5,6,7,8 , 35,36,36,37,38]
unlabeled_cols = list(range(9, 27)) + list(range(40, 44))

with open(CSV_PATH, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['Epoch', 'Train_Loss', 'Val_Loss', 'Val_mIoU'])

print(f"--- Data Splits ---")
print(f"Labeled Train: {labeled_cols}")
print(f"Unlabeled Train: {unlabeled_cols}")
print(f"Validation: {val_cols}")


## Dataloader
Here we define dataloader able to split the $256^3$ into 8 patches of 128^3, which is the largers size we could use, memory limitations. Also the dimensions have the input volumes should be divisable by 8, ensuring that downsampling will yield integer sized feature maps.

For the labeled dataset we use data augmentation with gaussian noise and flips. No data augmentation on the unlabeled dataset

In [None]:
try:
    labeled_dataset = VolumetricPatchDataset(selected_columns=labeled_cols, augment=True, is_labeled=True)
    labeled_loader = DataLoader(labeled_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

    unlabeled_dataset = VolumetricPatchDataset(selected_columns=unlabeled_cols, augment=False, is_labeled=False)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

    val_dataset = VolumetricPatchDataset(selected_columns=val_cols, augment=False, is_labeled=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    print("--- Loaders Ready ---")

except Exception as e:
    print(f"Error creating Datasets: {e}")
    exit()

### Train - AE

In [None]:
model = MultiTaskNet_big(in_channels=1, num_classes=NUM_CLASSES, latent_dim=LATENT_DIM).to(DEVICE)

Tversky = TverskyLoss(num_classes=NUM_CLASSES, alpha=0.6, beta=0.4).to(DEVICE)
focal = FocalLoss(gamma=2.0, weight=class_weight).to(DEVICE)
loss_seg_fn = ComboLoss(dice_loss_fn=Tversky, wce_loss_fn=focal).to(DEVICE)
loss_fn_recon = nn.MSELoss().to(DEVICE)

optimizer_model = optim.Adam(model.parameters(), lr=1e-4) # Lower LR for stability

scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer_model, mode="max", factor=0.5, patience=20, verbose=True
)

best_val_iou = 0.0
patience_counter = 0
EARLY_STOPPING_PATIENCE = 50

# --- LOSS HISTORY ---
train_loss_history = []
val_loss_history = []
val_iou_history = []

print("--- Starting Training ---")
for epoch in range(NUM_EPOCHS):
        
    # === TRAINING ===
    model.train() 
    train_loss = 0.0
    epoch_seg_loss = 0.0
    epoch_recon_loss = 0.0
    
    last_x, last_y, last_recon, last_seg = None, None, None, None
    
    for batch_idx, ((x, y_seg_target), x_unlabeled) in \
                    enumerate(zip(labeled_loader, itertools.cycle(unlabeled_loader))):
        
        x = x.to(DEVICE)
        y_seg_target = y_seg_target.to(DEVICE).squeeze(1)
        x_unlabeled = x_unlabeled.to(DEVICE) 

        optimizer_model.zero_grad()
        
        # 1. Labeled Forward
        seg_out, recon_out_labeled = model(x)
        total_loss_seg = loss_seg_fn(seg_out, y_seg_target)
        loss_recon_labeled = loss_fn_recon(recon_out_labeled, x)

        # 2. Unlabeled Forward
        noise = torch.randn_like(x_unlabeled) * 0.1
        x_unlabeled_noisy = x_unlabeled + noise
        _ , recon_out_unlabeled = model(x_unlabeled_noisy)
        
        loss_recon_unlabeled = loss_fn_recon(recon_out_unlabeled, x_unlabeled)
        
        total_loss_recon = loss_recon_labeled + loss_recon_unlabeled
        
        # Weighted Sum
        total_loss = (total_loss_seg * 100.0) + (total_loss_recon * 1.0)
            
        total_loss.backward()
        optimizer_model.step()
        
        train_loss += total_loss.item()
        epoch_seg_loss += total_loss_seg.item()
        epoch_recon_loss += total_loss_recon.item()
        
        if batch_idx == len(labeled_loader) - 1:
            last_x = x.detach()
            last_y = y_seg_target.detach()
            last_recon = recon_out_labeled.detach()
            last_seg = seg_out.detach()

    avg_train_loss = train_loss / len(labeled_loader)
    avg_seg_loss = epoch_seg_loss / len(labeled_loader)
    avg_recon_loss = epoch_recon_loss / len(labeled_loader)
    train_loss_history.append(avg_train_loss)

    # === VALIDATION ===
    model.eval()
    class_inter = np.zeros(NUM_CLASSES)
    class_union = np.zeros(NUM_CLASSES)
    loss_val = 0.0

    with torch.no_grad():
        for vx, vy_seg in val_loader:
            vx = vx.to(DEVICE)
            vy_seg = vy_seg.to(DEVICE).squeeze(1).long()
            
            val_seg_out, _ = model(vx)
            val_preds = torch.argmax(val_seg_out, dim=1)

            loss = loss_seg_fn(val_seg_out, vy_seg)
            loss_val += loss.item()
            for c in range(NUM_CLASSES):
                pred_c = (val_preds == c)
                true_c = (vy_seg == c)

                inter = (pred_c & true_c).sum().item()
                union = (pred_c | true_c).sum().item()

                class_inter[c] += inter
                class_union[c] += union

    avg_val_loss = loss_val / len(val_loader)
    val_loss_history.append(avg_val_loss)
    class_iou = []

    for c in range(NUM_CLASSES):
        if class_union[c] > 0:
            iou = class_inter[c] / class_union[c]
        else:
            iou = 0.0
        class_iou.append(iou)
    
    mIoU = np.mean(class_iou)
    val_iou_history.append(mIoU)

    with open(CSV_PATH, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch + 1, avg_train_loss, avg_val_loss, mIoU])
        
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f}")
    print(f"  Seg: {avg_seg_loss:.4f} | Recon: {avg_recon_loss:.4f}")
    print(f"  Val mIoU: {mIoU:.4f} (Best: {best_val_iou:.4f})")
    print(f"  [Class IoU] C0: {class_iou[0]} | C1: {class_iou[1]:.4f} | C2: {class_iou[2]:.4f} | C3: {class_iou[3]:.4f}")
    
    scheduler.step(mIoU)

    if mIoU > best_val_iou:
        best_val_iou = mIoU
        patience_counter = 0
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"  --> New Best Model Saved!")
    else: 
        patience_counter += 1
        print(f"  Patience count: {patience_counter}/{EARLY_STOPPING_PATIENCE}")

    if (epoch + 1) % SAVE_INTERVAL == 0:
        print(f"  Saving visuals for Epoch {epoch +1}...")
        save_predictions(epoch, last_x, last_y, last_recon, last_seg, OUTPUT_DIR)

print("--- Training Finished ---")
torch.save(model.state_dict(), SAVE_PATH_FINAL)
print(f"Best model saved {SAVE_PATH}")
print(f"Final model saved {SAVE_PATH_FINAL}")
print("Done.")




### Model measurements 
Below are the models measurements and model comparison. From test data
> Performance of Multi_big_final and Multi_big_best on each input column

| Model             | Column | ME     | Mean IoU | Mean Dice | AUROC  |
|------------------|--------|--------|----------|-----------|--------|
| **Multi_big_final** | 1      | 0.0824 | 0.6776   | 0.7193    | 0.9392 |
|                  | 2      | 0.0770 | 0.7526   | 0.8526    | 0.9759 |
|                  | 37     | 0.1629 | 0.7151   | 0.8194    | 0.9683 |
|                  | 38     | 0.0717 | 0.8399   | 0.9108    | 0.9911 |
| **Avg**          | --     | **0.0985** | **0.7463** | **0.8255** | **0.9686** |
| **Multi_big_best**  | 1      | 0.0425 | 0.6917   | 0.7197    | 0.9246 |
|                  | 2      | 0.0778 | 0.7498   | 0.8517    | 0.9702 |
|                  | 37     | 0.1409 | 0.7330   | 0.8354    | 0.9628 |
|                  | 38     | 0.0734 | 0.8468   | 0.9156    | 0.9903 |
| **Avg**          | --     | **0.0837** | **0.7553** | **0.8306** | **0.9620** |


> Average performance across all input volumes for Multi_big models

| Model            | ME     | Mean IoU | Mean Dice | AUROC  |
|------------------|--------|----------|-----------|--------|
| Multi_big_final  | 0.0985 | 0.7463   | 0.8255    | 0.9686 |
| Multi_big_best   | 0.0837 | 0.7553   | 0.8306    | 0.9620 |


> Per-class IoU and Dice for each input volume (Multi_big_final)

| Column | C0 IoU | C0 Dice | C1 IoU | C1 Dice | C2 IoU | C2 Dice | C3 IoU | C3 Dice |
|--------|--------|---------|--------|---------|--------|---------|--------|---------|
| 1      | 0.8974 | 0.9459  | 0.9096 | 0.9527  | 0.8826 | 0.9376  | 0.0210 | 0.0411  |
| 2      | 0.8054 | 0.8922  | 0.7279 | 0.8425  | 0.9142 | 0.9552  | 0.5630 | 0.7204  |
| 37     | 0.7560 | 0.8610  | 0.9079 | 0.9518  | 0.7789 | 0.8757  | 0.4175 | 0.5891  |
| 38     | 0.8588 | 0.9240  | 0.9004 | 0.9476  | 0.8976 | 0.9460  | 0.7027 | 0.8254  |

> Per-class IoU and Dice for each input volume (Multi_big_best)

| Column | C0 IoU | C0 Dice | C1 IoU | C1 Dice | C2 IoU | C2 Dice | C3 IoU | C3 Dice |
|--------|--------|---------|--------|---------|--------|---------|--------|---------|
| 1      | 0.8951 | 0.9447  | 0.9286 | 0.9630  | 0.9429 | 0.9706  | 0.0002 | 0.0004  |
| 2      | 0.7974 | 0.8873  | 0.7080 | 0.8290  | 0.9095 | 0.9526  | 0.5845 | 0.7378  |
| 37     | 0.7469 | 0.8551  | 0.9081 | 0.9518  | 0.8009 | 0.8894  | 0.4762 | 0.6452  |
| 38     | 0.8524 | 0.9203  | 0.9085 | 0.9521  | 0.8883 | 0.9408  | 0.7381 | 0.8493  |


> **Table 4:** Comparison of average performance for the final models across all input volumes.

| **Model (final)** | **ME** | **Mean IoU** | **Mean Dice** | **AUROC** |
| :--- | :---: | :---: | :---: | :---: |
| Multi\_big\_final | 0.0985 | 0.7463 | 0.8255 | 0.9686 |
| AG\_final | 0.0623 | 0.8164 | 0.8896 | 0.9886 |
| VAE\_final | 0.1020 | 0.7424 | 0.8208 | 0.9541 |

> **Table 5:** Comparison of average performance for the best models across all input volumes.

| **Model (best)** | **ME** | **Mean IoU** | **Mean Dice** | **AUROC** |
| :--- | :---: | :---: | :---: | :---: |
| Multi\_big\_best | 0.0837 | 0.7553 | 0.8306 | 0.9620 |
| AG\_best | 0.0614 | 0.8053 | 0.8759 | 0.9871 |
| VAE\_best | 0.0816 | 0.7669 | 0.8385 | 0.9571 |