In [5]:
import csv, logging, os, sys, torch, time

import torch.nn.functional as F

from torch.amp                import autocast, GradScaler
from torch.utils.tensorboard  import SummaryWriter
from tqdm                     import tqdm
from monai.networks.nets      import AutoencoderKL
from generative.networks.nets import PatchDiscriminator
from monai.losses             import PatchAdversarialLoss, PerceptualLoss
from torch.nn                 import MSELoss
from config_ldm_ddpm          import *
from dataset                  import *
from utils                    import *

In [4]:
import random, shutil
import numpy as np

from torchvision        import transforms
from torch.utils.data   import DataLoader, Dataset
from PIL                import Image, ImageEnhance
from skimage.morphology import disk, erosion, dilation, opening, closing
from custom_transforms  import *

In [6]:
os.chdir('C:/Users/Talha/OneDrive - Higher Education Commission/Desktop/Dr. Hassan Summer Work/Datasets/Kvasir-SEG')
base_dir = os.getcwd()
split_ratios = (600, 200, 200)

base_dir, split_ratios

('C:\\Users\\Talha\\OneDrive - Higher Education Commission\\Desktop\\Dr. Hassan Summer Work\\Datasets\\Kvasir-SEG',
 (600, 200, 200))

In [7]:
# Dataloaders
train_loader  = get_dataloaders(
    base_dir, split_ratio=SPLIT_RATIOS, split="train", 
    trainsize=TRAINSIZE, batch_size=BATCH_SIZE, format=FORMAT
)
val_loader    = get_dataloaders(
    base_dir, split_ratio=SPLIT_RATIOS, split="val", 
    trainsize=TRAINSIZE, batch_size=BATCH_SIZE, format=FORMAT
)

Dataset already split into train, val and test directories
Dataset already split into train, val and test directories


In [None]:
batch = next(iter(train_loader))
batch

In [None]:
clean_mask, noisy_mask  = batch['clean_mask'], batch['noisy_mask']
clean_mask.shape, noisy_mask.shape

(torch.Size([4, 1, 256, 256]), torch.Size([4, 1, 256, 256]))

In [10]:
clean_image, noisy_image = batch['clean_image'], batch['noisy_image']
clean_image.shape, noisy_image.shape

(torch.Size([4, 3, 256, 256]), torch.Size([4, 3, 256, 256]))

In [44]:
DAE_IMAGE_PARAMS       = {"spatial_dims"              : 2,
                          "in_channels"               : 3,
                          "latent_channels"           : 4, # (= Z in SDSeg paper)
                          "out_channels"              : 3,
                          "channels"                  : (128, 256, 512, 512),
                          "num_res_blocks"            : 2,
                          "attention_levels"          : (False, False, False, False),
                          "with_encoder_nonlocal_attn": True, # (as per SDSeg paper to ensure middle block of encoder is as required)
                          "with_decoder_nonlocal_attn": True, # (as per SDSeg paper to ensure middle block of decoder is as required)
                          "use_flash_attention"       : True}

autoencoderkl      = AutoencoderKL(**DAE_IMAGE_PARAMS)

In [45]:
latent_image                  = autoencoderkl.encode_stage_2_inputs(clean_image)
latent_image.shape

torch.Size([4, 4, 32, 32])

In [46]:
DAE_MASK_PARAMS        = {"spatial_dims"              : 2,
                          "in_channels"               : 1,
                          "latent_channels"           : 4, # (= Z in SDSeg paper)
                          "out_channels"              : 1,
                          "channels"                  : (128, 256, 512, 512),
                          "num_res_blocks"            : 2,
                          "attention_levels"          : (False, False, False, False),
                          "with_encoder_nonlocal_attn": True, # (as per SDSeg paper to ensure middle block of encoder is as required)
                          "with_decoder_nonlocal_attn": True, # (as per SDSeg paper to ensure middle block of decoder is as required)
                          "use_flash_attention"       : True}

dae_mask      = AutoencoderKL(**DAE_MASK_PARAMS)
latent_mask   = dae_mask.encode_stage_2_inputs(clean_mask)
latent_mask.shape

torch.Size([4, 4, 32, 32])

In [55]:
import torch
from generative.networks.nets import DiffusionModelUNet
from generative.networks.schedulers import DDPMScheduler, DDIMScheduler
from generative.inferers import LatentDiffusionInferer

MODEL_PARAMS = {"spatial_dims"     : 2 if DIMENSION == "2d" else 3,
                "in_channels"      : 8,  # Using latent space input (z = 4 + concatenation), so latent dimensions match autoencoder
                "out_channels"     : 4,  # Latent space output before decoder
                "num_channels"     : (192, 384, 384, 768, 768), # (192, 384, 384, 768, 768)
                "attention_levels" : (True, True, True, True, True),
                "num_res_blocks"   : 2,
                "num_head_channels": 24} 

# Define model
unet = DiffusionModelUNet(**MODEL_PARAMS)




In [56]:
scheduler = (
        DDIMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS, schedule=NOISE_SCHEDULER)
        if SCHEDULER == "DDIM"
        else DDPMScheduler(
            num_train_timesteps=NUM_TRAIN_TIMESTEPS, schedule=NOISE_SCHEDULER
        )
    )
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1)

In [None]:
latent_images               = autoencoderkl.encode_stage_2_inputs(clean_image).to('cpu')
latent_masks                = dae_mask.encode_stage_2_inputs(clean_mask).to('cpu')
noise                       = torch.randn_like(latent_masks).to('cpu') # (B, C, H, W)
timesteps                   = torch.randint(0, scheduler.num_train_timesteps, (latent_masks.size(0),), device = 'cpu').long()
z_T                         = scheduler.add_noise(original_samples = latent_masks, noise = noise, timesteps = timesteps)

#[talha] Make sure z_t is same as the z_t we could have returned from inferer. 
noise_pred                  = inferer(inputs          = clean_mask,
                                      noise             = noise,
                                      diffusion_model   = unet,
                                      timesteps         = timesteps,
                                      autoencoder_model = dae_mask,
                                      condition         = latent_images,
                                      mode              = "concat")

In [58]:
noise_pred.shape

torch.Size([4, 4, 32, 32])

In [59]:
loss_noise             = F.l1_loss(noise_pred.float(), noise.float())
            
alpha_bar_T            = scheduler.alphas_cumprod[timesteps][:, None, None, None]
z_0_pred               = (1 / torch.sqrt(alpha_bar_T)) * (z_T - (torch.sqrt(1 - alpha_bar_T) * noise_pred))
loss_latent            = F.l1_loss(z_0_pred.float(), latent_masks.float())

# Then add both losses and backpropogate.
loss                   = loss_noise + loss_latent