# 3D Latent Diffusion Model

In [None]:
# TODO: Add buttom with "Open with Colab"

## Set up environment using Colab


In [None]:
!python -c "import monai" || pip install -q "monai-weekly[tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Set up imports

In [1]:
import os
import tempfile
os.environ["CUDA_VISIBLE_DEVICES"]="6"
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.utils import first, set_determinism
from tqdm import tqdm

from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, LatentDiffusionModel
from generative.schedulers import DDPMScheduler
from generative.losses import perceptual, adversarial_loss

print_config()

In [None]:
# for reproducibility purposes set a seed
set_determinism(42)

## Setup a data directory and download dataset
Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
root_dir = "./"
print(root_dir)

## Download the training set

In [None]:
channel = 0 # 0 = Flair
assert channel in [0,1,2,3], 'Choose a valid channel'

train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.EnsureTyped(keys=["image"]),
        transforms.Orientationd(keys=["image"], axcodes="RAS"),
        transforms.Spacingd(keys=["image"], pixdim=(3.0, 3.0, 2.2), mode=("bilinear"),),
        transforms.CenterSpatialCropd(keys=["image"],roi_size = (120,120,64)),
        transforms.ScaleIntensityRangePercentilesd(keys="image", lower= 0, upper= 99.5, b_min= 0, b_max= 1),
    ]
)
train_ds = DecathlonDataset(root_dir=root_dir, task = 'Task01_BrainTumour', section="training", 
                            cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise
                            num_workers=4,
                            download=False, # Set download to True if the dataset hasnt been downloaded yet
                            seed=0, transform = train_transforms) 
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
print(f'Image shape {train_ds[0]["image"].shape}')

## Visualise examples from the training set

In [None]:
# Plot axial, coronal and sagittal slices of a training sample
check_data = first(train_loader)
idx = 0

img = check_data["image"][idx, channel]
fig, axs = plt.subplots(nrows=1, ncols=3)
for ax in axs:
    ax.axis("off")
ax = axs[0]
ax.imshow(img[...,img.shape[2]//2], cmap="gray")
ax = axs[1]
ax.imshow(img[:,img.shape[1]//2, ...], cmap="gray")
ax = axs[2]
ax.imshow(img[img.shape[0]//2, ...], cmap="gray")
plt.savefig("training_examples.png")

## Download the validation set

val_data = MedNISTDataset(root_dir=root_dir, section="validation", download=True, seed=0)
val_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "Hand"]
val_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
    ]
)
val_ds = Dataset(data=val_datalist, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=True, num_workers=4)

## Define the network

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

In [None]:
stage1_model = AutoencoderKL(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    num_channels=32,
    latent_channels=3,
    ch_mult=(1, 2, 2),
    num_res_blocks=1,
    norm_num_groups=16,
    attention_levels=(False, False, True),
)

unet = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=3,
    out_channels=3,
    num_res_blocks=1,
    attention_resolutions=[4, 2],
    channel_mult=[1, 2, 2],
    model_channels=64,
    # TODO: play with this number
    num_heads=1,
)

scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_schedule="linear",
    beta_start=0.0015,
    beta_end=0.0195,
)

model = LatentDiffusionModel(first_stage=stage1_model, unet_network=unet, scheduler=scheduler)

model = model.to(device)

In [None]:
loss_perceptual = perceptual.PerceptualLoss(spatial_dims = 3, network_type = 'squeeze', is_fake_3d = True, fake_3d_ratio= 0.2)

## Train AutoEncoder

In [None]:
optimizer = torch.optim.Adam(model.parameters(), 2.5e-5)
kl_weight = 1e-6

In [None]:
# TODO: Add lr_scheduler with warm-up
# TODO: Add EMA model

n_epochs = 1
val_interval = 2
for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        images = batch["image"][:,[channel],...].to(device)
        optimizer.zero_grad(set_to_none=True)
        
        reconstruction, z_mu, z_sigma = model.first_stage(images)

        l1_loss = F.l1_loss(reconstruction.float(), images.float())
        perceptual_loss = loss_perceptual(reconstruction.float(), images.float())

        kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim = [1, 2, 3, 4])
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
        # TODO : adverarial loss
        #
        loss = l1_loss + perceptual_loss + kl_weight * kl_loss        

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        progress_bar.set_postfix(
            {
                "loss": epoch_loss / (step + 1),
            }
        )

### Visualise reconstructions

In [None]:
# Plot axial, coronal and sagittal slices of a training sample
idx = 0
img = reconstruction[idx, channel].detach().cpu().numpy()
fig, axs = plt.subplots(nrows=1, ncols=3)
for ax in axs:
    ax.axis("off")
ax = axs[0]
ax.imshow(img[...,img.shape[2]//2], cmap="gray")
ax = axs[1]
ax.imshow(img[:,img.shape[1]//2, ...], cmap="gray")
ax = axs[2]
ax.imshow(img[img.shape[0]//2, ...], cmap="gray")
plt.savefig("reconstruction_examples.png")