# Autoencoder Training - Simplified Version

This notebook contains a simplified version of the autoencoder training code for easier debugging and experimentation.

## 1. Import Libraries

In [1]:
import logging
import os
import time

import matplotlib.pyplot as plt
import torch
from monai.data import CacheDataset, DataLoader
from monai.inferers import sliding_window_inference
from monai.losses import PatchAdversarialLoss, PerceptualLoss
from monai.networks.nets import AutoencoderKL, PatchDiscriminator
from monai.utils import set_determinism
from torch.nn import L1Loss
from tqdm import tqdm

from config import ConfigParser
from constants import TASK1_HN
from utils.utils import (
    get_data_paths,
    get_vae_train_transforms,
    get_vae_val_transforms
)

# Set deterministic behavior
set_determinism(42)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


## 2. Configuration and Hyperparameters

In [2]:
patch_size = [64, 64, 64]
train_val_split = 0.8
task1 = True

debug_mode = False

batch_size = 1
num_epochs = 10
learning_rate = 1e-4
adv_weight = 0.01
perceptual_weight = 0.001
kl_weight = 1e-6
autoencoder_warm_up_n_epochs = 3
val_interval = 2
save_interval = 5

sw_batch_size = 1
overlap = 0.5
mode = "gaussian"

print(f"Configuration loaded:")
print(f"  Patch size: {patch_size}")
print(f"  Batch size: {batch_size}")
print(f"  Epochs: {num_epochs}")
print(f"  Learning rate: {learning_rate}")
print(f"  Debug mode: {debug_mode}")

Configuration loaded:
  Patch size: [64, 64, 64]
  Batch size: 1
  Epochs: 10
  Learning rate: 0.0001
  Debug mode: False


## 3. KL Loss Function

In [3]:
def kl_loss(z_mu, z_sigma):
    """
    KL divergence loss for VAE
    kl_loss = 0.5 * sum(z_mu^2 + z_sigma^2 - log(z_sigma^2) - 1)
    """
    klloss = 0.5 * torch.sum(
        z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1,
        dim=list(range(1, len(z_mu.shape)))
    )
    return torch.sum(klloss) / klloss.shape[0]

## 4. Data Loading

In [None]:
# Load data
DATASET = "task1_hn"
print(f"Loading data from {TASK1_HN}...")
_, cts_paths, masks_paths = get_data_paths(TASK1_HN, task1=task1, debug=debug_mode)
data = [{"image": ct, "mask": mask} for ct, mask in zip(cts_paths, masks_paths)]
print(f"Total samples: {len(data)}")
print(f"First data sample: {data[0]}")

# Split data
train_data_split = int(len(data) * train_val_split)
train_data = data[:train_data_split]
val_data = data[train_data_split:]

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

# Create datasets
train_ds = CacheDataset(data=train_data, transform=get_vae_train_transforms())
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) #, collate_fn=pad_list_data_collate)

val_ds = CacheDataset(data=val_data, transform=get_vae_val_transforms())
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) #, collate_fn=pad_list_data_collate)

Loading data from [WindowsPath('data/SynthRAD2025/synthRAD2025_Task1_Train_D/synthRAD2025_Task1_Train_D/Task1/HN'), WindowsPath('data/synthRAD2025_Task1_Train/Task1/HN')]...
Total samples: 221
First data sample: {'image': 'data\\SynthRAD2025\\synthRAD2025_Task1_Train_D\\synthRAD2025_Task1_Train_D\\Task1\\HN\\1HND001\\ct.mha', 'mask': 'data\\SynthRAD2025\\synthRAD2025_Task1_Train_D\\synthRAD2025_Task1_Train_D\\Task1\\HN\\1HND001\\mask.mha'}
Training samples: 176
Validation samples: 45


Loading dataset: 100%|██████████| 176/176 [00:57<00:00,  3.04it/s]
monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
Loading dataset: 100%|██████████| 45/45 [00:08<00:00,  5.29it/s]


## 4.1 GENERAR JSON CON DATOS DEL DATASET

In [None]:
import json
import numpy as np
from collections import defaultdict
import os

# Function to convert Windows path to Linux format
def convert_to_linux_path(windows_path):
    """Convert Windows path to Linux-compatible format"""
    if isinstance(windows_path, str):
        # Convert backslashes to forward slashes
        linux_path = windows_path.replace('\\', '/')
        # Remove drive letter if present (C:/ -> /)
        if len(linux_path) > 1 and linux_path[1] == ':':
            linux_path = linux_path[2:]  # Remove "C:"
        return linux_path
    return windows_path

