In [1]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Super-resolution using Stable Diffusion v2 Upscalers

Tutorial to illustrate the super-resolution task on medical images using Latent Diffusion Models (LDMs) [1]. For that, we will use an autoencoder to obtain a latent representation of the high-resolution images. Then, we train a diffusion model to infer this latent representation when conditioned on a low-resolution image.

To improve the performance of our models, we will use a method called "noise conditioning augmentation" (introduced in [2] and used in Stable Diffusion v2.0 and Imagen Video [3]). During the training, we add noise to the low-resolution images using a random signal-to-noise ratio, and we condition the diffusion models on the amount of noise added. At sampling time, we use a fixed signal-to-noise ratio, representing a small amount of augmentation that aids in removing artefacts in the samples.


[1] - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752

[2] - Ho et al. "Cascaded diffusion models for high fidelity image generation" https://arxiv.org/abs/2106.15282

[3] - Ho et al. "High Definition Video Generation with Diffusion Models" https://arxiv.org/abs/2210.02303

## Set up environment using Colab


In [2]:
!python -c "import monai" || pip install -q "monai-weekly[tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Set up imports

In [3]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import CacheDataset, DataLoader
from monai.utils import first, set_determinism
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

print_config()

  from .autonotebook import tqdm as notebook_tqdm


    PyTorch 1.13.1+cu117 with CUDA 1107 (you have 1.12.1)
    Python  3.8.16 (you have 3.8.16)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484810403/work/c10/cuda/CUDAFunctions.cpp:109.)
User provided device_type of 'cuda', but CUDA is not available. Disabling


2023-04-10 18:37:54,723 - Created a temporary directory at /tmp/tmpk_8ni2di
2023-04-10 18:37:54,724 - Writing /tmp/tmpk_8ni2di/_remote_module_non_scriptable.py
MONAI version: 1.2.dev2304
Numpy version: 1.23.5
Pytorch version: 1.12.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 9a57be5aab9f2c2a134768c0c146399150e247a0
MONAI __file__: /home/ol18/miniconda3/envs/monai_generative/lib/python3.8/site-packages/monai_weekly-1.2.dev2304-py3.8.egg/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.10
ITK version: 5.3.0
Nibabel version: 5.1.0
scikit-image version: 0.20.0
Pillow version: 9.4.0
Tensorboard version: 2.12.1
gdown version: 4.7.1
TorchVision version: 0.13.1
tqdm version: 4.65.0
lmdb version: 1.4.0
psutil version: 5.9.4
pandas version: 2.0.0
einops version: 0.6.0
transformers version: 4.21.3
mlflow version: 2.2.2
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.mona

In [4]:
# for reproducibility purposes set a seed
set_determinism(42)

## Setup a data directory and download dataset
Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used.

In [5]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpa48mavq5


## Setup utils functions

In [6]:
def get_train_transforms():
    image_size = 64
    train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0,
                                        b_max=1.0, clip=True),
        transforms.RandAffined(
            keys=["image"],
            rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],
            translate_range=[(-1, 1), (-1, 1)],
            scale_range=[(-0.05, 0.05), (-0.05, 0.05)],
            spatial_size=[image_size, image_size],
            padding_mode="zeros",
            prob=0.5,
        ),
        transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]),
        transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)),
    ]
    )
    return train_transforms

def get_val_transforms():
    val_transforms = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image"]),
            transforms.EnsureChannelFirstd(keys=["image"]),
            transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, 
                                            b_max=1.0, clip=True),
            transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]),
            transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)),
        ]
    )
    return val_transforms



## Define the LightningModule for AutoEncoder (transforms, network, loaders, etc)
The LightningModule contains a refactoring of your training code. The following module is a reformatiing of the code in 2d_stable_diffusion_v2_super_resolution.


