## Attention gated AE notebook
This notebook goes through training steps of the attention gated 3D AE. It may not be possible to run the model in the notebook. Therefore, we refer to AG_vali.py and ./bash/AG_run.sh, to test the model yourself.


### Important libs and custom functions

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler 
import os
import sys
import csv
import itertools
from pathlib import Path

# --- PROJECT SETUP ---
PROJECT_ROOT = Path.cwd().parent.parent
sys.path.append(str(PROJECT_ROOT))
print(PROJECT_ROOT)

from func.utill import save_predictions
from func.dataloaders import VolumetricPatchDataset 
from func.loss import ComboLoss, TverskyLoss, DiceLoss, FocalLoss
from func.Models import MultiTaskNet_ag as MultiTaskNet 

/zhome/d2/4/167803/Desktop/Deep_project/02456-final-project


### Defining hyperparameters. 
We also switch to gpu using cuda and save important paths. Splitting data into ~90% train, ~5% validation, ~5% test. Initiation of csv file where we save losses and mIoU.

In [None]:
INPUT_SHAPE = (128, 128, 128) 
NUM_CLASSES = 4
LATENT_DIM = 256 
BATCH_SIZE = 2
SAVE_INTERVAL = 20
NUM_EPOCHS = 400
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.001

# Weights for Multi-Task Loss
SEG_WEIGHT = 100
RECON_WEIGHT = 1.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

CLASS_WEIGHTS = torch.tensor([0.5, 1.5, 1.0, 4.0]).to(device) 
print(f"Using Class Weights: {CLASS_WEIGHTS}")

OUTPUT_DIR = PROJECT_ROOT / "Output_AG_vali"
CSV_PATH =  PROJECT_ROOT / "stats" / "training_log_ag.csv"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
SAVE_PATH_FINAL = PROJECT_ROOT / "Trained_models" / "AG_val_final.pth"
SAVE_PATH = PROJECT_ROOT / "Trained_models" / "AG_val_best.pth"
SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)

test_cols = [1,2, 33, 34]      
val_cols = [27, 28, 29, 30]
labeled_train_cols = [3,4,5,6,7,8 , 35,36,36,37,38]
unlabeled_train_cols = list(range(9, 27)) + list(range(40, 44))

with open(CSV_PATH, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['Epoch', 'Train_Loss', 'Val_Loss', 'Val_mIoU'])

print(f"--- Data Splits ---")
print(f"Test (Reserved): {test_cols}")
print(f"Validation: {val_cols}")
print(f"Labeled Train: {labeled_train_cols}")
print(f"Unlabeled Train: {unlabeled_train_cols}")




Using device: cpu
Using Class Weights: tensor([0.5000, 1.5000, 1.0000, 4.0000])
--- Data Splits ---
Test (Reserved): [1, 2, 33, 34]
Validation: [27, 28, 29, 30]
Labeled Train: [3, 4, 5, 6, 7, 8, 35, 36, 36, 37, 38]
Unlabeled Train: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]


## Dataloader
Here we define dataloader able to split the $256^3$ into 8 patches of 128^3, which is the largers size we could use, memory limitations. Also the dimensions have the input volumes should be divisable by 8, ensuring that downsampling will yield integer sized feature maps.

For the labeled dataset we use data augmentation with gaussian noise and flips. No data augmentation on the unlabeled dataset

In [8]:
try:
    # 1. Labeled Dataset
    labeled_dataset = VolumetricPatchDataset(
        selected_columns=labeled_train_cols, 
        augment=True, 
        is_labeled=True
    )
    
    labeled_loader = DataLoader(
        dataset=labeled_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4
    )
    print("--- Labeled Loader Ready ---")

    # 2. Unlabeled Dataset
    unlabeled_dataset = VolumetricPatchDataset(
        selected_columns=unlabeled_train_cols,
        augment=False, 
        is_labeled=False
    )
    
    unlabeled_loader = DataLoader(
        dataset=unlabeled_dataset,
        batch_size=BATCH_SIZE, 
        shuffle=True,
        num_workers=4
    )
    print("--- Unlabeled Loader Ready ---")

    # 3. Validation Loader
    val_dataset = VolumetricPatchDataset(
        selected_columns=val_cols,
        augment=False,
        is_labeled=True
    )
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    print("--- Validation Loader Ready ---")
    print(f"Train Batches: {len(labeled_loader)}")
    print(f"Val Batches: {len(val_loader)}")

except Exception as e:
    print(f"Error creating Datasets: {e}")
    exit()


--- Labeled Loader Ready ---
--- Unlabeled Loader Ready ---
--- Validation Loader Ready ---
Train Batches: 88
Val Batches: 24


### Train - Attention Gated AE

In [None]:
model = MultiTaskNet(
    in_channels=1, 
    num_classes=NUM_CLASSES, 
    latent_dim=LATENT_DIM  
).to(device)

