In [1]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Super-resolution using Stable Diffusion v2 Upscalers

Tutorial to illustrate the super-resolution task on medical images using Latent Diffusion Models (LDMs) [1]. For that, we will use an autoencoder to obtain a latent representation of the high-resolution images. Then, we train a diffusion model to infer this latent representation when conditioned on a low-resolution image.

To improve the performance of our models, we will use a method called "noise conditioning augmentation" (introduced in [2] and used in Stable Diffusion v2.0 and Imagen Video [3]). During the training, we add noise to the low-resolution images using a random signal-to-noise ratio, and we condition the diffusion models on the amount of noise added. At sampling time, we use a fixed signal-to-noise ratio, representing a small amount of augmentation that aids in removing artefacts in the samples.


[1] - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752

[2] - Ho et al. "Cascaded diffusion models for high fidelity image generation" https://arxiv.org/abs/2106.15282

[3] - Ho et al. "High Definition Video Generation with Diffusion Models" https://arxiv.org/abs/2210.02303

## Set up environment using Colab


In [1]:
!python -c "import monai" || pip install -q "monai-weekly[tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Set up imports

In [1]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import CacheDataset, DataLoader
from monai.utils import first, set_determinism
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

print_config()

  from .autonotebook import tqdm as notebook_tqdm


    PyTorch 1.13.1+cu117 with CUDA 1107 (you have 1.12.1)
    Python  3.8.16 (you have 3.8.16)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
2023-04-06 16:56:04,614 - Created a temporary directory at /tmp/tmp0zwlum7_
2023-04-06 16:56:04,617 - Writing /tmp/tmp0zwlum7_/_remote_module_non_scriptable.py
MONAI version: 1.2.dev2304
Numpy version: 1.23.5
Pytorch version: 1.12.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 9a57be5aab9f2c2a134768c0c146399150e247a0
MONAI __file__: /home/ol18/miniconda3/envs/monai_generative/lib/python3.8/site-packages/monai_weekly-1.2.dev2304-py3.8.egg/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.10
ITK version: 5.3.0
Nibabel version: 5.1.0
scikit-image version: 0.20.0
Pillow version: 9.4.0
Tensorboard version: 2.12.1
gdown 

In [2]:
# for reproducibility purposes set a seed
set_determinism(42)

## Setup a data directory and download dataset
Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used.

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpqpsolhvx


## Define the LightningModule for AutoEncoder (transforms, network, loaders, etc)
The LightningModule contains a refactoring of your training code. The following module is a refactoring of the code in 2d_stable_diffusion_v2_super_resolution-lightning


