In [1]:
#imports
import numpy as np
import scipy
from scipy.stats import norm, multivariate_normal

import urllib.request

from PIL import Image

import matplotlib.pyplot as plt 
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
plt.style.use("seaborn")
seed=192022
np.random.seed(seed)

  plt.style.use("seaborn")


In [None]:
#Improved Forward Process

def forward_improved_DDPM(original_img, alpha_bar, t):
    """Improved Forward Diffusion Process
    Args:
        original_img : Image at time-step zero (t = 0)
        alpha_bar    : The reparameterized version of beta
        t            : Current timestep 
    
    Returns:
        Image obtained at current time-step.
    """

    alpha_bar_t = alpha_bar[t].reshape(-1, 1, 1) # beta_t
    mu = np.sqrt(alpha_bar_t) * original_img     # mean
    sigma = np.sqrt(1.0 - alpha_bar_t)           # variance

    img_t = mu + sigma * np.random.randn(*original_img.shape)
    return img_t

    
urllib.request.urlretrieve(
  'https://upload.wikimedia.org/wikipedia/commons/e/ea/Dog_coat_variation.png',
   "Dog_coat_variation.png")
  
img = Image.open("Dog_coat_variation.png")

IMG_SIZE = (128, 128)

img = img.resize(size=IMG_SIZE)

timesteps = 100
beta_start = 0.0001
beta_end = 0.05
beta = np.linspace(beta_start, beta_end, num=timesteps, dtype=np.float32)

alpha = 1.0 - beta           
alpha_bar = np.cumprod(alpha)

processed_img = [img] # image at time-step `zero`

original_img = np.asarray(img.copy(), dtype=np.float32) / 255. 

specific_timestep = [20, 40, 60, 80, 99]

for step in specific_timestep:
    img_t = forward_improved_DDPM(original_img, alpha_bar, step)
    img_t = (img_t.clip(0,1) * 255.0).astype(np.uint8)
    processed_img.append(img_t)


# Plot
_, ax = plt.subplots(1 , len(processed_img), figsize=(15,5))
for i, sample in enumerate(processed_img):
    ax[i].imshow(sample)
    ax[i].set_title(f"Timestep: {i*20}")
    ax[i].axis("off")
    ax[i].grid(False)
plt.suptitle("Efficient Forward process in DDPMs", y=0.85)
plt.axis("off")
plt.tight_layout()