In [7]:
class AutoEnconder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.data_dir = root_dir
        self.autoencoderkl = AutoencoderKL(spatial_dims=2,
                                           in_channels=1,
                                           out_channels=1,
                                           num_channels=(256, 512, 512),
                                           latent_channels=3,
                                           num_res_blocks=2,
                                           norm_num_groups=32,
                                           attention_levels=(False, False, True))
        self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1,
                                                num_layers_d=3, num_channels=64)
        self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="alex")
        self.perceptual_weight = 0.002
        self.autoencoder_warm_up_n_epochs = 10
        self.automatic_optimization = False
        self.adv_loss = PatchAdversarialLoss(criterion="least_squares")
        self.adv_weight = 0.005
        self.kl_weight = 1e-6
        
    def forward(self, z):
        return self.autoencoderkl(z)

    def prepare_data(self):
        train_transforms = get_train_transforms()
        val_transforms = get_val_transforms()
        
        train_data = MedNISTDataset(root_dir=self.data_dir, section="training", download=True, seed=0)
        train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "HeadCT"]
        val_data = MedNISTDataset(root_dir=self.data_dir, section="validation", download=True, seed=0)
        val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "HeadCT"]
        
        self.train_ds = CacheDataset(data=train_datalist, transform=train_transforms)
        self.val_ds = CacheDataset(data=val_datalist, transform=val_transforms)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=16, shuffle=True,
                          num_workers=4, persistent_workers=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=16, shuffle=False,
                          num_workers=4)
                          
    def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma):
        recons_loss = F.l1_loss(reconstruction.float(), images.float())
        p_loss = self.perceptual_loss(reconstruction.float(), images.float())
        kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
        loss_g = recons_loss + (self.kl_weight * kl_loss) + (self.perceptual_weight * p_loss)
        return loss_g,recons_loss
    
    def _compute_loss_discriminator(self, reconstruction):
        logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
        loss_d_fake = self.adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
        logits_real = discriminator(images.contiguous().detach())[-1]
        loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
        discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
        loss_d = self.adv_weight * discriminator_loss
        return loss_d, discriminator_loss
        
    def training_step(self, batch, batch_idx):
        optimizer_g, optimizer_d = self.optimizers()
        images = batch["image"]
        reconstruction, z_mu, z_sigma = self.forward(images)
        loss_g, recons_loss = self._compute_loss_generator(images, reconstruction, z_mu, z_sigma)
        self.log("recons_loss", recons_loss, batch_size=16, prog_bar=True)

        if self.current_epoch > self.autoencoder_warm_up_n_epochs:
            logits_fake = discriminator(reconstruction.contiguous().float())[-1]
            generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            loss_g += adv_weight * generator_loss
                          

        self.log("loss_g", loss_g, batch_size=16, prog_bar=True)
        self.manual_backward(loss_g)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        if self.current_epoch > self.autoencoder_warm_up_n_epochs:
            loss_d, discriminator_loss = self._compute_loss_discriminator(reconstruction)
            self.log("train_loss_d", loss_d, batch_size=16, prog_bar=True)
            self.manual_backward(loss_d)
            optimizer_d.step()
            optimizer_d.zero_grad()
            self.untoggle_optimizer(optimizer_d)

                          
        if self.current_epoch > self.autoencoder_warm_up_n_epochs:
            gen_epoch_loss += generator_loss.item()
            disc_epoch_loss += discriminator_loss.item()
            self.log("gen_loss", gen_loss, batch_size=16, prog_bar=True)
            self.log("disc_loss", disc_loss, batch_size=16, prog_bar=True)
                          
                                    
    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        reconstruction, z_mu, z_sigma = self.autoencoderkl(images)
        recons_loss = F.l1_loss(images.float(), reconstruction.float())
        self.log("val_loss_d", recons_loss, prog_bar=True)
                              

    def configure_optimizers(self):
        optimizer_g = torch.optim.Adam(self.autoencoderkl.parameters(), lr=5e-5)
        optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)
        return [optimizer_g, optimizer_d], []
                          


## Train Autoencoder

In [8]:
n_epochs = 75
val_interval = 10

                          
# initialise the LightningModule
ae_net = AutoEnconder()

# set up checkpoints

checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename="best_metric_model")

                         
# initialise Lightning's trainer.
trainer = pl.Trainer(devices=1,
                     max_epochs=n_epochs,
                     check_val_every_n_epoch=val_interval,
                     callbacks=checkpoint_callback,
                     default_root_dir=root_dir)

# train
trainer.fit(ae_net)

The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.