# Function to extract metadata from a dataset
def extract_metadata_from_dataset(dataset):
    """Extract metadata from all samples in a dataset"""
    metadata_list = []

    for i in range(len(dataset)):
        sample = dataset[i]
        # Get the metadata from the sample
        img_meta = sample['image_meta_dict'] if 'image_meta_dict' in sample else sample.get('image').meta if hasattr(sample.get('image'), 'meta') else {}
        mask_meta = sample['mask_meta_dict'] if 'mask_meta_dict' in sample else sample.get('mask').meta if hasattr(sample.get('mask'), 'meta') else {}

        # Combine metadata (prioritize image metadata)
        combined_meta = {**mask_meta, **img_meta}
        metadata_list.append(combined_meta)

    return metadata_list

# Function to calculate mean and std for numerical fields
def calculate_stats(metadata_list):
    """Calculate mean and std for numerical metadata fields"""
    # Fields to calculate statistics for
    numerical_fields = ['spacing', 'spatial_shape']

    stats = {'mean': {}, 'std': {}}

    # Collect all values for each field
    field_values = defaultdict(list)

    for meta in metadata_list:
        for field in numerical_fields:
            if field in meta:
                if field == 'spacing' and hasattr(meta[field], 'tolist'):
                    field_values[field].append(meta[field].tolist())
                elif field == 'spatial_shape' and hasattr(meta[field], 'tolist'):
                    field_values[field].append(meta[field].tolist())
                elif isinstance(meta[field], (list, tuple, np.ndarray)):
                    field_values[field].append(list(meta[field]))
                else:
                    field_values[field].append(meta[field])

    # Calculate mean and std
    for field, values in field_values.items():
        if values:
            if field in ['spacing', 'spatial_shape']:
                # For arrays, calculate element-wise mean and std
                values_array = np.array(values)
                stats['mean'][field] = np.mean(values_array, axis=0).tolist()
                stats['std'][field] = np.std(values_array, axis=0).tolist()

    # Handle other metadata fields (take first available or create defaults)
    sample_meta = metadata_list[0] if metadata_list else {}

    # Add other fields with default values or from first sample
    for field in ['original_affine', 'affine']:
        if field in sample_meta:
            if hasattr(sample_meta[field], 'tolist'):
                stats['mean'][field] = sample_meta[field].tolist()
                stats['std'][field] = np.zeros_like(sample_meta[field]).tolist()
            else:
                stats['mean'][field] = sample_meta[field]
                stats['std'][field] = 0
        else:
            # Default 4x4 identity matrix for affine transforms
            default_affine = np.eye(4).tolist()
            stats['mean'][field] = default_affine
            stats['std'][field] = np.zeros((4, 4)).tolist()

    # Add string fields
    for field in ['space', 'original_channel_dim']:
        if field in sample_meta:
            stats['mean'][field] = sample_meta[field]
            stats['std'][field] = 0  # No std for categorical data
        else:
            stats['mean'][field] = 'unknown'
            stats['std'][field] = 0

    return stats

print("Extracting metadata from datasets...")

# Extract metadata from train and validation datasets
train_metadata = extract_metadata_from_dataset(train_ds)
val_metadata = extract_metadata_from_dataset(val_ds)

# Combine all metadata for statistics calculation
all_metadata = train_metadata + val_metadata

print(f"Extracted metadata from {len(train_metadata)} training samples and {len(val_metadata)} validation samples")

# Calculate statistics
stats = calculate_stats(all_metadata)

print("Calculated metadata statistics")

# Create the dataset dump structure
dataset_dump = {
    "metadata": {
        "name": "SynthRAD2025_Task1_HN",
        "mean": stats['mean'],
        "std": stats['std']
    },
    "train": [],
    "validation": []
}

# Add training data paths (converted to Linux format)
for data_item in train_data:
    dataset_dump["train"].append({
        "image": convert_to_linux_path(data_item["image"]),
        "mask": convert_to_linux_path(data_item["mask"])
    })

# Add validation data paths (converted to Linux format)
for data_item in val_data:
    dataset_dump["validation"].append({
        "image": convert_to_linux_path(data_item["image"]),
        "mask": convert_to_linux_path(data_item["mask"])
    })

