<a href="https://colab.research.google.com/github/NickAI2811/DataToolKit/blob/main/Stable_Diffusion_LNMIIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Topic:** Stable Diffusion on Medical Datasets

**Agenda:**
1.  A gentle introduction to the stable diffusion process (for beginners)
2.  Implementation of stable diffusion on MEDMNIST dataset using Pytorch


**Diffusion Models**
1. Emerging class of generative networks; driven by thermodynamics
2. Generic pipeline has three stages: Forward process, Reverse process and Sampling process

**Why Stable Diffusion ?**
1. Operations (of forward, reverse and sampling steps) done on rich representations in the latent space
2. Consequently, training and sampling remains reliable, without divergence, mode collapse and extreme computation costs









In [None]:
#1 Install necessary packages

!pip install torch                                       # the core deep learning engine (tensors, autograd, neural networks, GPU support)
!pip install torchvision                                 # computer-vision add-on for PyTorch (datasets, image transforms, pretrained vision models)
!pip install medmnist                                    # to download medmnist dataset
!pip install tqdm                                        # to display live progress bars for loops (training, data loading, downloads, etc.)
!pip install pillow                                      # for loading, saving, and preprocessing images

In [None]:
#2 Import necessary packages

import torch
torch.set_printoptions(sci_mode=False, precision=4)                 # Setting printoptions, optional step

import numpy as np
np.set_printoptions(suppress=True, precision=4)                     # Setting printoptions, optional step

import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, utils                           # mostly focused on visualization, image manipulation, and debugging
from medmnist import INFO                                           # a dictionary (with dataset names as keys) object to store the metadata
from tqdm import tqdm                                               # package to visualize the progress
import os
import random
import medmnist

In [None]:
#3 Create a separate directory for the results

SAVE_DIR = "./results"
os.makedirs(SAVE_DIR, exist_ok=True)

# Setting the seed values

torch.manual_seed(42)
random.seed(42)

In [None]:
#4 Setting hyperparameters and computation device

TIMESTEPS = 1000                                                                # total number of discrete steps in the forward and reverse process.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")            # Set-up the computation device

BATCH_SIZE = 64
LR = 0.00001
EPOCHS = 2

In [None]:
#5 Load the MNIST ChestMNIST dataset

info = INFO['chestmnist']
DataClass = getattr(medmnist, info['python_class'])            # dynamically fetches the dataset class from the MedMNIST module using the class name

In [None]:
#6 Define the transformations to be applied to the image; additional transformations may be explored as a homework

transform = transforms.Compose([
    transforms.ToTensor(),                       # convert to tensor
    transforms.Normalize([0.5], [0.5])           # normalize to mean = 0.5, std = 0.5
])

In [None]:
#7 Get the training dataset and initialize the trainloader

train_dataset = DataClass(split='train', transform=transform, download=True)
print("type(train_dataset): ", type(train_dataset), "\n")
print("train_dataset: \n\n", train_dataset, "\n")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
#8 Exploring a single batch of scans

images, labels = next(iter(train_loader))
images = images.squeeze().numpy()
labels = labels.squeeze().numpy()

import matplotlib.pyplot as plt

subset = images[0:16]                                                               # plot first 16 images

plt.figure(figsize=(6,6))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(subset[i].squeeze(), cmap='gray')
    plt.axis('off')

plt.tight_layout()
plt.show()

**Generic Pipeline of Stable Diffusion**

1. Encoder obtains the latent embeddings $z$ of input image $x$ at timestep
\begin{equation}
x → Encoder → z
\end{equation}
2. Forward Process:
\begin{equation}
z_{t=0} → Forward → z_{t=T}
\end{equation}
3. Reverse (Sampling) Process:
\begin{equation}
z_{t=T} → Reverse (Sampling)→ \hat{z}_{t=0}
\end{equation}
4. Decoder reconstructs the input image $x$ as
\begin{equation}
\hat{z} → Decoder → \hat{x}
\end{equation}

In [None]:
# Creating the encoder and decoder network to obtain the latent embeddings (stable diffusion works on latent space)
# Encoder and Decoder can also be one of the existing networks such as ResNet etc.

# original image (x) ---Encoder--> Embeddings (z) ---Decoder--> Reconstructed Image (x')