2023-04-10 18:37:58,137 - GPU available: False, used: False
2023-04-10 18:37:58,138 - TPU available: False, using: 0 TPU cores
2023-04-10 18:37:58,140 - IPU available: False, using: 0 IPUs
2023-04-10 18:37:58,141 - HPU available: False, using: 0 HPUs


Can't initialize NVML
MedNIST.tar.gz: 59.0MB [00:02, 23.5MB/s]                                                                                                                                       

2023-04-10 18:38:00,799 - INFO - Downloaded: /tmp/tmpa48mavq5/MedNIST.tar.gz





2023-04-10 18:38:00,916 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-04-10 18:38:00,918 - INFO - Writing into directory: /tmp/tmpa48mavq5.


Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47164/47164 [00:45<00:00, 1047.66it/s]


2023-04-10 18:38:55,238 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-04-10 18:38:55,240 - INFO - File exists: /tmp/tmpa48mavq5/MedNIST.tar.gz, skipped downloading.
2023-04-10 18:38:55,242 - INFO - Non-empty folder exists in /tmp/tmpa48mavq5/MedNIST, skipped extracting.


Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5895/5895 [00:05<00:00, 1061.53it/s]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7991/7991 [00:12<00:00, 637.95it/s]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 972/972 [00:02<00:00, 414.79it/s]

2023-04-10 18:39:16,217 - Missing logger folder: /tmp/tmpa48mavq5/lightning_logs



Checkpoint directory /tmp/tmpa48mavq5 exists and is not empty.


2023-04-10 18:39:16,255 - 
  | Name            | Type                 | Params
---------------------------------------------------------
0 | autoencoderkl   | AutoencoderKL        | 75.1 M
1 | discriminator   | PatchDiscriminator   | 2.8 M 
2 | perceptual_loss | PerceptualLoss       | 2.5 M 
3 | adv_loss        | PatchAdversarialLoss | 0     
---------------------------------------------------------
77.8 M    Trainable params
2.5 M     Non-trainable params
80.3 M    Total params
321.225   Total estimated model params size (MB)
Sanity Checking DataLoader 0:  50%|███████████████████████████████████████████████████████                                                       | 1/2 [00:04<00:04,  4.34s/it]

Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 0:   2%|█▎                                                                                   | 8/500 [01:59<2:02:20, 14.92s/it, v_num=0, recons_loss=0.212, loss_g=0.215]

Detected KeyboardInterrupt, attempting graceful shutdown...


## Rescaling factor

As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) became crucial in image-to-image translation models (such as the ones used for super-resolution). For this reason, we will compute the component-wise standard deviation to be used as scaling factor.

In [9]:
train_loader = ae_net.train_dataloader()
check_data = first(train_loader)
z = ae_net.autoencoderkl.train(mode=False).encode_stage_2_inputs(check_data["image"].to(ae_net.device))
print(f"Scaling factor set to {1/torch.std(z)}")
scale_factor = 1 / torch.std(z)

Scaling factor set to 0.3829643428325653


## Define the LightningModule for DiffusionModelUnet (transforms, network, loaders, etc)
The LightningModule contains a refactoring of your training code. The following module is a reformatiing of the code in 2d_stable_diffusion_v2_super_resolution.