Tversky = TverskyLoss(num_classes=NUM_CLASSES, alpha=0.6, beta=0.4)
focal = FocalLoss(gamma=2.0).to(device)
loss_fn_recon = nn.MSELoss().to(device)

loss_fn_seg = ComboLoss(
    dice_loss_fn=Tversky, # normally dice
    wce_loss_fn=focal, # normally cross
    alpha=0.6, beta=0.4
).to(device)

optimizer_model = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer_model,
    mode='max',      
    factor=0.5,      
    patience=20,     
    verbose=True
)

best_val_iou = 0.0

# --- LOSS HISTORY ---
train_loss_history = []
val_loss_history = []
val_iou_history = []
print("--- Starting Training ---")

for epoch in range(NUM_EPOCHS):
    
    # --- TRAINe ---
    model.train() 
    epoch_train_loss = 0.0
    epoch_seg_loss = 0.0
    epoch_recon_loss = 0.0
    
    # Visualization variables
    last_x, last_y, last_recon, last_seg = None, None, None, None

    for batch_idx, ((x, y_seg_target), (x_unlabeled)) in \
            enumerate(zip(labeled_loader, itertools.cycle(unlabeled_loader))):
        
        x = x.to(device)       
        y_seg_target = y_seg_target.to(device).squeeze(1) # Squeeze for CrossEntropy
        x_unlabeled = x_unlabeled.to(device) 
        
        optimizer_model.zero_grad()
        
        # Forward Labeled
        seg_out, recon_out_labeled = model(x)
        
        # Forward Unlabeled (with noise)
        noise = torch.randn_like(x_unlabeled) * 0.1
        x_unlabeled_noisy = x_unlabeled + noise
        _ , recon_out_unlabeled = model(x_unlabeled_noisy)
                    
        # Loss Calculation
        loss_seg = loss_fn_seg(seg_out, y_seg_target)
        loss_recon = loss_fn_recon(recon_out_labeled, x) + \
                        loss_fn_recon(recon_out_unlabeled, x_unlabeled)
        
        total_loss = (loss_seg*SEG_WEIGHT) + (loss_recon*RECON_WEIGHT)
        
        total_loss.backward()
        optimizer_model.step()
        
        epoch_train_loss += total_loss.item()
        epoch_seg_loss += loss_seg.item()
        epoch_recon_loss += loss_recon.item()

        # Save last batch
        if batch_idx == len(labeled_loader) - 1:
            last_x = x.detach()
            last_y = y_seg_target.detach()
            last_recon = recon_out_labeled.detach()
            last_seg = seg_out.detach()
            
    avg_train_loss = epoch_train_loss / len(labeled_loader)
    avg_seg_loss = epoch_seg_loss / len(labeled_loader)
    avg_recon_loss = epoch_recon_loss / len(labeled_loader)
    train_loss_history.append(avg_train_loss)
    
    # --- VALIDATION PHASE ---
    model.eval()
    #total_val_iou = 0.0
    class_inter = np.zeros(NUM_CLASSES)
    class_union = np.zeros(NUM_CLASSES)
    loss_val = 0.0
    with torch.no_grad():
        for (x_val, y_val_seg) in val_loader:
            x_val = x_val.to(device)
            y_val_seg = y_val_seg.to(device).squeeze(1).long()
            
            val_seg_out, _ = model(x_val)
            val_preds = torch.argmax(val_seg_out, dim=1)

            loss = loss_fn_seg(val_seg_out, y_val_seg)
            loss_val += loss.item()
            for c in range(NUM_CLASSES):
                pred_c = (val_preds == c)
                true_c = (y_val_seg == c)

                inter = (pred_c & true_c).sum().item()
                union = (pred_c | true_c).sum().item()

                class_inter[c] += inter
                class_union[c] += union

    avg_val_loss = loss_val / len(val_loader)
    val_loss_history.append(avg_val_loss)
    class_iou =[]

    for c in range(NUM_CLASSES):
        if class_union[c] > c:
            iou = class_inter[c]/class_union[c]
        else:
            iou = 0.0
        class_iou.append(iou)
    
    mIoU = np.mean(class_iou)
    val_iou_history.append(mIoU)

    with open(CSV_PATH, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch + 1, avg_train_loss, avg_val_loss, mIoU])
        
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {avg_train_loss:.4f}")
    print(f" Avg Train Loss: {avg_train_loss:.4f} | Seg Loss: {avg_seg_loss:.4f} | Recon Loss: {avg_recon_loss:.4f}")
    print(f"  Val mIoU: {mIoU:.4f} (Best: {best_val_iou:.4f})")
    print(f" [Class IoU] C0: {class_iou[0]:.4f} C1: {class_iou[1]:.4f} | C2: {class_iou[2]:.4f} | C3: {class_iou[3]:.4f}")
    # Scheduler Step
    scheduler.step(mIoU)

    # Save Best Model
    if mIoU > best_val_iou:
        best_val_iou = mIoU
        patience_count = 0
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"  --> New Best Model Saved!")
    else: 
        patience_count +=1
        (f"Patience count: {patience_count:.3f}")

    # Visualization
    if (epoch + 1) % SAVE_INTERVAL == 0:
        print(f"  Saving visuals for Epoch {epoch+1}...")
        save_predictions(epoch, last_x, last_y, last_recon, last_seg, OUTPUT_DIR)
