# The Third Step: Conditioning LDM

## Setup imports

In [None]:
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import CacheDataset, DataLoader
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks import nets
from generative.networks.schedulers import DDPMScheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import generative
generative.__version__

## Prepare Kvasir-SEG Dataset

In [None]:
path_img = './Kvasir-SEG/images/'
path_msk = './Kvasir-SEG/masks/'
fnames_img = [f for f in os.listdir(path_img) if '.jpg' in f]
datalist = []
for fname in fnames_img:
    data = {'image': path_img+fname, 'seg': path_msk+fname}
    datalist.append(data)

# Shuffle
# np.random.shuffle(datalist)
# Split the datalist to train and validation
train_datalist = datalist[:950]
val_datalist = datalist[950:]
datalist[:3]

## Conditioning LDM

### Transforms and Dataloader Setup 

In [None]:
batch_size = 1
target_key = ['image', 'seg']
shape = [128, 128]
train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=target_key),
        transforms.EnsureChannelFirstd(keys=target_key),
        transforms.Resized(keys=target_key, spatial_size=shape),
        transforms.ScaleIntensityRanged(keys=target_key,
                                        a_min=0.0, a_max=255.0,
                                        b_min=0.0, b_max=1.0, clip=True),
        transforms.RandAffined(
            keys=target_key,
            rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],
            scale_range=[(-0.05, 0.05), (-0.05, 0.05)],
            padding_mode="zeros",
            prob=0.5,
        ),
    ]
)
train_ds = CacheDataset(data=train_datalist, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)

val_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=target_key),
        transforms.EnsureChannelFirstd(keys=target_key),
        transforms.Resized(keys=target_key, spatial_size=shape),
        transforms.ScaleIntensityRanged(keys=target_key,
                                        a_min=0.0, a_max=255.0,
                                        b_min=0.0, b_max=1.0, clip=True),
    ]
)
val_ds = CacheDataset(data=val_datalist, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)

In [None]:
d = first(train_loader)
d['image'].shape, d['seg'].shape

### Load Pre-trained AutoEncoderKL

In [None]:
ae_img = nets.AutoencoderKL(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_channels=(32, 64, 64),
    latent_channels=3,
    num_res_blocks=1,
    norm_num_groups=16,
    attention_levels=(False, False, True),
).to(device)

state_dict_img = torch.load(os.path.join('models', 'AE_img.pt'))
ae_img.load_state_dict(state_dict_img)

In [None]:
ae_mask = nets.AutoencoderKL(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_channels=(32, 64, 64),
    latent_channels=32,
    num_res_blocks=1,
    norm_num_groups=16,
    attention_levels=(False, False, True),
).to(device)

state_dict_mask = torch.load(os.path.join('models', 'AE_mask.pt'))
ae_mask.load_state_dict(state_dict_mask)

### Define LDM model, scaling factor, scheduler, inferer and optimizer

In [None]:
ldm = nets.DiffusionModelUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_res_blocks=1,
    num_channels=(32, 64, 64),
    attention_levels=(False, True, True),
    num_head_channels=(0, 64, 64),
    with_conditioning=True,
    cross_attention_dim=32
).to(device)

#### Scaling factor
As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) can affect the results obtained with the LDM, if the standard deviation of the latent space distribution drifts too much from that of a Gaussian. For this reason, it is best practice to use a scaling factor to adapt this standard deviation.

Note: In case where the latent space is close to a Gaussian distribution, the scaling factor will be close to one, and the results will not differ from those obtained when it is not used.

In [None]:
with torch.no_grad():
    with autocast(enabled=True):
        z = ae_img.encode_stage_2_inputs(d['image'].to(device))

print(f"Scaling factor set to {1/torch.std(z)}")
scale_factor = 1 / torch.std(z)

In [None]:
num_train_timesteps = 1000
scheduler = DDPMScheduler(num_train_timesteps=num_train_timesteps, schedule="scaled_linear_beta", beta_start=0.0015, beta_end=0.0195)
inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor)
optimizer_diff = torch.optim.Adam(params=ldm.parameters(), lr=1e-4)

### Training

In [None]:
n_epochs = 1
val_interval = 1
epoch_loss_list = []
val_epoch_loss_list = []
ae_img.eval()
ae_mask.eval()
scaler = GradScaler()

first_batch = first(train_loader)
z = ae_img.encode_stage_2_inputs(d['image'].to(device))