In [21]:
class AutoEnconder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.data_dir = root_dir
        self.autoencoderkl = AutoencoderKL(spatial_dims=2,
                                           in_channels=1,
                                           out_channels=1,
                                           num_channels=(256, 512, 512),
                                           latent_channels=3,
                                           num_res_blocks=2,
                                           norm_num_groups=32,
                                           attention_levels=(False, False, True))
        self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1,
                                                num_layers_d=3, num_channels=64)
        self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="alex")
        self.perceptual_weight = 0.002
        self.autoencoder_warm_up_n_epochs = 10
        self.automatic_optimization = False
        self.adv_loss = PatchAdversarialLoss(criterion="least_squares")
        self.adv_weight = 0.005
        self.kl_weight = 1e-6
        
    def forward(self, z):
        return self.autoencoderkl(z)

    def prepare_data(self):
        image_size = 64
        train_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image"]),
                transforms.EnsureChannelFirstd(keys=["image"]),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0,
                                                b_max=1.0, clip=True),
                transforms.RandAffined(
                    keys=["image"],
                    rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],
                    translate_range=[(-1, 1), (-1, 1)],
                    scale_range=[(-0.05, 0.05), (-0.05, 0.05)],
                    spatial_size=[image_size, image_size],
                    padding_mode="zeros",
                    prob=0.5,
                ),
                transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]),
                transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)),
            ]
        )
        
        val_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image"]),
                transforms.EnsureChannelFirstd(keys=["image"]),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, 
                                                b_max=1.0, clip=True),
                transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]),
                transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)),
            ]
        )
        
        train_data = MedNISTDataset(root_dir=self.data_dir, section="training", download=True, seed=0)
        train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "HeadCT"]
        val_data = MedNISTDataset(root_dir=self.data_dir, section="validation", download=True, seed=0)
        val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "HeadCT"]
        
        self.train_ds = CacheDataset(data=train_datalist, transform=train_transforms)
        self.val_ds = CacheDataset(data=val_datalist, transform=val_transforms)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=16, shuffle=True,
                          num_workers=4, persistent_workers=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=16, shuffle=False,
                          num_workers=4)
                          
    def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma):
        recons_loss = F.l1_loss(reconstruction.float(), images.float())
        p_loss = self.perceptual_loss(reconstruction.float(), images.float())
        kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
        loss_g = recons_loss + (self.kl_weight * kl_loss) + (self.perceptual_weight * p_loss)
        return loss_g,recons_loss
    
    def _compute_loss_discriminator(self, reconstruction):
        logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
        loss_d_fake = self.adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
        logits_real = discriminator(images.contiguous().detach())[-1]
        loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
        discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
        loss_d = self.adv_weight * discriminator_loss
        return loss_d, discriminator_loss
        
    def training_step(self, batch, batch_idx):
        optimizer_g, optimizer_d = self.optimizers()
        images = batch["image"]
        reconstruction, z_mu, z_sigma = self.forward(images)
        loss_g, recons_loss = self._compute_loss_generator(images, reconstruction, z_mu, z_sigma)
        self.log("recons_loss", recons_loss, prog_bar=True)

        if self.current_epoch > self.autoencoder_warm_up_n_epochs:
            logits_fake = discriminator(reconstruction.contiguous().float())[-1]
            generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            loss_g += adv_weight * generator_loss
                          

        self.log("loss_g", loss_g, prog_bar=True)
        self.manual_backward(loss_g)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        if self.current_epoch > self.autoencoder_warm_up_n_epochs:
            loss_d, discriminator_loss = self._compute_loss_discriminator(reconstruction)
            self.log("loss_d", loss_d, prog_bar=True)
            self.manual_backward(loss_d)
            optimizer_d.step()
            optimizer_d.zero_grad()
            self.untoggle_optimizer(optimizer_d)

                          
        if self.current_epoch > self.autoencoder_warm_up_n_epochs:
            gen_epoch_loss += generator_loss.item()
            disc_epoch_loss += discriminator_loss.item()
            self.log("gen_loss", gen_loss, prog_bar=True)
            self.log("disc_loss", disc_loss, prog_bar=True)
                          
                                    
    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        reconstruction, z_mu, z_sigma = self.autoencoderkl(images)
        recons_loss = F.l1_loss(images.float(), reconstruction.float())
        self.log("loss_d", recons_loss, prog_bar=True)
                              

    def configure_optimizers(self):
        optimizer_g = torch.optim.Adam(self.autoencoderkl.parameters(), lr=5e-5)
        optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)
        return [optimizer_g, optimizer_d], []
                          


## Train Autoencoder

In [22]:
n_epochs = 75
val_interval = 10

                          
# initialise the LightningModule
net = AutoEnconder()

# set up checkpoints

checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename="best_metric_model")

                         
# initialise Lightning's trainer.
trainer = pl.Trainer(devices=1,
                     max_epochs=n_epochs,
                     check_val_every_n_epoch=val_interval,
                     callbacks=checkpoint_callback,
                     default_root_dir=root_dir)

# train
trainer.fit(net)

The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.