In [10]:
class DiffusionUNET(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.data_dir = root_dir
        self.unet = DiffusionModelUNet(
                                spatial_dims=2,
                                in_channels=4,
                                out_channels=3,
                                num_res_blocks=2,
                                num_channels=(256, 256, 512, 1024),
                                attention_levels=(False, False, True, True),
                                num_head_channels=(0, 0, 64, 64),
                            )
        self.max_noise_level = 350
        self.scheduler = DDPMScheduler(num_train_timesteps=1000, 
                          beta_schedule="linear",
                          beta_start=0.0015,
                          beta_end=0.0195)
        self.z = ae_net.autoencoderkl.train(mode=False)


    def forward(self, x, timesteps, low_res_timesteps):
        return self.unet(x=x, 
                         timesteps=timesteps,
                         class_labels=low_res_timesteps)
    
    
    def prepare_data(self):
        train_transforms = get_train_transforms()
        val_transforms = get_val_transforms()
        
        train_data = MedNISTDataset(root_dir=self.data_dir, section="training",
                                    download=True, seed=0)
        train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "HeadCT"]
        val_data = MedNISTDataset(root_dir=self.data_dir, section="validation",
                                  download=True, seed=0)
        val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "HeadCT"]
        
        self.train_ds = CacheDataset(data=train_datalist, transform=train_transforms)
        self.val_ds = CacheDataset(data=val_datalist, transform=val_transforms)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=16, shuffle=True,
                          num_workers=4, persistent_workers=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=16, shuffle=True,
                          num_workers=4)
    
    def _calculate_loss(self, batch, batch_idx):
        images = batch["image"]
        low_res_image = batch["low_res_image"]       
        latent = self.z.encode_stage_2_inputs(images) * scale_factor
        latent =  latent.detach() # avoid adding this to graph.
        optimizer = self.optimizers()
        
        # Noise augmentation
        noise = torch.randn_like(latent)
        low_res_noise = torch.randn_like(low_res_image)
        timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (latent.shape[0],),
                                  device=latent.device).long()
        low_res_timesteps = torch.randint(
            0, self.max_noise_level, (low_res_image.shape[0],), device=latent.device
        ).long()

        noisy_latent = self.scheduler.add_noise(original_samples=latent, 
                                           noise=noise, timesteps=timesteps)
        noisy_low_res_image = self.scheduler.add_noise(
            original_samples=low_res_image, noise=low_res_noise, 
            timesteps=low_res_timesteps
        )

        latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)

        noise_pred = self.forward(latent_model_input, timesteps, low_res_timesteps)
        loss = F.mse_loss(noise_pred.float(), noise.float())
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, batch_idx)
        self.log("train_loss", loss, batch_size=16, prog_bar=True)
        return loss

        
    def validation_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, batch_idx)
        self.log("val_loss", loss, batch_size=16, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.unet.parameters(), lr=5e-5)
        return optimizer

## Train Diffusion Model

In order to train the diffusion model to perform super-resolution, we will need to concatenate the latent representation of the high-resolution with the low-resolution image. For this, we create a Diffusion model with `in_channels=4`. Since only the outputted latent representation is interesting, we set `out_channels=3`.

As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler low_res_scheduler to add this noise, with the t step defining the signal-to-noise ratio and use the t value to condition the diffusion model (inputted using class_labels argument).

In [None]:
n_epochs = 200
val_interval = 20

                          
# initialise the LightningModule
d_net = DiffusionUNET()

# set up checkpoints

checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename="best_metric_model_dunet")

                         
# initialise Lightning's trainer.
trainer = pl.Trainer(devices=1,
                     max_epochs=n_epochs,
                     check_val_every_n_epoch=val_interval,
                     callbacks=checkpoint_callback,
                     default_root_dir=root_dir)

# train
trainer.fit(d_net)

2023-04-10 18:41:46,760 - GPU available: False, used: False
2023-04-10 18:41:46,761 - TPU available: False, using: 0 TPU cores
2023-04-10 18:41:46,762 - IPU available: False, using: 0 IPUs
2023-04-10 18:41:46,762 - HPU available: False, using: 0 HPUs
2023-04-10 18:41:46,869 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-04-10 18:41:46,869 - INFO - File exists: /tmp/tmpa48mavq5/MedNIST.tar.gz, skipped downloading.
2023-04-10 18:41:46,870 - INFO - Non-empty folder exists in /tmp/tmpa48mavq5/MedNIST, skipped extracting.