class Encoder(nn.Module):
    def __init__(self, latent_dim=4):                                            # Setting the latent channels to be 4
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, latent_dim, 4, 2, 1)
        )

    def forward(self, x):
        return self.net(x)                                                        # shape of latent embeddings --> (batch_size, 4, 7, 7)

class Decoder(nn.Module):
    def __init__(self, latent_dim=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)                                                       # shape of the reconstructed images --> (batch_size, 1, 28, 28)

In [None]:
# Getting the model summary of the encoder-decoder networks

!pip install torchsummary
from torchsummary import summary

latent_dim = 4                                                                   # Number of latent channels in the encoder-decoder network

encoder = Encoder(latent_dim).to(DEVICE)
decoder = Decoder(latent_dim).to(DEVICE)

print("Encoder Summary:")
summary(encoder, input_size=(1, 28, 28))                                         # (batch_size, 1, 28, 28) ---Encoder--> (batch_size, 4, 7, 7)

print("Decoder Summary:")
summary(decoder, input_size=(latent_dim, 7, 7))                                  # (batch_size, 4, 7, 7) ---Decoder--> (batch_size, 1, 28, 28)

In [None]:
# Pretrain the encoder and decoder networks on the original dataset images (x)

encoder = Encoder(latent_dim=4).to(DEVICE)
decoder = Decoder(latent_dim=4).to(DEVICE)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=1e-3
)

ED_EPOCHS = 1