print("--- Training Finished ---")
print("Saving model weights...")
torch.save(model.state_dict(), SAVE_PATH_FINAL)
print(f"Best model saved {SAVE_PATH}")
print(f"Final model saved {SAVE_PATH_FINAL}")
print("Done.")

## Model measurements
Below are the models measurements and model comparison. From test data
> Performance of AG_final and AG_best on each input column

| Model       | Column | ME     | Mean IoU | Mean Dice | AUROC  |
|-------------|--------|--------|----------|-----------|--------|
| **AG_final** | 1      | 0.0295 | 0.7969   | 0.8610    | 0.9946 |
|             | 2      | 0.0682 | 0.7787   | 0.8717    | 0.9860 |
|             | 37     | 0.0939 | 0.8026   | 0.8857    | 0.9801 |
|             | 38     | 0.0576 | 0.8874   | 0.9401    | 0.9938 |
| **Avg**     | --     | **0.0623** | **0.8164** | **0.8896** | **0.9886** |
| **AG_best**  | 1      | 0.0333 | 0.7487   | 0.8044    | 0.9883 |
|             | 2      | 0.0692 | 0.7704   | 0.8656    | 0.9839 |
|             | 37     | 0.0888 | 0.8097   | 0.8905    | 0.9822 |
|             | 38     | 0.0543 | 0.8924   | 0.9429    | 0.9939 |
| **Avg**     | --     | **0.0614** | **0.8053** | **0.8759** | **0.9871** |


> Average performance across all input volumes for AG models

| Model     | ME     | Mean IoU | Mean Dice | AUROC  |
|-----------|--------|----------|-----------|--------|
| AG_final  | 0.0623 | 0.8164   | 0.8896    | 0.9886 |
| AG_best   | 0.0614 | 0.8053   | 0.8759    | 0.9871 |


> Per-class IoU and Dice for each input volume (AG_final)

| Column | C0 IoU | C0 Dice | C1 IoU | C1 Dice | C2 IoU | C2 Dice | C3 IoU | C3 Dice |
|--------|--------|---------|--------|---------|--------|---------|--------|---------|
| 1      | 0.9238 | 0.9604  | 0.9373 | 0.9676  | 0.9600 | 0.9796  | 0.3666 | 0.5365  |
| 2      | 0.8095 | 0.8947  | 0.7526 | 0.8589  | 0.9216 | 0.9592  | 0.6312 | 0.7739  |
| 37     | 0.7933 | 0.8847  | 0.9271 | 0.9622  | 0.8710 | 0.9310  | 0.6191 | 0.7647  |
| 38     | 0.8690 | 0.9299  | 0.9177 | 0.9571  | 0.9092 | 0.9525  | 0.8538 | 0.9211  |

> Per-class IoU and Dice for each input volume (AG_best)

| Column | C0 IoU | C0 Dice | C1 IoU | C1 Dice | C2 IoU | C2 Dice | C3 IoU | C3 Dice |
|--------|--------|---------|--------|---------|--------|---------|--------|---------|
| 1      | 0.9141 | 0.9551  | 0.9357 | 0.9668  | 0.9556 | 0.9773  | 0.1893 | 0.3183  |
| 2      | 0.8112 | 0.8958  | 0.7390 | 0.8499  | 0.9218 | 0.9592  | 0.6098 | 0.7576  |
| 37     | 0.8074 | 0.8935  | 0.9230 | 0.9600  | 0.8771 | 0.9345  | 0.6314 | 0.7741  |
| 38     | 0.8784 | 0.9353  | 0.9200 | 0.9584  | 0.9133 | 0.9547  | 0.8577 | 0.9234  |


> **Table 4:** Comparison of average performance for the final models across all input volumes.

| **Model (final)** | **ME** | **Mean IoU** | **Mean Dice** | **AUROC** |
| :--- | :---: | :---: | :---: | :---: |
| Multi\_big\_final | 0.0985 | 0.7463 | 0.8255 | 0.9686 |
| AG\_final | 0.0623 | 0.8164 | 0.8896 | 0.9886 |
| VAE\_final | 0.1020 | 0.7424 | 0.8208 | 0.9541 |

> **Table 5:** Comparison of average performance for the best models across all input volumes.

| **Model (best)** | **ME** | **Mean IoU** | **Mean Dice** | **AUROC** |
| :--- | :---: | :---: | :---: | :---: |
| Multi\_big\_best | 0.0837 | 0.7553 | 0.8306 | 0.9620 |
| AG\_best | 0.0614 | 0.8053 | 0.8759 | 0.9871 |
| VAE\_best | 0.0816 | 0.7669 | 0.8385 | 0.9571 |