In [2]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from diffusers import UNet2DModel
import matplotlib.pyplot as plt
import os

# Define the training configuration
@dataclass
class TrainingConfig:
    image_size = 256  # the generated image resolution
    batch_size = 16
    num_epochs = 50
    learning_rate = 1e-4
    save_model_epochs = 10
    output_dir = "generated"
    seed = 0

config = TrainingConfig()

# Dataset loader (assuming datasetloader2.py defines CTInpaintingDataset)
from datasetloader import CTInpaintingDataset
dataset = CTInpaintingDataset("data", augment=True)
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

# Model definition with 3 input channels: corrupted image, mask, and tissue segmentation as input
model = UNet2DModel(
    sample_size=config.image_size,  # image resolution
    in_channels=3,  # 3 channels for corrupted, mask, and tissue
    out_channels=1,  # output channel is the CT scan (grayscale)
    layers_per_block=2,  # ResNet layers per UNet block
    block_out_channels=(8, 8, 8, 8),  # UNet block channels
    norm_num_groups=8,  # group normalization
    down_block_types=(
        "DownBlock2D", 
        "DownBlock2D", 
        "DownBlock2D", 
        "AttnDownBlock2D"  # includes attention in last down block
    ),
    up_block_types=(
        "AttnUpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D", 
        "UpBlock2D"
    ),
)


def generate_inpainting_plot(model, dataloader, device, output_dir="data/generated"):
    # Ensure the model is in evaluation mode
    model.eval()

    # Create directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Get 8 samples from the dataloader
    samples = next(iter(dataloader))
    corrupted_images = samples['corrupted'][:8].to(device)  # masked images
    mask_images = samples['mask'][:8].to(device)            # masks
    tissue_images = samples['tissue'][:8].to(device)        # tissue segmentations
    true_images = samples['ct'][:8].to(device)              # original unmasked images

    # Concatenate inputs
    inputs = torch.cat([corrupted_images, mask_images, tissue_images], dim=1)

    # Forward pass through the model
    with torch.no_grad():
        inpainted_images = model(inputs, 0).sample  # get inpainted outputs from the model

    # Prepare the plot
    fig, axes = plt.subplots(3, 8, figsize=(20, 8))  # 3 rows, 8 columns
    
    for i in range(8):
        # Row 1: Masked images (corrupted + mask)
        axes[0, i].imshow(corrupted_images[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
        axes[0, i].axis('off')

        # Row 2: Reconstructed / Inpainted images
        axes[1, i].imshow(inpainted_images[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
        axes[1, i].axis('off')

        # Row 3: True images (unmasked CT)
        axes[2, i].imshow(true_images[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
        axes[2, i].axis('off')

    # Save the figure
    plt.tight_layout()
    plot_path = os.path.join(output_dir, "inpainting_comparison.png")
    plt.savefig(plot_path)
    plt.close()


# L1 Loss function
criterion = nn.L1Loss()

# Optimizer
optimizer = Adam(model.parameters(), lr=config.learning_rate)

# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

import tqdm

for epoch in range(config.num_epochs):
    model.train()
    running_loss = 0.0
    batchtqdm = tqdm.tqdm(train_dataloader)
    for batch in batchtqdm:
        corrupted = batch['corrupted'].to(device) # is in [0, 1]
        mask = batch['mask'].to(device) # is in [0, 1] (not binary, has fractional smooth values too at the boundary)
        tissue = batch['tissue'].to(device) # is in {0, 1, 2} (3 classes) (0: background, 1: soft tissue, 2: fat)
        ct = batch['ct'].to(device) #is in [0, 1]
        vertebrae = batch['vertebrae'].to(device) #is a tensor list of integers, each integer is the vertebrae #

        # Concatenate the 3 inputs into a 3-channel input for the UNet
        inputs = torch.cat([corrupted, mask, tissue], dim=1)

        # Forward pass
        outputs = model(inputs, 0).sample  # model's predicted ct image

        # Compute the loss
        loss = criterion(outputs, ct)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        batchtqdm.set_postfix(loss=loss.item())
        break

    generate_inpainting_plot(model, train_dataloader, device, config.output_dir)

    # Log epoch statistics
    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{config.num_epochs}], Loss: {avg_loss:.4f}")

    # Save model every few epochs
    if (epoch + 1) % config.save_model_epochs == 0:
        torch.save(model.state_dict(), f"{config.output_dir}/unet_epoch_{epoch+1}.pth")

print("Training complete.")

  0%|          | 0/296 [00:05<?, ?it/s, loss=0.267]


Plot saved to generated/inpainting_comparison.png
Epoch [1/50], Loss: 0.0009


  0%|          | 0/296 [00:05<?, ?it/s, loss=0.251]


Plot saved to generated/inpainting_comparison.png
Epoch [2/50], Loss: 0.0008


  0%|          | 0/296 [00:05<?, ?it/s, loss=0.239]


Plot saved to generated/inpainting_comparison.png
Epoch [3/50], Loss: 0.0008


  0%|          | 0/296 [00:05<?, ?it/s, loss=0.228]


Plot saved to generated/inpainting_comparison.png
Epoch [4/50], Loss: 0.0008


  0%|          | 0/296 [00:05<?, ?it/s, loss=0.219]


Plot saved to generated/inpainting_comparison.png
Epoch [5/50], Loss: 0.0007


  0%|          | 0/296 [00:05<?, ?it/s, loss=0.215]


Plot saved to generated/inpainting_comparison.png
Epoch [6/50], Loss: 0.0007


  0%|          | 0/296 [00:06<?, ?it/s, loss=0.194]


Plot saved to generated/inpainting_comparison.png
Epoch [7/50], Loss: 0.0007


  0%|          | 0/296 [00:00<?, ?it/s]


KeyboardInterrupt: 