for epoch in range(ED_EPOCHS):
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{ED_EPOCHS}]", leave=True)

    for x, _ in pbar:  # x in [-1, 1]
        x = x.to(DEVICE)

        z = encoder(x)
        x_hat = decoder(z)

        loss = criterion(x_hat, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        pbar.set_postfix(AE_Loss=f"{loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{ED_EPOCHS}] | Avg AE Loss: {avg_loss:.4f}")

In [None]:
# Save the encoder and decoder network weights

torch.save(encoder.state_dict(), "encoder_pretrained.pth")
torch.save(decoder.state_dict(), "decoder_pretrained.pth")

In [None]:
# Freeze the encoder-decoder network

encoder.load_state_dict(torch.load("encoder_pretrained.pth"))
decoder.load_state_dict(torch.load("decoder_pretrained.pth"))

encoder.eval()
decoder.eval()

for p in encoder.parameters():
    p.requires_grad = False

for p in decoder.parameters():
    p.requires_grad = False

**Forward Process**

1.  The forward diffusion process perturbs a latent representation $z_0$ to $z_{t=1}^{T}$ as timesteps progress.
2.  A forward transition $p(z_t \mid z_{t-1})$ describes this perturbation, where $\epsilon_t$ is the noise added at timestep $(t)$.

\begin{equation}
p(z_T \mid z_0)
= p(z_1 \mid z_0)\; \dots\; p(z_t \mid z_{t-1})\; \dots\; p(z_T \mid z_{T-1})
= \prod_{t=1}^{T} p(z_t \mid z_{t-1})
\end{equation}

3.  Through multiple timesteps, the original latent distribution $p(z_0)$ is eventually perturbed to a tractable terminal distribution $p(z_T)$.


**Implementation of Forward Process**

1. Typically, stable diffusion models are implemented as Denoising Diffusion Probabilistic Models (DDPMs).
2. The forward process gradually adds Gaussian noise to an image over $T$ timesteps and therefore, are well-grounded probabilistically:

\begin{equation}
p(z_t \mid z_{t-1}) =
\mathcal{N}\Big(
z_t;\; \sqrt{1-\beta_t}\, z_{t-1},\; \beta_t I
\Big)
\end{equation}

where:   
*   $z_t$ is noisy latent representation at timestep $(t)$
*   $\beta_t$ is noise variance added at step $(t)$

3.  The above conditional distribution implies that $z_{t}$ can be sampled as

\begin{equation}
z_t
=
\sqrt{1-\beta_t}\, z_{t-1}
+
\sqrt{\beta_t}\, \epsilon_t,
\qquad
\epsilon_t \sim \mathcal{N}(0, I)
\end{equation}

where:
*  $\sqrt{1-\beta_t}\, z_{t-1}$ is the deterministic / mean component
*  $\sqrt{\beta_t}\, \epsilon_t$ is the noise
*  $\Sigma_t = \beta_t I$ is the covariance
4. Stable diffusion uses a cosine noise schedule $\beta_{t}$ given by
\begin{equation}
\beta_t
=
\beta_{\min}
+
\frac{1}{2}
\left(\beta_{\max} - \beta_{\min}\right)
\left(
1 - \cos\left(\frac{\pi\, t}{T - 1}\right)
\right),
\quad t = 0, 1, \dots, T-1
\end{equation}
where $\beta_{max}$ and $\beta_{min}$ are hyperparameters
5. Fixed and non-learnable; carefully chosen learning schedules balance exploding noise and vanishing signal

In [None]:
#10 Setting the noise variance to be added at each time step.

import math

def cosine_beta_schedule(T, beta_min=1e-4, beta_max=2e-2):
    t = torch.linspace(0, T - 1, T)
    betas = beta_min + 0.5 * (beta_max - beta_min) * (1 - torch.cos(math.pi * t / (T - 1)))
    return betas

# Computing the noise variance (beta) values

betas = cosine_beta_schedule(TIMESTEPS).to(DEVICE)
print("betas.size(): \t", betas.size(), "\n")

# Plotting noise variance (beta) values against time-steps

x_values = np.linspace(1,TIMESTEPS,TIMESTEPS)
y_values = betas
y_values = y_values.cpu()
import matplotlib.pyplot as plt
plt.plot(x_values, y_values)
plt.xlabel("Number of time steps")
plt.ylabel("Noise variance (beta) added")
plt.grid(True)
plt.title("Variation of noise variance (beta) against time-steps", y=-0.20)
plt.show()

We define:

\begin{equation}
\alpha_t = 1 - \beta_t
\end{equation}

where:  

*  $\beta_t$ is the variance of the Gaussian noise added at timestep $(t)$
*  $\alpha_t$ represents the fraction of the original signal retained at that timestep $(t)$

Using the above substitution, the forward process can be written as:

\begin{equation}
z_t = \sqrt{\alpha_t} \, z_{t-1} + \sqrt{\beta_t} \, \epsilon_{t}, \quad \epsilon_{t} \sim \mathcal{N}(0,I)
\end{equation}

Impact of $\alpha_{t}$ on the noising process:
1.   Early timesteps: $\alpha_t \approx 1$ → most of the latent representation is preserved
2.   Later timesteps:  $\alpha_t < 1$ → more noise is added



In [None]:
#11 Computing alpha values from beta values

alphas = 1 - betas
print("alphas.size(): \t", alphas.size(), "\n")

# Plotting noise variance (alpha) values against time-steps

x_values = np.linspace(1,TIMESTEPS,TIMESTEPS)
y_values = alphas
y_values = y_values.cpu()
import matplotlib.pyplot as plt
plt.plot(x_values, y_values)
plt.xlabel("Number of time steps")
plt.ylabel("Noise variance (alpha) added")
plt.grid(True)
plt.title("Variation of noise variance (alpha) against time-steps", y=-0.20)
plt.show()

The cumulative product of alphas gives the fraction of the original latent representation remaining after $(t)$ steps:

\begin{equation}
\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s
\end{equation}

Then, $z_t$ can be sampled directly from $z_0$ as:

\begin{equation}
z_t = \sqrt{\bar{\alpha}_t} \, z_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon_{t}
\end{equation}

In [None]:
#12 Redefining noise variance added at each step

alphas_cumprod = torch.cumprod(alphas, dim=0)                           # Returns the cumulative product of elements in the input vector

print("alphas_cumprod.size(): \t", alphas_cumprod.size(), "\n")

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

In [None]:
#13 Generate a noisy latent embedding at timestep t using the forward diffusion formula

def q_sample(z, t, noise):
    sqrt_a = sqrt_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_one_minus_a = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    return sqrt_a * z + sqrt_one_minus_a * noise

**Implementation of the Reverse Process**

**Reverse Process**
1.  Trains a denoising network to remove the noise added in the forward process
2.  Specifically, the reverse process moves on the chain in the opposite direction and iteratively removes noise between two consecutive timesteps as $(t)$ decreases from $T$ to $0$

\begin{equation}
p_{\theta}(z_{0})
= p_{\theta}(z_{T})\, p_{\theta}(z_{T-1}\mid z_{T}) \;\dots\; p_{\theta}(z_{t-1}\mid z_{t}) \;\dots\; p_{\theta}(z_{0}\mid _{1}) \\
= p_{\theta}(z_{T}) \prod_{t=1}^{T} p_{\theta}(z_{t-1}\mid z_{t})
\end{equation}

3.  The reverse process in the above step is modelled as
\begin{equation}
 p_{\theta}(z_{t-1}|z_{t}) = \mathcal{N}(z_{t-1}; \mu_{\theta}(z_{t},t), \Sigma_{\theta}(z_{t},t))
 \end{equation}

In this implementation, UNet is used as the denoising network

1. UNet combined multiscale information through skip connections; denoising becomes easy as the network can propagate low-level and high-level features
2. UNet is parameterized using $\theta$ described above.



In [None]:
#14 Model the noise added in the forward process using UNet architecture; UNet acts as a denoising model

class SimpleUNet(nn.Module):
    def __init__(self, in_ch=latent_dim, base_ch=64, time_dim=128):
        super().__init__()

        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim)                    # maps a scalar diffusion timestep into a high-dimensional learned embedding
        )

        self.enc1 = Block(in_ch, base_ch, time_dim)          # Encoder blocks inject time information to condition feature extraction
        self.enc2 = Block(base_ch, base_ch * 2, time_dim)
        self.pool = nn.AvgPool2d(2)

        self.bot = Block(base_ch * 2, base_ch * 2, time_dim) # captures the richest representations (global structure + time step information)

        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec2 = Block(base_ch * 4, base_ch, time_dim)
        self.dec1 = Block(base_ch * 2, base_ch, time_dim)    # decoder blocks use the same time conditioning to guide noise removal at that timestep

        self.out = nn.Conv2d(base_ch, in_ch, 1)

    def forward(self, x, t):                                 # # Input: embeddings (batch_size, latent_dim, H, W) and t (batch_size, 1)

        t = self.time_mlp(t.view(-1, 1))                     # (batch_size, 1) --> (batch_size, time_dim)

        e1 = self.enc1(x, t)                                 # (batch_size, latent_dim, H, W) --> (batch_size, base_ch, H, W)
        e2 = self.enc2(self.pool(e1), t)                     # (batch_size, base_ch, H, W) --> (batch_size, base_ch, H/2, W/2) --> (batch_size, base_ch*2, H/2, W/2)

        b = self.bot(self.pool(e2), t)                       # (batch_size, base_ch*2, H/2, W/2) --> (batch_size, base_ch*2, H/4, W/4) --> (batch_size, base_ch*2, H/4, W/4)

        d2 = self.dec2(torch.cat([F.interpolate(b, size=e2.shape[-2:], mode='nearest'), e2], dim=1), t)  # (batch_size, base_ch*2, H/4, W/4) --> (batch_size, base_ch*2, H/2, W/2) --> (batch_size, base_ch, H/2, W/2)
        d1 = self.dec1(torch.cat([F.interpolate(d2, size=e1.shape[-2:], mode='nearest'), e1], dim=1), t) # (batch_size, base_ch, H/2, W/2) --> (batch_size, base_ch, H, W) --> (batch_size, base_ch, H, W)

        return self.out(d1) # (batch_size, base_ch, H, W) --> (batch_size, latent_dim, H, W)