2023-04-06 17:32:20,399 - GPU available: True (cuda), used: True
2023-04-06 17:32:20,401 - TPU available: False, using: 0 TPU cores
2023-04-06 17:32:20,402 - IPU available: False, using: 0 IPUs
2023-04-06 17:32:20,403 - HPU available: False, using: 0 HPUs
2023-04-06 17:32:20,576 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-04-06 17:32:20,578 - INFO - File exists: /tmp/tmpqpsolhvx/MedNIST.tar.gz, skipped downloading.
2023-04-06 17:32:20,579 - INFO - Non-empty folder exists in /tmp/tmpqpsolhvx/MedNIST, skipped extracting.



Loading dataset:   0%|                                                                                                                               | 0/47164 [00:00<?, ?it/s][A
Loading dataset:   0%|▎                                                                                                                  | 119/47164 [00:00<00:39, 1189.23it/s][A
Loading dataset:   1%|▌                                                                                                                  | 249/47164 [00:00<00:37, 1249.81it/s][A
Loading dataset:   1%|▉                                                                                                                  | 381/47164 [00:00<00:36, 1278.57it/s][A
Loading dataset:   1%|█▏                                                                                                                 | 509/47164 [00:00<00:38, 1204.00it/s][A
Loading dataset:   1%|█▌                                                                                

Loading dataset:  12%|█████████████▎                                                                                                    | 5483/47164 [00:04<00:40, 1029.00it/s][A
Loading dataset:  12%|█████████████▋                                                                                                     | 5588/47164 [00:05<00:43, 954.05it/s][A
Loading dataset:  12%|█████████████▊                                                                                                    | 5702/47164 [00:05<00:41, 1002.25it/s][A
Loading dataset:  12%|██████████████                                                                                                    | 5809/47164 [00:05<00:40, 1020.07it/s][A
Loading dataset:  13%|██████████████▎                                                                                                   | 5913/47164 [00:05<00:40, 1021.07it/s][A
Loading dataset:  13%|██████████████▋                                                                    

Epoch 0:   0%|                                                                                                                                         | 0/500 [07:41<?, ?it/s]


Loading dataset:  16%|██████████████████▍                                                                                                | 7586/47164 [00:07<01:18, 501.56it/s][A





Loading dataset:  16%|██████████████████▊                                                                                                | 7723/47164 [00:07<01:01, 636.97it/s][A
Loading dataset:  17%|███████████████████▏                                                                                               | 7865/47164 [00:07<00:50, 779.83it/s][A
Loading dataset:  17%|███████████████████▌                                                                                               | 7998/47164 [00:07<00:43, 892.77it/s][A
Loading dataset:  17%|███████████████████▊                                                                                               | 8117/47164 [00:07<00:43, 891.53it/s][A
Loading dataset:  17%|████████████████████                                                                                               | 8227/47164 [00:08<00:41, 933.56it/s][A
Loading dataset:  18%|████████████████████▎                                                             

Loading dataset:  28%|███████████████████████████████▌                                                                                 | 13158/47164 [00:12<00:28, 1177.42it/s][A
Loading dataset:  28%|███████████████████████████████▊                                                                                 | 13278/47164 [00:12<00:30, 1094.41it/s][A
Loading dataset:  28%|████████████████████████████████▎                                                                                 | 13390/47164 [00:12<00:35, 943.43it/s][A
Loading dataset:  29%|████████████████████████████████▌                                                                                 | 13490/47164 [00:13<00:40, 823.02it/s][A
Loading dataset:  29%|████████████████████████████████▊                                                                                 | 13578/47164 [00:13<00:44, 759.30it/s][A
Loading dataset:  29%|█████████████████████████████████                                                  

Loading dataset:  38%|███████████████████████████████████████████▍                                                                     | 18146/47164 [00:17<00:26, 1082.64it/s][A
Loading dataset:  39%|███████████████████████████████████████████▋                                                                     | 18255/47164 [00:17<00:27, 1035.66it/s][A
Loading dataset:  39%|████████████████████████████████████████████▍                                                                     | 18360/47164 [00:17<00:31, 911.52it/s][A
Loading dataset:  39%|████████████████████████████████████████████▌                                                                     | 18455/47164 [00:18<00:32, 878.87it/s][A
Loading dataset:  39%|████████████████████████████████████████████▊                                                                     | 18545/47164 [00:18<00:32, 881.60it/s][A
Loading dataset:  40%|█████████████████████████████████████████████                                      

Loading dataset:  50%|███████████████████████████████████████████████████████▉                                                         | 23355/47164 [00:22<00:21, 1090.64it/s][A
Loading dataset:  50%|████████████████████████████████████████████████████████▎                                                        | 23495/47164 [00:22<00:20, 1178.60it/s][A
Loading dataset:  50%|████████████████████████████████████████████████████████▌                                                        | 23619/47164 [00:22<00:19, 1195.39it/s][A
Loading dataset:  50%|████████████████████████████████████████████████████████▉                                                        | 23740/47164 [00:23<00:22, 1042.65it/s][A
Loading dataset:  51%|█████████████████████████████████████████████████████████▋                                                        | 23849/47164 [00:23<00:23, 988.48it/s][A
Loading dataset:  51%|█████████████████████████████████████████████████████████▉                         

Loading dataset:  61%|█████████████████████████████████████████████████████████████████████▍                                            | 28752/47164 [00:27<00:19, 953.01it/s][A
Loading dataset:  61%|█████████████████████████████████████████████████████████████████████▋                                            | 28849/47164 [00:27<00:19, 925.73it/s][A
Loading dataset:  61%|█████████████████████████████████████████████████████████████████████▉                                            | 28943/47164 [00:28<00:37, 491.27it/s][A
Loading dataset:  62%|██████████████████████████████████████████████████████████████████████▎                                           | 29072/47164 [00:28<00:28, 628.63it/s][A
Loading dataset:  62%|██████████████████████████████████████████████████████████████████████▌                                           | 29210/47164 [00:28<00:23, 776.30it/s][A
Loading dataset:  62%|██████████████████████████████████████████████████████████████████████▉            

Loading dataset:  72%|█████████████████████████████████████████████████████████████████████████████████▊                               | 34147/47164 [00:33<00:11, 1125.56it/s][A
Loading dataset:  73%|██████████████████████████████████████████████████████████████████████████████████▏                              | 34285/47164 [00:33<00:10, 1197.21it/s][A
Loading dataset:  73%|██████████████████████████████████████████████████████████████████████████████████▍                              | 34411/47164 [00:33<00:10, 1212.96it/s][A
Loading dataset:  73%|██████████████████████████████████████████████████████████████████████████████████▊                              | 34544/47164 [00:33<00:10, 1244.73it/s][A
Loading dataset:  74%|███████████████████████████████████████████████████████████████████████████████████                              | 34670/47164 [00:33<00:10, 1168.39it/s][A
Loading dataset:  74%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  85%|███████████████████████████████████████████████████████████████████████████████████████████████▌                 | 39870/47164 [00:37<00:07, 1021.49it/s][A
Loading dataset:  85%|████████████████████████████████████████████████████████████████████████████████████████████████▌                 | 39975/47164 [00:38<00:07, 972.36it/s][A
Loading dataset:  85%|████████████████████████████████████████████████████████████████████████████████████████████████                 | 40084/47164 [00:38<00:07, 1002.00it/s][A
Loading dataset:  85%|█████████████████████████████████████████████████████████████████████████████████████████████████▏                | 40186/47164 [00:38<00:07, 994.19it/s][A
Loading dataset:  85%|█████████████████████████████████████████████████████████████████████████████████████████████████▍                | 40287/47164 [00:38<00:07, 931.06it/s][A
Loading dataset:  86%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▍     | 44864/47164 [00:42<00:02, 895.46it/s][A
Loading dataset:  95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▋     | 44971/47164 [00:42<00:02, 942.98it/s][A
Loading dataset:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▉     | 45067/47164 [00:43<00:02, 829.99it/s][A
Loading dataset:  96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏    | 45158/47164 [00:43<00:02, 849.65it/s][A
Loading dataset:  96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▍    | 45259/47164 [00:43<00:02, 893.14it/s][A
Loading dataset:  96%|███████████████████████████████████████████████████████████████████████████████████

2023-04-06 17:33:06,469 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-04-06 17:33:06,470 - INFO - File exists: /tmp/tmpqpsolhvx/MedNIST.tar.gz, skipped downloading.
2023-04-06 17:33:06,472 - INFO - Non-empty folder exists in /tmp/tmpqpsolhvx/MedNIST, skipped extracting.



Loading dataset:   0%|                                                                                                                                | 0/5895 [00:00<?, ?it/s][A
Loading dataset:   2%|██▎                                                                                                                 | 117/5895 [00:00<00:04, 1166.49it/s][A
Loading dataset:   4%|████▊                                                                                                               | 247/5895 [00:00<00:04, 1240.65it/s][A
Loading dataset:   6%|███████▍                                                                                                            | 379/5895 [00:00<00:04, 1274.04it/s][A
Loading dataset:   9%|██████████                                                                                                          | 514/5895 [00:00<00:04, 1301.81it/s][A
Loading dataset:  11%|████████████▋                                                                     

Loading dataset:  87%|████████████████████████████████████████████████████████████████████████████████████████████████████               | 5131/5895 [00:05<00:00, 1001.93it/s][A
Loading dataset:  89%|██████████████████████████████████████████████████████████████████████████████████████████████████████▋            | 5265/5895 [00:05<00:00, 1091.64it/s][A
Loading dataset:  91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▏         | 5392/5895 [00:05<00:00, 1139.75it/s][A
Loading dataset:  93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▍       | 5509/5895 [00:05<00:00, 1076.84it/s][A
Loading dataset:  95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▋     | 5620/5895 [00:05<00:00, 1043.52it/s][A
Loading dataset:  97%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  36%|█████████████████████████████████████████▊                                                                          | 2883/7991 [00:04<00:07, 729.11it/s][A
Loading dataset:  37%|██████████████████████████████████████████▉                                                                         | 2957/7991 [00:04<00:06, 729.69it/s][A
Loading dataset:  38%|███████████████████████████████████████████▉                                                                        | 3031/7991 [00:04<00:07, 703.04it/s][A
Loading dataset:  39%|█████████████████████████████████████████████▏                                                                      | 3111/7991 [00:04<00:06, 730.37it/s][A
Loading dataset:  40%|██████████████████████████████████████████████▎                                                                     | 3188/7991 [00:04<00:06, 740.48it/s][A
Loading dataset:  41%|███████████████████████████████████████████████▍                                   

Loading dataset:  77%|█████████████████████████████████████████████████████████████████████████████████████████▏                          | 6141/7991 [00:08<00:02, 634.13it/s][A
Loading dataset:  78%|██████████████████████████████████████████████████████████████████████████████████████████                          | 6207/7991 [00:09<00:02, 638.86it/s][A
Loading dataset:  79%|███████████████████████████████████████████████████████████████████████████████████████████▏                        | 6285/7991 [00:09<00:02, 676.93it/s][A
Loading dataset:  80%|████████████████████████████████████████████████████████████████████████████████████████████▏                       | 6354/7991 [00:09<00:02, 651.66it/s][A
Loading dataset:  80%|█████████████████████████████████████████████████████████████████████████████████████████████▏                      | 6423/7991 [00:09<00:02, 657.01it/s][A
Loading dataset:  81%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▏       | 908/972 [00:02<00:00, 458.87it/s][A
Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 972/972 [00:02<00:00, 413.47it/s][A
Checkpoint directory /tmp/tmpqpsolhvx exists and is not empty.


2023-04-06 17:33:27,505 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
2023-04-06 17:33:27,519 - 
  | Name            | Type                 | Params
---------------------------------------------------------
0 | autoencoderkl   | AutoencoderKL        | 75.1 M
1 | discriminator   | PatchDiscriminator   | 2.8 M 
2 | perceptual_loss | PerceptualLoss       | 2.5 M 
3 | adv_loss        | PatchAdversarialLoss | 0     
---------------------------------------------------------
77.8 M    Trainable params
2.5 M     Non-trainable params
80.3 M    Total params
321.225   Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]

Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Epoch 1:   2%|█▊                                                                                  | 11/500 [00:07<05:36,  1.45it/s, v_num=7, recons_loss=0.0389, loss_g=0.0403]

Detected KeyboardInterrupt, attempting graceful shutdown...


## Rescaling factor

As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) became crucial in image-to-image translation models (such as the ones used for super-resolution). For this reason, we will compute the component-wise standard deviation to be used as scaling factor.

In [30]:
train_loader = net.train_dataloader()
check_data = first(train_loader)
with torch.no_grad():
    with autocast(enabled=True):
        z = net.autoencoderkl.encode_stage_2_inputs(check_data["image"].to(net.device))

print(f"Scaling factor set to {1/torch.std(z)}")
scale_factor = 1 / torch.std(z)

Scaling factor set to 0.8888660669326782


## Define the LightningModule for AutoEncoder (transforms, network, loaders, etc)
The LightningModule contains a refactoring of your training code. The following module is a refactoring of the code in 2d_stable_diffusion_v2_super_resolution-lightning

In [None]:
class DiffusionUNET(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.data_dir = root
        self.unet = DiffusionModelUNet(
                                spatial_dims=2,
                                in_channels=4,
                                out_channels=3,
                                num_res_blocks=2,
                                num_channels=(256, 256, 512, 1024),
                                attention_levels=(False, False, True, True),
                                num_head_channels=(0, 0, 64, 64),
                            )
        self.max_noise_level = 350

    def forward(self, x):
        return self.unet(x)
    
    
    def prepare_data(self):
        image_size = 64
        train_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image"]),
                transforms.EnsureChannelFirstd(keys=["image"]),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0,
                                                b_max=1.0, clip=True),
                transforms.RandAffined(
                    keys=["image"],
                    rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],
                    translate_range=[(-1, 1), (-1, 1)],
                    scale_range=[(-0.05, 0.05), (-0.05, 0.05)],
                    spatial_size=[image_size, image_size],
                    padding_mode="zeros",
                    prob=0.5,
                ),
                transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]),
                transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)),
            ]
        )
        
        val_transforms = transforms.Compose(
            [
                transforms.LoadImaged(keys=["image"]),
                transforms.EnsureChannelFirstd(keys=["image"]),
                transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, 
                                                b_max=1.0, clip=True),
                transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]),
                transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)),
            ]
        )
        
        train_data = MedNISTDataset(root_dir=self.data_dir, section="training", download=True, seed=0)
        train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "HeadCT"]
        val_data = MedNISTDataset(root_dir=self.data_dir, section="validation", download=True, seed=0)
        val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "HeadCT"]
        
        self.train_ds = CacheDataset(data=train_datalist, transform=train_transforms)
        self.val_ds = CacheDataset(data=val_datalist, transform=val_transforms)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=16, shuffle=True,
                          num_workers=4, persistent_workers=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=16, shuffle=True,
                          num_workers=4)

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        low_res_image = batch["low_res_image"]
        optimizer = self.optimizers()
        scheduler, low_res_scheduler = self.schedulers()
        
        latent = net.autoencoderkl.encode_stage_2_inputs(images) * scale_factor

        # Noise augmentation
        noise = torch.randn_like(latent)
        low_res_noise = torch.randn_like(low_res_image)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latent.shape[0],),
                                  device=latent.device).long()
        low_res_timesteps = torch.randint(
            0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device
        ).long()

        noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)
        noisy_low_res_image = scheduler.add_noise(
            original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps
        )

        latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)

        noise_pred = self.forward(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)
        loss = F.mse_loss(noise_pred.float(), noise.float())
        
       
        self.log("loss", loss, prog_bar=True)
        optimizer.zero_grad()
        self.manual_backward(loss)
        optimizer.step()
        scheduler.step()
        low_res_scheduler.step()

        
    def validation_step(self, batch, batch_idx):
        
        images = batch["image"].to(device)
        low_res_image = batch["low_res_image"].to(device)
        latent = net.autoencoderkl.encode_stage_2_inputs(images) * scale_factor
        
        # Noise augmentation
        noise = torch.randn_like(latent).to(device)
        low_res_noise = torch.randn_like(low_res_image).to(device)
        timesteps = torch.randint(
            0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device
        ).long()
        low_res_timesteps = torch.randint(
            0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device
        ).long()

        noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)
        noisy_low_res_image = scheduler.add_noise(
            original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps
        )

        latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)
        noise_pred = self.forward(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)
        loss = F.mse_loss(noise_pred.float(), noise.float())

        val_loss += loss.item()
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(unet.parameters(), lr=5e-5)
        scheduler = DDPMScheduler(num_train_timesteps=1000, 
                                   beta_schedule="linear", beta_start=0.0015, beta_end=0.0195)
        low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear",
                                          beta_start=0.0015, beta_end=0.0195)

        return [optimizer], [scheduler, low_res_scheduler]


## Train Diffusion Model

In order to train the diffusion model to perform super-resolution, we will need to concatenate the latent representation of the high-resolution with the low-resolution image. For this, we create a Diffusion model with `in_channels=4`. Since only the outputted latent representation is interesting, we set `out_channels=3`.

As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler `low_res_scheduler` to add this noise, with the `t` step defining the signal-to-noise ratio and use the `t` value to condition the diffusion model (inputted using `class_labels` argument).

In [None]:
optimizer = torch.optim.Adam(unet.parameters(), lr=5e-5)

scaler_diffusion = GradScaler()

n_epochs = 200
val_interval = 20
epoch_loss_list = []
val_epoch_loss_list = []

for epoch in range(n_epochs):
    unet.train()
    autoencoderkl.eval()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)
    progress_bar.set_description(f"Epoch {epoch}")
    for step, batch in progress_bar:
        images = batch["image"].to(device)
        low_res_image = batch["low_res_image"].to(device)
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=True):
            with torch.no_grad():
                latent = autoencoderkl.encode_stage_2_inputs(images) * scale_factor

            # Noise augmentation
            noise = torch.randn_like(latent).to(device)
            low_res_noise = torch.randn_like(low_res_image).to(device)
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device).long()
            low_res_timesteps = torch.randint(
                0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device
            ).long()

            noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)
            noisy_low_res_image = scheduler.add_noise(
                original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps
            )

            latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)

            noise_pred = unet(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)
            loss = F.mse_loss(noise_pred.float(), noise.float())

        scaler_diffusion.scale(loss).backward()
        scaler_diffusion.step(optimizer)
        scaler_diffusion.update()

        epoch_loss += loss.item()

        progress_bar.set_postfix({"loss": epoch_loss / (step + 1)})
    epoch_loss_list.append(epoch_loss / (step + 1))

    if (epoch + 1) % val_interval == 0:
        unet.eval()
        val_loss = 0
        for val_step, batch in enumerate(val_loader, start=1):
            images = batch["image"].to(device)
            low_res_image = batch["low_res_image"].to(device)

            with torch.no_grad():
                with autocast(enabled=True):
                    latent = autoencoderkl.encode_stage_2_inputs(images) * scale_factor
                    # Noise augmentation
                    noise = torch.randn_like(latent).to(device)
                    low_res_noise = torch.randn_like(low_res_image).to(device)
                    timesteps = torch.randint(
                        0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device
                    ).long()
                    low_res_timesteps = torch.randint(
                        0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device
                    ).long()

                    noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)
                    noisy_low_res_image = scheduler.add_noise(
                        original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps
                    )

                    latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)
                    noise_pred = unet(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)
                    loss = F.mse_loss(noise_pred.float(), noise.float())

            val_loss += loss.item()
        val_loss /= val_step
        val_epoch_loss_list.append(val_loss)
        print(f"Epoch {epoch} val loss: {val_loss:.4f}")

        # Sampling image during training
        sampling_image = low_res_image[0].unsqueeze(0)
        latents = torch.randn((1, 3, 16, 16)).to(device)
        low_res_noise = torch.randn((1, 1, 16, 16)).to(device)
        noise_level = 20
        noise_level = torch.Tensor((noise_level,)).long().to(device)
        noisy_low_res_image = scheduler.add_noise(
            original_samples=sampling_image,
            noise=low_res_noise,
            timesteps=torch.Tensor((noise_level,)).long().to(device),
        )

        scheduler.set_timesteps(num_inference_steps=1000)
        for t in tqdm(scheduler.timesteps, ncols=110):
            with torch.no_grad():
                with autocast(enabled=True):
                    latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)
                    noise_pred = unet(
                        x=latent_model_input, timesteps=torch.Tensor((t,)).to(device), class_labels=noise_level
                    )
                latents, _ = scheduler.step(noise_pred, t, latents)

        with torch.no_grad():
            decoded = autoencoderkl.decode_stage_2_outputs(latents / scale_factor)

        low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode="bicubic")
        plt.figure(figsize=(2, 2))
        plt.style.use("default")
        plt.imshow(
            torch.cat([images[0, 0].cpu(), low_res_bicubic[0, 0].cpu(), decoded[0, 0].cpu()], dim=1),
            vmin=0,
            vmax=1,
            cmap="gray",
        )
        plt.tight_layout()
        plt.axis("off")
        plt.show()

### Plotting sampling example

In [None]:
# Sampling image during training
unet.eval()
num_samples = 3
validation_batch = first(val_loader)

images = validation_batch["image"].to(device)
sampling_image = validation_batch["low_res_image"].to(device)[:num_samples]

In [None]:
latents = torch.randn((num_samples, 3, 16, 16)).to(device)
low_res_noise = torch.randn((num_samples, 1, 16, 16)).to(device)
noise_level = 10
noise_level = torch.Tensor((noise_level,)).long().to(device)
noisy_low_res_image = scheduler.add_noise(
    original_samples=sampling_image, noise=low_res_noise, timesteps=torch.Tensor((noise_level,)).long().to(device)
)
scheduler.set_timesteps(num_inference_steps=1000)
for t in tqdm(scheduler.timesteps, ncols=110):
    with torch.no_grad():
        with autocast(enabled=True):
            latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)
            noise_pred = unet(x=latent_model_input, timesteps=torch.Tensor((t,)).to(device), class_labels=noise_level)

        # 2. compute previous image: x_t -> x_t-1
        latents, _ = scheduler.step(noise_pred, t, latents)

with torch.no_grad():
    decoded = autoencoderkl.decode_stage_2_outputs(latents / scale_factor)

In [None]:
low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode="bicubic")
fig, axs = plt.subplots(num_samples, 3, figsize=(8, 8))
axs[0, 0].set_title("Original image")
axs[0, 1].set_title("Low-resolution Image")
axs[0, 2].set_title("Outputted image")
for i in range(0, num_samples):
    axs[i, 0].imshow(images[i, 0].cpu(), vmin=0, vmax=1, cmap="gray")
    axs[i, 0].axis("off")
    axs[i, 1].imshow(low_res_bicubic[i, 0].cpu(), vmin=0, vmax=1, cmap="gray")
    axs[i, 1].axis("off")
    axs[i, 2].imshow(decoded[i, 0].cpu(), vmin=0, vmax=1, cmap="gray")
    axs[i, 2].axis("off")
plt.tight_layout()

### Clean-up data directory

In [None]:
if directory is None:
    shutil.rmtree(root_dir)