# Create output directory if it doesn't exist
os.makedirs("notebook_outputs", exist_ok=True)

# Save the dataset dump
output_file = os.path.join("notebook_outputs", f"dataset_{DATASET}.json")
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(dataset_dump, f, indent=2, ensure_ascii=False)

print(f"Dataset metadata saved to: {output_file}")
print(f"Total training samples: {len(dataset_dump['train'])}")
print(f"Total validation samples: {len(dataset_dump['validation'])}")

# Display sample of the metadata
print("\nSample metadata structure:")
print(f"Mean spacing: {stats['mean'].get('spacing', 'N/A')}")
print(f"Mean spatial shape: {stats['mean'].get('spatial_shape', 'N/A')}")
print(f"Space: {stats['mean'].get('space', 'N/A')}")
print(f"\nFirst training sample (Linux format):")
if dataset_dump['train']:
    print(f"Image: {dataset_dump['train'][0]['image']}")
    print(f"Mask: {dataset_dump['train'][0]['mask']}")
print(f"\nFirst validation sample (Linux format):")
if dataset_dump['validation']:
    print(f"Image: {dataset_dump['validation'][0]['image']}")
    print(f"Mask: {dataset_dump['validation'][0]['mask']}")

Extracting metadata from datasets...
Extracted metadata from 176 training samples and 45 validation samples
Calculated metadata statistics
Dataset metadata saved to: notebook_outputs\dataset_metadata.json
Total training samples: 176
Total validation samples: 45

Sample metadata structure:
Mean spacing: [1.0, 1.0, 3.0]
Mean spatial shape: [340.47058823529414, 308.4117647058824, 90.13122171945702]
Space: RAS

First training sample (Linux format):
Image: data/SynthRAD2025/synthRAD2025_Task1_Train_D/synthRAD2025_Task1_Train_D/Task1/HN/1HND001/ct.mha
Mask: data/SynthRAD2025/synthRAD2025_Task1_Train_D/synthRAD2025_Task1_Train_D/Task1/HN/1HND001/mask.mha

First validation sample (Linux format):
Image: data/synthRAD2025_Task1_Train/Task1/HN/1HNC037/ct.mha
Mask: data/synthRAD2025_Task1_Train/Task1/HN/1HNC037/mask.mha
Extracted metadata from 176 training samples and 45 validation samples
Calculated metadata statistics
Dataset metadata saved to: notebook_outputs\dataset_metadata.json
Total traini

## 5. Model Setup

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Autoencoder
autoencoder = AutoencoderKL(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(32, 64, 64),
    latent_channels=3,
    num_res_blocks=1,
    norm_num_groups=16,
    attention_levels=(False, False, True),
).to(device)

# Discriminator
discriminator = PatchDiscriminator(
    spatial_dims=3,
    num_layers_d=3,
    channels=32,
    in_channels=1,
    out_channels=1,
    norm="INSTANCE",
).to(device)

print(f"Autoencoder parameters: {sum(p.numel() for p in autoencoder.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
print(f"Total models size: {(sum(p.numel() for p in autoencoder.parameters()) + sum(p.numel() for p in discriminator.parameters())) / 1e6:.2f} MB")

Using device: cuda
Autoencoder parameters: 2,299,208
Discriminator parameters: 2,770,977
Total size: 5.07 MB


## 6. Loss Functions and Optimizers

In [15]:
# Loss functions
l1_loss = L1Loss()
adv_loss = PatchAdversarialLoss(criterion="least_squares")
loss_perceptual = PerceptualLoss(
    spatial_dims=3,
    network_type="squeeze",
    is_fake_3d=True,
    fake_3d_ratio=0.2
).to(device)

# Optimizers
optimizer_g = torch.optim.Adam(params=autoencoder.parameters(), lr=learning_rate)
optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=learning_rate)

print("Loss functions and optimizers initialized")