**Time-Injection**

1.    Time injection is the process of feeding the timestep $(t)$ into the denoising network so it knows the noise level at that stage of the reverse diffusion process.
2.    Without this, the model cannot distinguish early noisy steps from later ones; reverse diffusion becomes unstable and inconsistent
3.  Can be chosen freely; larger values give richer embeddings but increase the number of parameters in the Linear layer.


In [None]:
# Inject timestep information into every feature, letting the network know which diffusion step it is processing, without adding extra noise.

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time = nn.Linear(time_dim, out_ch)
        self.act = nn.SiLU()

    def forward(self, x, t):
        h = self.act(self.conv1(x))                       # (batch_size, in_ch, H, W) --> (batch_size, out_ch, H, W)
        h = h + self.time(t).unsqueeze(-1).unsqueeze(-1)  # (batch_size, time_dim) --> (batch_size, out_ch) --> (batch_size, out_ch, 1, 1) --> (batch_size, out_ch, H, W)
        h = self.act(self.conv2(h))                       # (batch_size, out_ch, H, W) --> (batch_size, out_ch, H, W)
        return h

In [None]:
# Get the model summary for the above UNET model

import torch.nn.functional as F

latent_size = 7                                   # spatial dimension of latent tensor fed into SimpleUnet;  (4, 7, 7) --> (4, 7, 7)
time_dim = 128

