# Guiding DIP Early Stopping with DDPM-inspired Supervision

In [1]:
%%HTML
<style>
    body{
 --vscode-font-family: "ComicShannsMono Nerd Font";
    }
</style>

## Import the libraries

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.subplots import make_subplots
from skimage.transform import resize
from tqdm import tqdm

import torch
import torch.nn as nn

from deepCNN import DeepCNN

## Define some utility functions

In [3]:
def get_image(
        image_path: str,
        image_width: int=256,
        image_height: int=256,
    ) -> np.ndarray:
    """Read an image from a file and return it as a numpy array."""
    image = plt.imread(image_path)
    image_np = resize(image, (image_width, image_height), anti_aliasing=True)
    image_np = (image_np-image_np.min()) / (image_np.max()-image_np.min())
    return image_np

def show_image(image_np: np.ndarray) -> None:
    """Show a numpy array as an image."""
    fig = px.imshow(image_np, color_continuous_scale='gray', height=600, width=600)
    fig.update_layout(coloraxis_showscale=False, margin=dict(l=50, r=50, b=50, t=50))
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    fig.show()

def get_scheduler(
        type: str="scaled_linear",
        beta_start: float=0.0001,
        beta_end: float=0.02,
        num_train_timesteps: int=100
    ) -> torch.Tensor:
    """Return a scheduler for beta values in DDPM."""
    if type == "linear":
        return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
    elif type == "scaled_linear":
        # this schedule is very specific to the latent diffusion model.
        return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
    elif type == "sigmoid":
        # GeoDiff sigmoid schedule
        betas = torch.linspace(-6, 6, num_train_timesteps)
        return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

def ddpm_forward(
        x0: torch.Tensor,
        step: int,
        alpha_bars: torch.Tensor
    ) -> torch.Tensor:
    """Retuen a noisy image for forward pass in DDPM in a given step."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    img_w, img_h = x0.shape
    a_bar = alpha_bars[step]
    noisy_image = x0*a_bar.sqrt() + (1-a_bar).sqrt() * torch.randn(img_w, img_h).to(device)
    noisy_image = (noisy_image-noisy_image.min()) / (noisy_image.max()-noisy_image.min())
    return noisy_image

## Testing the utility functions

* Read an image

In [4]:
image_w, image_h = 128, 128
image_np = get_image('data/boat.png', image_w, image_h)
show_image(image_np)

* show the difference between schedulers

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

beta_start = 0.0001
beta_end = 0.02
num_of_steps = 100
fig = make_subplots(rows=3, cols=10, row_titles=["Linear", "Scaled Linear", "Sigmoid"])
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.update_layout(width=1400, height=500)
for row, scheduler_type in enumerate(["linear", "scaled_linear", "sigmoid"]):
    betas = get_scheduler(scheduler_type, beta_start, beta_end, num_of_steps).to(device)
    alphas = 1-betas
    alpha_bars = torch.tensor([torch.prod(alphas[:i+1]) for i in range(len(alphas))]).to(device)
    raw_image = torch.tensor(image_np).to(device)
    noisy_image_seq = [ddpm_forward(raw_image, i, alpha_bars) for i in range(num_of_steps)]
    noisy_image_seq = [noisy_image.cpu().detach().numpy() for noisy_image in noisy_image_seq]

    select_step = len(noisy_image_seq) / 10
    for idx, noisy_image in enumerate(noisy_image_seq):
        if idx % select_step == 0:
            fig.add_trace(
                px.imshow(noisy_image, binary_string=True, color_continuous_scale='gray', width=200, height=200).data[0],
                row=row+1, col=int(idx/select_step)+1
            )
fig.show()

In [6]:
ground_truth = torch.tensor(image_np).to(device)
noisy_image = torch.tensor(noisy_image_seq[5]).to(device)
input_image = torch.rand(image_w, image_h).to(device)
show_image(noisy_image)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepCNN(
    num_of_layers=8,
    num_of_channels=128,
    kernel_size=3
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [12]:
output_image_list = []
epoch_num = 4000

with tqdm(range(epoch_num)) as pbar:
    for epoch in pbar:
        optimizer.zero_grad()
        output = model(input_image)
        loss = criterion(output, noisy_image)
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"Loss": f"{loss.item():.6f}"})
        output_image_list.append(output.cpu().detach().numpy())


100%|██████████| 4000/4000 [15:51<00:00,  4.20it/s, Loss=0.002739]


In [17]:
# animation
fig = px.imshow(np.array(output_image_list[::10]), animation_frame=0, binary_string=True, color_continuous_scale='gray', width=700, height=700)
fig.show()