Loss functions and optimizers initialized


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=SqueezeNet1_1_Weights.IMAGENET1K_V1`. You can also use `weights=SqueezeNet1_1_Weights.DEFAULT` to get the most up-to-date weights.


## 7. Training Loop

In [None]:
# Training variables
train_losses = []
val_losses = []
best_val_loss = float('inf')

# Create output directory
output_dir = "notebook_outputs"
os.makedirs(output_dir, exist_ok=True)

print("Starting training...")

for epoch in range(num_epochs):
    # Training phase
    autoencoder.train()
    discriminator.train()
    epoch_start = time.time()
    epoch_g_loss = 0
    epoch_d_loss = 0
    epoch_recon_loss = 0
    epoch_kl_loss = 0
    step = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

    for batch_data in progress_bar:
        step += 1
        images = batch_data["image"].to(device)

        # Train Generator (Autoencoder)
        optimizer_g.zero_grad(set_to_none=True)

        reconstruction, z_mu, z_sigma = autoencoder(images)

        # Calculate losses
        loss_recon = l1_loss(reconstruction, images)
        loss_kl = kl_loss(z_mu, z_sigma)
        loss_perc = loss_perceptual(reconstruction.float(), images.float())

        # Base generator loss
        loss_g_base = loss_recon + kl_weight * loss_kl + perceptual_weight * loss_perc

        # Add adversarial loss after warmup
        if epoch >= autoencoder_warm_up_n_epochs:
            logits_fake = discriminator(reconstruction.contiguous().float())[-1]
            loss_g_adv = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            loss_g = loss_g_base + adv_weight * loss_g_adv
        else:
            loss_g = loss_g_base

        loss_g.backward()
        optimizer_g.step()

        # Train Discriminator (after warmup)
        loss_d = torch.tensor(0.0)
        if epoch >= autoencoder_warm_up_n_epochs:
            optimizer_d.zero_grad(set_to_none=True)

            # Fake samples
            logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
            loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)

            # Real samples
            logits_real = discriminator(images.contiguous().detach())[-1]
            loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)

            loss_d = adv_weight * (loss_d_fake + loss_d_real) * 0.5
            loss_d.backward()
            optimizer_d.step()

        # Accumulate losses
        epoch_g_loss += loss_g.item()
        epoch_d_loss += loss_d.item()
        epoch_recon_loss += loss_recon.item()
        epoch_kl_loss += loss_kl.item()

        # Update progress bar
        progress_bar.set_postfix({
            "G_loss": f"{loss_g.item():.4f}",
            "D_loss": f"{loss_d.item():.4f}",
            "Recon": f"{loss_recon.item():.4f}"
        })

    # Calculate average losses
    avg_g_loss = epoch_g_loss / step
    avg_d_loss = epoch_d_loss / step
    avg_recon_loss = epoch_recon_loss / step
    avg_kl_loss = epoch_kl_loss / step

    train_losses.append({
        "epoch": epoch,
        "g_loss": avg_g_loss,
        "d_loss": avg_d_loss,
        "recon_loss": avg_recon_loss,
        "kl_loss": avg_kl_loss
    })

    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch + 1} completed in {epoch_time:.2f}s")
    print(f"  G Loss: {avg_g_loss:.4f}, D Loss: {avg_d_loss:.4f}")
    print(f"  Recon Loss: {avg_recon_loss:.4f}, KL Loss: {avg_kl_loss:.4f}")

    # Validation
    if (epoch + 1) % val_interval == 0:
        autoencoder.eval()
        val_recon_loss = 0
        val_step = 0

        with torch.no_grad():
            for val_batch in tqdm(val_loader, desc="Validation", leave=False):
                val_step += 1
                val_images = val_batch["image"].to(device)

                val_outputs = sliding_window_inference(
                    inputs=val_images,
                    roi_size=patch_size,
                    sw_batch_size=sw_batch_size,
                    predictor=autoencoder.reconstruct,
                    overlap=overlap,
                    mode=mode,
                    device=device,
                )

                batch_recon_loss = l1_loss(val_outputs, val_images)
                val_recon_loss += batch_recon_loss.item()

        avg_val_loss = val_recon_loss / val_step
        val_losses.append({"epoch": epoch, "val_loss": avg_val_loss})

        print(f"  Validation Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(autoencoder.state_dict(),
                      os.path.join(output_dir, "best_autoencoder.pth"))
            torch.save(discriminator.state_dict(),
                      os.path.join(output_dir, "best_discriminator.pth"))
            print(f"  New best model saved! Loss: {best_val_loss:.4f}")

    # Save checkpoint
    if (epoch + 1) % save_interval == 0:
        checkpoint = {
            "epoch": epoch,
            "autoencoder_state_dict": autoencoder.state_dict(),
            "discriminator_state_dict": discriminator.state_dict(),
            "optimizer_g_state_dict": optimizer_g.state_dict(),
            "optimizer_d_state_dict": optimizer_d.state_dict(),
            "train_losses": train_losses,
            "val_losses": val_losses
        }
        torch.save(checkpoint, os.path.join(output_dir, f"checkpoint_epoch_{epoch+1}.pth"))
        print(f"  Checkpoint saved at epoch {epoch+1}")

print("Training completed!")

## 8. Training Results Visualization

In [None]:
# Plot training losses
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

epochs = [loss["epoch"] for loss in train_losses]
g_losses = [loss["g_loss"] for loss in train_losses]
d_losses = [loss["d_loss"] for loss in train_losses]
recon_losses = [loss["recon_loss"] for loss in train_losses]
kl_losses = [loss["kl_loss"] for loss in train_losses]

# Generator loss
axes[0, 0].plot(epochs, g_losses, 'b-', label='Generator Loss')
axes[0, 0].set_title('Generator Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

# Discriminator loss
axes[0, 1].plot(epochs, d_losses, 'r-', label='Discriminator Loss')
axes[0, 1].set_title('Discriminator Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True)

# Reconstruction loss
axes[1, 0].plot(epochs, recon_losses, 'g-', label='Reconstruction Loss')
axes[1, 0].set_title('Reconstruction Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].grid(True)

# KL loss
axes[1, 1].plot(epochs, kl_losses, 'm-', label='KL Loss')
axes[1, 1].set_title('KL Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_losses.png'), dpi=300, bbox_inches='tight')
plt.show()

# Plot validation loss if available
if val_losses:
    val_epochs = [loss["epoch"] for loss in val_losses]
    val_loss_values = [loss["val_loss"] for loss in val_losses]

    plt.figure(figsize=(10, 6))
    plt.plot(val_epochs, val_loss_values, 'o-', label='Validation Loss')
    plt.title('Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend()
    plt.savefig(os.path.join(output_dir, 'validation_loss.png'), dpi=300, bbox_inches='tight')
    plt.show()

## 9. Save Final Models and Results

In [None]:
# Save final models
torch.save(autoencoder.state_dict(), os.path.join(output_dir, "final_autoencoder.pth"))
torch.save(discriminator.state_dict(), os.path.join(output_dir, "final_discriminator.pth"))

# Save training history
import json

training_history = {
    "config": {
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "patch_size": patch_size,
        "warmup_epochs": autoencoder_warm_up_n_epochs
    },
    "train_losses": train_losses,
    "val_losses": val_losses,
    "best_val_loss": best_val_loss
}

with open(os.path.join(output_dir, "training_history.json"), "w") as f:
    json.dump(training_history, f, indent=2)

print(f"All outputs saved to: {output_dir}")
print(f"Best validation loss: {best_val_loss:.4f}")

## 10. Model Evaluation (Optional)

In [None]:
# Load best model for evaluation
autoencoder.load_state_dict(torch.load(os.path.join(output_dir, "best_autoencoder.pth")))
autoencoder.eval()

# Generate sample reconstruction
with torch.no_grad():
    sample_batch = next(iter(val_loader))
    sample_image = sample_batch["image"][:1].to(device)  # Take first sample

    reconstruction = sliding_window_inference(
        inputs=sample_image,
        roi_size=patch_size,
        sw_batch_size=sw_batch_size,
        predictor=autoencoder.reconstruct,
        overlap=overlap,
        mode=mode,
        device=device,
    )

    # Show middle slice
    middle_slice = sample_image.shape[-1] // 2

    original_slice = sample_image[0, 0, :, :, middle_slice].cpu().numpy()
    recon_slice = reconstruction[0, 0, :, :, middle_slice].cpu().numpy()

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    ax1.imshow(original_slice, cmap='gray')
    ax1.set_title('Original')
    ax1.axis('off')

    ax2.imshow(recon_slice, cmap='gray')
    ax2.set_title('Reconstruction')
    ax2.axis('off')

    # Difference
    diff = abs(original_slice - recon_slice)
    ax3.imshow(diff, cmap='hot')
    ax3.set_title('Difference')
    ax3.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'sample_reconstruction.png'), dpi=300, bbox_inches='tight')
    plt.show()

    # Calculate reconstruction error
    mse = torch.mean((sample_image - reconstruction) ** 2).item()
    mae = torch.mean(torch.abs(sample_image - reconstruction)).item()

    print(f"Sample reconstruction metrics:")
    print(f"  MSE: {mse:.6f}")
    print(f"  MAE: {mae:.6f}")