# The First Step: Image AutoEncoderKL

## Environment Preparation

In [None]:
!pip install monai[tqdm] lpips

In [None]:
!git clone https://github.com/Project-MONAI/GenerativeModels.git

In [None]:
%cd GenerativeModels/
!python setup.py install
%cd ..

In [None]:
!rm -r GenerativeModels/
import os
os._exit(00)

## Setup imports

In [None]:
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import CacheDataset, DataLoader
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks import nets
from generative.networks.schedulers import DDPMScheduler

In [None]:
import generative
generative.__version__

## Prepare Kvasir-SEG Dataset

In [None]:
set_determinism(42)

!wget https://datasets.simula.no/downloads/kvasir-seg.zip
!unzip kvasir-seg.zip

In [None]:
path_img = './Kvasir-SEG/images/'
path_msk = './Kvasir-SEG/masks/'
fnames_img = [f for f in os.listdir(path_img) if '.jpg' in f]
datalist = []
for fname in fnames_img:
    data = {'image': path_img+fname, 'seg': path_msk+fname}
    datalist.append(data)

# Shuffle
# np.random.shuffle(datalist)
# Split the datalist to train and validation
train_datalist = datalist[:950]
val_datalist = datalist[950:]
datalist[:3]

## Image AutoEncoder

### Transforms and Dataloader Setup 

In [None]:
batch_size = 16
target_key = 'image'
shape = [128, 128]
train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=target_key),
        transforms.EnsureChannelFirstd(keys=target_key),
        transforms.Resized(keys=target_key, spatial_size=shape),
        transforms.ScaleIntensityRanged(keys=target_key, a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        transforms.RandAffined(
            keys=target_key,
            rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],
            scale_range=[(-0.05, 0.05), (-0.05, 0.05)],
            padding_mode="zeros",
            prob=0.5,
        ),
    ]
)
train_ds = CacheDataset(data=train_datalist, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)

val_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=target_key),
        transforms.EnsureChannelFirstd(keys=target_key),
        transforms.Resized(keys=target_key, spatial_size=shape),
        transforms.ScaleIntensityRanged(keys=target_key, a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
    ]
)
val_ds = CacheDataset(data=val_datalist, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)

### Training Data Preview

In [None]:
check_data = first(train_loader)
print(f"batch shape: {check_data[target_key].shape}")
image_visualisation = check_data[target_key][0].numpy()
plt.figure("training images", (3, 3))
plt.imshow(image_visualisation.transpose([2, 1, 0]), vmin=0, vmax=1)
plt.axis("off")
plt.tight_layout()
plt.show()

### Define models, optimizer and loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

autoencoder = nets.AutoencoderKL(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_channels=(32, 64, 64),
    latent_channels=3,
    num_res_blocks=1,
    norm_num_groups=16,
    attention_levels=(False, False, True),
).to(device)

discriminator = nets.PatchDiscriminator(spatial_dims=2, num_layers_d=3, num_channels=32, in_channels=3, out_channels=1).to(device)

In [None]:
l1_loss = torch.nn.L1Loss()
adv_loss = PatchAdversarialLoss(criterion="least_squares")
loss_perceptual = PerceptualLoss(spatial_dims=2, network_type="squeeze", is_fake_3d=True, fake_3d_ratio=0.2).to(device)

def KL_loss(z_mu, z_sigma):
    kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
    return torch.sum(kl_loss) / kl_loss.shape[0]

adv_weight = 0.01
perceptual_weight = 0.001
kl_weight = 1e-6

In [None]:
optimizer_g = torch.optim.Adam(params=autoencoder.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4)

### Training

In [None]:
n_epochs = 100
autoencoder_warm_up_n_epochs = 5
val_interval = 10
epoch_recon_loss_list = []
epoch_gen_loss_list = []
epoch_disc_loss_list = []
val_recon_epoch_loss_list = []
intermediary_images = []
n_example_images = 4

for epoch in range(n_epochs):
    autoencoder.train()
    discriminator.train()
    epoch_loss = 0
    gen_epoch_loss = 0
    disc_epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        images = batch["image"].to(device)  # choose only one of Brats channels

        # Generator part
        optimizer_g.zero_grad(set_to_none=True)
        reconstruction, z_mu, z_sigma = autoencoder(images)
        kl_loss = KL_loss(z_mu, z_sigma)

        recons_loss = l1_loss(reconstruction.float(), images.float())
        p_loss = loss_perceptual(reconstruction.float(), images.float())
        loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss

        if epoch > autoencoder_warm_up_n_epochs:
            logits_fake = discriminator(reconstruction.contiguous().float())[-1]
            generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            loss_g += adv_weight * generator_loss

        loss_g.backward()
        optimizer_g.step()

        if epoch > autoencoder_warm_up_n_epochs:
            # Discriminator part
            optimizer_d.zero_grad(set_to_none=True)
            logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
            loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
            logits_real = discriminator(images.contiguous().detach())[-1]
            loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
            discriminator_loss = (loss_d_fake + loss_d_real) * 0.5

            loss_d = adv_weight * discriminator_loss

            loss_d.backward()
            optimizer_d.step()

        epoch_loss += recons_loss.item()
        if epoch > autoencoder_warm_up_n_epochs:
            gen_epoch_loss += generator_loss.item()
            disc_epoch_loss += discriminator_loss.item()

        progress_bar.set_postfix(
            {
                "recons_loss": epoch_loss / (step + 1),
                "gen_loss": gen_epoch_loss / (step + 1),
                "disc_loss": disc_epoch_loss / (step + 1),
            }
        )
    epoch_recon_loss_list.append(epoch_loss / (step + 1))
    epoch_gen_loss_list.append(gen_epoch_loss / (step + 1))
    epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))

In [None]:
torch.save(autoencoder.state_dict(), os.path.join('models', 'AE_img.pt'))

In [None]:
plt.style.use("ggplot")
plt.title("Learning Curves", fontsize=20)
plt.plot(epoch_recon_loss_list)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()

In [None]:
plt.title("Adversarial Training Curves", fontsize=20)
plt.plot(epoch_gen_loss_list, color="C0", linewidth=2.0, label="Generator")
plt.plot(epoch_disc_loss_list, color="C1", linewidth=2.0, label="Discriminator")
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()

## Visualization

In [None]:
# Plot axial, coronal and sagittal slices of a training sample
idx = 0

plt.figure(figsize=(4, 8))
plt.subplot(121)
plt.imshow(images[idx].detach().cpu().numpy().transpose([2, 1, 0]), vmin=0, vmax=1)
plt.title('Origin')
plt.subplot(122)
plt.imshow(reconstruction[idx].detach().cpu().numpy().transpose([2, 1, 0]), vmin=0, vmax=1)
plt.title('Reconstruction')
plt.show()