Loading dataset:   0%|                                                                                                                               | 0/47164 [00:00<?, ?it/s][A
Loading dataset:   0%|▎                                                                                                                  | 137/47164 [00:00<00:34, 1366.48it/s][A
Loading dataset:   1%|▋                                                                                                                  | 274/47164 [00:00<00:41, 1135.35it/s][A
Loading dataset:   1%|▉                                                                                                                  | 390/47164 [00:00<00:41, 1134.05it/s][A
Loading dataset:   1%|█▏                                                                                                                 | 505/47164 [00:00<00:44, 1039.25it/s][A
Loading dataset:   1%|█▍                                                                                

Loading dataset:  11%|████████████                                                                                                      | 4981/47164 [00:04<00:36, 1142.44it/s][A
Loading dataset:  11%|████████████▎                                                                                                     | 5097/47164 [00:05<00:39, 1059.14it/s][A
Loading dataset:  11%|████████████▋                                                                                                      | 5205/47164 [00:05<00:42, 998.35it/s][A
Loading dataset:  11%|████████████▉                                                                                                      | 5307/47164 [00:05<00:43, 973.33it/s][A
Loading dataset:  11%|█████████████▏                                                                                                     | 5406/47164 [00:05<00:44, 945.74it/s][A
Loading dataset:  12%|█████████████▎                                                                     

Loading dataset:  21%|████████████████████████                                                                                          | 9939/47164 [00:09<00:35, 1035.99it/s][A
Loading dataset:  21%|████████████████████████▎                                                                                         | 10045/47164 [00:10<00:37, 996.80it/s][A
Loading dataset:  22%|████████████████████████▌                                                                                         | 10146/47164 [00:10<00:39, 937.82it/s][A
Loading dataset:  22%|████████████████████████▊                                                                                         | 10242/47164 [00:10<00:39, 935.56it/s][A
Loading dataset:  22%|█████████████████████████                                                                                         | 10354/47164 [00:10<00:37, 986.77it/s][A
Loading dataset:  22%|█████████████████████████▎                                                         

Loading dataset:  31%|███████████████████████████████████▍                                                                             | 14814/47164 [00:14<00:32, 1000.96it/s][A
Loading dataset:  32%|███████████████████████████████████▊                                                                             | 14944/47164 [00:15<00:29, 1084.02it/s][A
Loading dataset:  32%|████████████████████████████████████                                                                             | 15057/47164 [00:15<00:30, 1054.98it/s][A
Loading dataset:  32%|████████████████████████████████████▎                                                                            | 15166/47164 [00:15<00:31, 1023.10it/s][A
Loading dataset:  32%|████████████████████████████████████▉                                                                             | 15271/47164 [00:15<00:32, 972.06it/s][A
Loading dataset:  33%|█████████████████████████████████████▏                                             

Loading dataset:  42%|███████████████████████████████████████████████▏                                                                 | 19705/47164 [00:19<00:27, 1001.61it/s][A
Loading dataset:  42%|███████████████████████████████████████████████▉                                                                  | 19807/47164 [00:19<00:27, 978.43it/s][A
Loading dataset:  42%|████████████████████████████████████████████████                                                                  | 19906/47164 [00:20<00:27, 981.50it/s][A
Loading dataset:  42%|████████████████████████████████████████████████▎                                                                 | 20005/47164 [00:20<00:29, 911.79it/s][A
Loading dataset:  43%|████████████████████████████████████████████████▌                                                                 | 20111/47164 [00:20<00:28, 952.39it/s][A
Loading dataset:  43%|████████████████████████████████████████████████▍                                  

Loading dataset:  53%|███████████████████████████████████████████████████████████▍                                                     | 24825/47164 [00:24<00:20, 1099.06it/s][A
Loading dataset:  53%|███████████████████████████████████████████████████████████▊                                                     | 24957/47164 [00:24<00:19, 1160.15it/s][A
Loading dataset:  53%|████████████████████████████████████████████████████████████                                                     | 25075/47164 [00:25<00:20, 1070.05it/s][A
Loading dataset:  53%|████████████████████████████████████████████████████████████▎                                                    | 25184/47164 [00:25<00:20, 1065.39it/s][A
Loading dataset:  54%|█████████████████████████████████████████████████████████████▏                                                    | 25292/47164 [00:25<00:22, 965.64it/s][A
Loading dataset:  54%|█████████████████████████████████████████████████████████████▎                     

Loading dataset:  65%|█████████████████████████████████████████████████████████████████████████▋                                       | 30749/47164 [00:29<00:11, 1424.21it/s][A
Loading dataset:  66%|██████████████████████████████████████████████████████████████████████████                                       | 30902/47164 [00:29<00:11, 1453.36it/s][A
Loading dataset:  66%|██████████████████████████████████████████████████████████████████████████▍                                      | 31048/47164 [00:29<00:11, 1401.46it/s][A
Loading dataset:  66%|██████████████████████████████████████████████████████████████████████████▋                                      | 31189/47164 [00:29<00:12, 1315.82it/s][A
Loading dataset:  66%|███████████████████████████████████████████████████████████████████████████                                      | 31322/47164 [00:30<00:12, 1248.01it/s][A
Loading dataset:  67%|███████████████████████████████████████████████████████████████████████████▎       

Loading dataset:  78%|███████████████████████████████████████████████████████████████████████████████████████▋                         | 36587/47164 [00:34<00:08, 1279.91it/s][A
Loading dataset:  78%|███████████████████████████████████████████████████████████████████████████████████████▉                         | 36717/47164 [00:34<00:08, 1185.00it/s][A
Loading dataset:  78%|████████████████████████████████████████████████████████████████████████████████████████▎                        | 36838/47164 [00:35<00:08, 1167.00it/s][A
Loading dataset:  78%|████████████████████████████████████████████████████████████████████████████████████████▌                        | 36957/47164 [00:35<00:09, 1059.20it/s][A
Loading dataset:  79%|████████████████████████████████████████████████████████████████████████████████████████▊                        | 37075/47164 [00:35<00:09, 1090.10it/s][A
Loading dataset:  79%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  90%|█████████████████████████████████████████████████████████████████████████████████████████████████████▊           | 42515/47164 [00:39<00:03, 1353.30it/s][A
Loading dataset:  90%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏          | 42653/47164 [00:39<00:03, 1359.09it/s][A
Loading dataset:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████████▌          | 42799/47164 [00:39<00:03, 1388.21it/s][A
Loading dataset:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████████▉          | 42961/47164 [00:39<00:02, 1454.33it/s][A
Loading dataset:  91%|███████████████████████████████████████████████████████████████████████████████████████████████████████▎         | 43122/47164 [00:40<00:02, 1498.44it/s][A
Loading dataset:  92%|███████████████████████████████████████████████████████████████████████████████████

2023-04-10 18:42:30,195 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2023-04-10 18:42:30,196 - INFO - File exists: /tmp/tmpa48mavq5/MedNIST.tar.gz, skipped downloading.
2023-04-10 18:42:30,197 - INFO - Non-empty folder exists in /tmp/tmpa48mavq5/MedNIST, skipped extracting.



Loading dataset:   0%|                                                                                                                                | 0/5895 [00:00<?, ?it/s][A
Loading dataset:   3%|██▉                                                                                                                 | 149/5895 [00:00<00:03, 1489.61it/s][A
Loading dataset:   5%|██████                                                                                                              | 311/5895 [00:00<00:03, 1564.14it/s][A
Loading dataset:   8%|█████████▏                                                                                                          | 468/5895 [00:00<00:03, 1424.68it/s][A
Loading dataset:  10%|████████████                                                                                                        | 612/5895 [00:00<00:04, 1206.41it/s][A
Loading dataset:  13%|██████████████▌                                                                   

Loading dataset:  89%|███████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 5248/5895 [00:04<00:00, 989.63it/s][A
Loading dataset:  91%|████████████████████████████████████████████████████████████████████████████████████████████████████████▋          | 5365/5895 [00:05<00:00, 1035.91it/s][A
Loading dataset:  93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▋        | 5470/5895 [00:05<00:00, 1025.43it/s][A
Loading dataset:  95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▋      | 5574/5895 [00:05<00:00, 962.79it/s][A
Loading dataset:  96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▌    | 5672/5895 [00:05<00:00, 966.62it/s][A
Loading dataset:  98%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  33%|██████████████████████████████████████▋                                                                             | 2666/7991 [00:04<00:07, 714.66it/s][A
Loading dataset:  34%|███████████████████████████████████████▊                                                                            | 2739/7991 [00:04<00:07, 691.39it/s][A
Loading dataset:  35%|████████████████████████████████████████▊                                                                           | 2813/7991 [00:04<00:07, 700.64it/s][A
Loading dataset:  36%|█████████████████████████████████████████▉                                                                          | 2885/7991 [00:04<00:07, 704.97it/s][A
Loading dataset:  37%|██████████████████████████████████████████▉                                                                         | 2956/7991 [00:04<00:07, 698.83it/s][A
Loading dataset:  38%|███████████████████████████████████████████▉                                       

Loading dataset:  72%|███████████████████████████████████████████████████████████████████████████████████▋                                | 5763/7991 [00:09<00:03, 645.15it/s][A
Loading dataset:  73%|████████████████████████████████████████████████████████████████████████████████████▋                               | 5834/7991 [00:09<00:03, 659.74it/s][A
Loading dataset:  74%|█████████████████████████████████████████████████████████████████████████████████████▋                              | 5901/7991 [00:09<00:03, 618.71it/s][A
Loading dataset:  75%|██████████████████████████████████████████████████████████████████████████████████████▌                             | 5967/7991 [00:09<00:03, 629.02it/s][A
Loading dataset:  76%|███████████████████████████████████████████████████████████████████████████████████████▊                            | 6047/7991 [00:09<00:02, 675.49it/s][A
Loading dataset:  77%|███████████████████████████████████████████████████████████████████████████████████

Loading dataset:  56%|██████████████████████████████████████████████████████████████████▋                                                   | 549/972 [00:01<00:01, 385.14it/s][A
Loading dataset:  60%|███████████████████████████████████████████████████████████████████████▍                                              | 588/972 [00:01<00:01, 376.97it/s][A
Loading dataset:  65%|████████████████████████████████████████████████████████████████████████████▍                                         | 630/972 [00:01<00:00, 387.53it/s][A
Loading dataset:  69%|█████████████████████████████████████████████████████████████████████████████████▌                                    | 672/972 [00:01<00:00, 396.56it/s][A
Loading dataset:  73%|██████████████████████████████████████████████████████████████████████████████████████▍                               | 712/972 [00:01<00:00, 393.11it/s][A
Loading dataset:  78%|███████████████████████████████████████████████████████████████████████████████████

2023-04-10 18:42:51,417 - 
  | Name      | Type               | Params
-------------------------------------------------
0 | unet      | DiffusionModelUNet | 266 M 
1 | scheduler | DDPMScheduler      | 0     
2 | z         | AutoencoderKL      | 75.1 M
-------------------------------------------------
342 M     Trainable params
0         Non-trainable params
342 M     Total params
1,368.189 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]

Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Epoch 0:   0%|▍                                                                                                   | 2/500 [00:16<1:08:40,  8.27s/it, v_num=1, train_loss=0.980]

### Plotting sampling example

In [None]:
# Sampling image during training
num_samples = 3
val_loader = d_net.val_dataloader()
check_data = first(val_loader)
images = check_data["image"]
sampling_image = check_data["low_res_image"][:num_samples]

In [None]:
latents = torch.randn((num_samples, 3, 16, 16)).to(images.device)
low_res_noise = torch.randn((num_samples, 1, 16, 16)).to(images.device)
noise_level = 10
noise_level = torch.Tensor((noise_level,)).long().to(images.device)
scheduler = d_net.scheduler
noisy_low_res_image = scheduler.add_noise(original_samples=sampling_image, 
                                          noise=low_res_noise,
                                          timesteps=torch.Tensor((noise_level,)).long())

scheduler.set_timesteps(num_inference_steps=1000)
for t in tqdm(scheduler.timesteps, ncols=110):
    with torch.no_grad():
        with autocast(enabled=True):
            latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)
            noise_pred = d_net.forward(x=latent_model_input,
                                       timesteps=torch.Tensor((t,)),
                                       low_res_timesteps=noise_level)

        # 2. compute previous image: x_t -> x_t-1
        latents, _ = scheduler.step(noise_pred, t, latents)

with torch.no_grad():
    decoded = ae_net.autoencoderkl.decode_stage_2_outputs(latents / scale_factor)

In [None]:
low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode="bicubic")
fig, axs = plt.subplots(num_samples, 3, figsize=(8, 8))
axs[0, 0].set_title("Original image")
axs[0, 1].set_title("Low-resolution Image")
axs[0, 2].set_title("Outputted image")
for i in range(0, num_samples):
    axs[i, 0].imshow(images[i, 0].cpu(), vmin=0, vmax=1, cmap="gray")
    axs[i, 0].axis("off")
    axs[i, 1].imshow(low_res_bicubic[i, 0].cpu(), vmin=0, vmax=1, cmap="gray")
    axs[i, 1].axis("off")
    axs[i, 2].imshow(decoded[i, 0].cpu(), vmin=0, vmax=1, cmap="gray")
    axs[i, 2].axis("off")
plt.tight_layout()

### Clean-up data directory

In [None]:
if directory is None:
    shutil.rmtree(root_dir)