for epoch in range(n_epochs):
    ldm.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=70)
    progress_bar.set_description(f"Epoch {epoch}")
    # Training
    for step, batch in progress_bar:
        images = batch['image'].to(device)
        masks = batch['seg'].to(device)
        optimizer_diff.zero_grad(set_to_none=True)

        with autocast(enabled=True):
            # Generate random noise
            noise = torch.randn_like(z).to(device)
            
            # Get encoded condition
            condition = ae_mask.encode_stage_2_inputs(masks).mean(dim=3).transpose(1, 2).to(device)
            
            # Create timesteps
            timesteps = torch.randint(
                0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
            ).long()

            # Get model prediction
            noise_pred = inferer(
                inputs=images, autoencoder_model=ae_img, diffusion_model=ldm, noise=noise, timesteps=timesteps, condition=condition
            )

            loss = F.mse_loss(noise_pred.float(), noise.float())

        scaler.scale(loss).backward()
        scaler.step(optimizer_diff)
        scaler.update()

        epoch_loss += loss.item()

        progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})
        break
    epoch_loss_list.append(epoch_loss / (step + 1))
    # Validation
    if (epoch + 1) % val_interval == 0:
        ldm.eval()
        val_epoch_loss = 0
        for step, batch in enumerate(val_loader):
            images = batch['image'].to(device)
            masks = batch['seg'].to(device)
            with torch.no_grad():
                with autocast(enabled=True):
                    noise = torch.randn_like(z).to(device)
                    condition = ae_mask.encode_stage_2_inputs(masks).mean(dim=3).transpose(1, 2).to(device)
                    timesteps = torch.randint(
                        0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
                    ).long()

                    noise_pred = inferer(
                        inputs=images, autoencoder_model=ae_img, diffusion_model=ldm, noise=noise, timesteps=timesteps, condition=condition
                    )
                    
                    val_loss = F.mse_loss(noise_pred.float(), noise.float())

            val_epoch_loss += val_loss.item()
        val_epoch_loss_list.append(val_epoch_loss / (step + 1))
        print({"val_loss": val_epoch_loss / (step + 1)})
        # Sampling image during training
        noise = torch.randn((1,)+tuple(images[0].shape))
        noise = noise.to(device)
        scheduler.set_timesteps(num_inference_steps=num_train_timesteps)
        with autocast(enabled=True):
            generated, intermediates = inferer.sample(
            input_noise=noise,
            diffusion_model=ldm,
            autoencoder_model=ae_img,
            scheduler=scheduler,
            save_intermediates=True,
            intermediate_steps=num_train_timesteps//10,
            conditioning=condition
            )
        
        plt.figure(figsize=(4, 12))
        plt.subplot(131)
        plt.imshow(images[0].cpu().numpy().transpose([2,1,0]), vmin=0, vmax=1)
        plt.tight_layout()
        plt.axis("off")
        plt.title('Original')
        plt.subplot(132)
        plt.imshow(masks[0].cpu().numpy().transpose([2,1,0]), vmin=0, vmax=1)
        plt.tight_layout()
        plt.axis("off")
        plt.title('Conditioning')
        plt.subplot(133)
        plt.imshow(generated[0].float().cpu().numpy().transpose([2,1,0]), vmin=0, vmax=1)
        plt.tight_layout()
        plt.axis("off")
        plt.title('Generated')
        plt.show()

In [None]:
plt.plot(epoch_loss_list)
plt.title("Learning Curves", fontsize=20)
plt.plot(epoch_loss_list)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()

In [None]:
plt.title("Adversarial Training Curves", fontsize=20)
plt.plot(epoch_gen_loss_list, color="C0", linewidth=2.0, label="Generator")
plt.plot(epoch_disc_loss_list, color="C1", linewidth=2.0, label="Discriminator")
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()

## Visualization

In [None]:
def vis():
    ldm.eval()
    batch = first(val_loader)
    origin = batch['image'].to(device)
    masks = batch['seg'].to(device)
    condition = ae_mask.encode_stage_2_inputs(masks).mean(dim=3).transpose(1, 2).to(device)
    noise = torch.randn_like(z)
    noise = noise.to(device)
    scheduler.set_timesteps(num_inference_steps=num_train_timesteps)
    with autocast(enabled=True):
        image, intermediates = inferer.sample(
            input_noise=noise,
            diffusion_model=ldm,
            autoencoder_model=ae_img,
            scheduler=scheduler,
            save_intermediates=True,
            intermediate_steps=num_train_timesteps//10,
            conditioning=condition
        )

    chain = torch.cat(intermediates, dim=-1)

    plt.figure(figsize=(3, 6))
    plt.subplot(121)
    plt.imshow(images[0].cpu().numpy().transpose([2,1,0]), vmin=0, vmax=1)
    plt.tight_layout()
    plt.axis("off")
    plt.title('Original')
    plt.subplot(122)
    plt.imshow(masks[0].cpu().numpy().transpose([2,1,0]), vmin=0, vmax=1)
    plt.tight_layout()
    plt.axis("off")
    plt.title('Conditioning')
    plt.show()
    
    plt.figure(figsize=(10, 50))
    plt.style.use("default")
    plt.imshow(chain[0].float().cpu().numpy().transpose([1,2,0]), vmin=0, vmax=1)
    plt.tight_layout()
    plt.axis("off")
    plt.show()
    return chain

_ = vis()