model = SimpleUNet(
    in_ch=latent_dim,
    base_ch=64,
    time_dim=128
).to(DEVICE)

class ModelWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        t = torch.zeros(x.size(0), 1, device=x.device)
        return self.model(x, t)

wrapped_model = ModelWrapper(model).to(DEVICE)

print(
    summary(
        wrapped_model,
        input_size=(latent_dim, latent_size, latent_size),
        device=str(DEVICE)
    )
)


 **Sampling Process**
 1.  Leverages the optimized denoising network parameters $\theta^{*}$ to generate novel latent embeddings $\hat{z}_{0}$.
2. Sampling obtains a latent sample $z_{T}$ from the terminal distribution $p(z_{T})$ and then uses the trained network to iteratively remove noise according to the transition $p_{\theta^{*}}(z_{t-1}\mid z_{t})$.

\begin{equation}
p_{\theta^{*}}(z_{0}) = p_{\theta^{*}}(z_{T})\,p_{\theta^{*}}(z_{T-1}\mid z_{T}) \cdots p_{\theta^{*}}(z_{0}\mid z_{1})
= p_{\theta^{*}}(z_{T}) \prod_{t=1}^{T} p_{\theta^{*}}(z_{t-1}\mid z_{t})
\end{equation}

**Steps**

1. Start from a pure Gaussian noise in latent space $z_T$.  
2. At each step, remove the predicted noise using the UNet and rescale the latent embedding.  
3. Repeat until $t = 0$, producing a clean latent representation $\hat{z}_0$, which can then be decoded into the image space.

\begin{equation}
z_{t-1}
=
\frac{1}{\sqrt{\alpha_t}}
\left(
z_t
-
\frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}
\,\hat{\epsilon}_\theta(z_t, t)
\right)
+
\mathbf{1}_{t>0}\,\sqrt{\beta_t}\,\epsilon,
\qquad
\epsilon \sim \mathcal{N}(0, I)
\end{equation}


In [None]:
#17 Sampling process; Reverse Diffusion (Latent Sampling)

@torch.no_grad()                                                                   # no gradients are computed (inference mode)

def sample_latent(model, decoder, n_samples=16):

    model.eval()

    z = torch.randn(n_samples, 4, 7, 7).to(DEVICE)                                # z_T in the DDPM notation

    for t in reversed(range(TIMESTEPS)):                                          # Iterates backwards through timesteps from T-1 to 0.
        t_batch = torch.full((n_samples,), t, device=DEVICE).float() / TIMESTEPS
        pred_noise = model(z, t_batch)                                            # predicted noise component in z at timestep t

        alpha = alphas[t]                                                         # Core reverse diffusion step
        alpha_bar = alphas_cumprod[t]
        beta = betas[t]
        z = (1 / torch.sqrt(alpha)) * (z - beta * pred_noise / torch.sqrt(1 - alpha_bar))  # remove predicted noise from current latent z

        if t > 0:
            z += torch.sqrt(beta) * torch.randn_like(z)

    x = decoder(z)                                                                # map from latent to original feature/image space
    return x.clamp(-1,1)

In [None]:
#18 Initializing the models' hyperparameters and setting the loss function

loss_fn = nn.MSELoss()

encoder = Encoder().to(DEVICE)
decoder = Decoder().to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [None]:
#19 Train the denoising model (UNet) in the reverse direction

for epoch in range(EPOCHS):

    pbar = tqdm(train_loader)
    for x, _ in pbar:

        x = x.to(DEVICE)

        z = encoder(x)
        noise = torch.randn_like(z)

        t = torch.randint(0, TIMESTEPS, (z.size(0),), device=DEVICE)  # Each sample gets a different random timestep, so network sees all timesteps
        z_t = q_sample(z, t, noise)    # t determines alpha values
        t_norm = t.float() / TIMESTEPS  # networks process inputs in a small, continuous range better

        pred_noise = model(z_t, t_norm)
        loss = loss_fn(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch+1} | Loss: {loss.item():.4f}")

    samples = sample_latent(model, decoder, 16)
    samples = (samples + 1) / 2
    utils.save_image(samples, f"{SAVE_DIR}/epoch_{epoch+1}.png", nrow=4)

In [None]:
# Save the model and the artificial samples as Pytorch tensors

torch.save(model.state_dict(), f"{SAVE_DIR}/ddpm_chestmnist.pt")
print("Training complete.")