## **Sprites Generation using Diffusion Models**

#### Imports

In [10]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

from IPython.display import HTML
from typing import Dict, Tuple
from tqdm.auto import tqdm
from Model_utils import *
from Diffusion_utils import *
from Dataset import *

#### Neural Network Model for Generation

In [25]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):
        '''
        n_feat: Number of channels is the number of features, higher value will allow to capture more intricate features
        n_cfeat: Number of context feature; dim of the vector 
        '''
        super(ContextUnet, self).__init__()
        
        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height # assuming h == w. this must be divisible by 4
        
        # initialize the initial convolutional layer, n_feat is the number of intermediate feature maps 
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
    
        # initialize the down-sampling of the U-net with 2 levels. increase the channels 
        self.down1 = UnetDown(n_feat, n_feat)    # down1 # [10, 256, 8, 8]
        self.down2 = UnetDown(n_feat, 2*n_feat)  # down2 # [10, 256, 4, 4]
        # In Downsampling, no of channels increase while spatial dimensions decrease 
        # The increase in channels helps capture more complex features 
        self.to_vec = nn.Sequential(
            nn.AvgPool2d((4)), 
            nn.GELU()     # applies the GELU activation function element-wise to the pooled features
        )
        
        # Embed the timesteps and the context labels with fully-connected neural network
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
        
        # intit the up sampling path with three levels
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h // 4, self.h // 4),   # kernel_size and stride determines the h x w
            nn.GroupNorm(8, 2 * n_feat), # normalize
            nn.ReLU(),
        ) 
        
        # Up sampling path 
        self.up1 = UnetUp(4 * n_feat, n_feat)
        # as we have embedding in between, the channels matches
        self.up2 = UnetUp(2 * n_feat, n_feat)
        
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps 
            nn.GroupNorm(8, n_feat), # normalize 
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )
        '''map to the same number of channels as input, as we will subtract this and give this as input again'''

    
    def forward(self, x, t, c=None):
        '''
        x : (batch, n_feat, h, w) : image  - n_feat is actually channels 
        t : (batch, n_cfeat) : time step
        c : (batch, n_classes) : context label  """We use categorical instead of natural embedding of context"""
        '''
        # context mask says which samples to block the contex on
        # initial conv layer
        x = self.init_conv(x)
        # down
        down1 = self.down1(x)   
        down2 = self.down2(down1)
        
        # conver the feature maps to a vecotor and apply an activation 
        hiddenvec = self.to_vec(down2)
        
        # mask out context if context_mask == 1
        if c is None: # c is the context_mask
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
        
        # embed context and timstep
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
        
        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2 + temb2, down1)
        out = self.out(torch.cat((up3, x), 1))

        return out 

#### Hyperparameters

In [14]:
# diffusion parameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64   # hidden dim feature
n_cfeat = 5   # context vector of categorically 5 element  vector
height = 16  # 16 x 16 image
save_dir = './weights/'

# training hyperparameters
batch_size = 100
n_epoch = 32
lrate = 1e-3

**DDPM noise Schedule**

In [5]:
(8 - 5) * torch.linspace(0, 1, 10) + 5

tensor([5.0000, 5.3333, 5.6667, 6.0000, 6.3333, 6.6667, 7.0000, 7.3333, 7.6667,
        8.0000])

In [6]:
b_t_dummy = (8 - 5) * torch.linspace(0, 1, 10) + 5
a_t_dummy = 1 - b_t_dummy
a_t_dummy

tensor([-4.0000, -4.3333, -4.6667, -5.0000, -5.3333, -5.6667, -6.0000, -6.3333,
        -6.6667, -7.0000])

In [7]:
''' Constructing "Denoising Diffusion Probaiblistic Model" noise schedule'''
# determines the strength or level of noise added in each timestep 

# beta values at t. (beta2 - beta1) : scales values to lie within the range beta1 to beta2 ; generates equally spaced values from 0 to 1
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1   # + beta1 shifts the values to start from beta1

# alpha values at t ; complement of the a_t, In the diffusion models, the sum of alpha and beta values at each time step is 1 
a_t = 1 - b_t

# cummulative produce of a_t as eponential values
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
# find the cumulative sim of log along dim=0 and the .exp() exponentiates the cumulative sum, which effectively represents the cumulative product of a_t as exponential values 
ab_t[0] = 1 # initia value 1 ensuring the starting point for noise scaling 

In [8]:
'''Model'''

nn_model = ContextUnet(
    in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, 
    height=height
).to(device)

## **Training**

In [8]:
labels_data = np.load('sprite_labels_nc_1788_16x16.npy')

print(labels_data.shape)
print(labels_data)

(89400, 5)
[[1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0.]
 ...
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]]


In [11]:
dataset = CustomDataset(sfilename="./sprites_1788_16x16.npy", 
                        lfilename="./sprite_labels_nc_1788_16x16.npy",
                        transform=transform,
                        null_context=False
            )
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


Perturbate Data

In [16]:
def perturb_input(x, t, noise):
    # x has 4 dim, so add 3 extra dim to the ab_t
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None ,None ]) * noise

In [None]:

nn_model.train()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    
    pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 
        x_pert = perturb_input(x, t, noise)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
         
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        
        optim.step()

    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
        print('saved model at ' + save_dir + f"model_{ep}.pth")

### Sampling without the context

In [18]:
'''
Remove the predicted noise from the actual noise and we add some extra noise back to aviod mode collapse
'''
# denoising process

def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None: # adds noise
        z = torch.randn_like(x)  # with the same shape as the input  
    # get the sqrt() of beta values at given time step t from pre calculated noise schedule b_t
    noise = b_t.sqrt()[t] * z # multiply to calculate the diffusion noise
    # computing the denoised mean 
    scaling_factor = (1 - a_t[t]) / (1 - ab_t[t]).sqrt()
    mean = (x - pred_noise * scaling_factor) 
    normalized_mean = mean / a_t[t].sqrt()
    # adds the denoised mean to the diffusion noise 
    return normalized_mean + noise

In [19]:
# weights of model without the context 
nn_model.load_state_dict(torch.load(f"{save_dir}/model_trained.pth", map_location=device))
nn_model.eval()


ContextUnet(
  (init_conv): ResidualConvBlock(
    (conv1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (conv2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
  )
  (down1): UnetDown(
    (model): Sequential(
      (0): ResidualConvBlock(
        (conv1): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GELU(approximate='none')
        )
        (conv2): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-0

In [20]:
'''
Sampling using DDPM standard algorithm
'''
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting process
    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor. level of noise to be added 
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. 
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        
        if i % save_rate ==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [None]:
# visualize samples
plt.clf() # clears the current figure 
samples, intermediate_ddpm = sample_ddpm(32) # samples to generate 
# 32 samples, 4 rows, 
animation_ddpm = plot_sample(intermediate_ddpm, 32, 4, save_dir, "ani_run", None, save=False)
# converts animation into HTML format to display it in the output 

video_filename = "Process_DDPM_noise_added.mp4"
animation_ddpm.save(video_filename, writer="ffmpeg", fps=5)
print(f"Animation saved as {video_filename}")

HTML(animation_ddpm.to_jshtml())



#### Sampling without Adding Extra Noise

It will result in model collapse

In [21]:
'''Without adding the noise'''
@torch.no_grad()
def sample_ddpm_incorrect(n_sample):
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    intermediate = [] 
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        '''Adding the noise'''
        z = 0

        eps = nn_model(samples, t)    
        samples = denoise_add_noise(samples, i, eps, z)
        if i%20==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [None]:
# visualize samples
plt.clf()
samples, intermediate = sample_ddpm_incorrect(32)
animation_no_noise = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)


video_filename = "Process_DDPM_with_no_noise.mp4"
animation_no_noise.save(video_filename, writer="ffmpeg", fps=5)
print(f"Animation saved as {video_filename}")


HTML(animation_no_noise.to_jshtml())

## Sampling with Context

In [22]:
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
nn_model.load_state_dict(torch.load(f"{save_dir}/context_model_31.pth"))
nn_model.eval()

ContextUnet(
  (init_conv): ResidualConvBlock(
    (conv1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
    (conv2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
    )
  )
  (down1): UnetDown(
    (model): Sequential(
      (0): ResidualConvBlock(
        (conv1): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GELU(approximate='none')
        )
        (conv2): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-0

In [23]:
'''Sample DDPM with Context'''

@torch.no_grad()
def sample_ddpm_context(n_sample, context, save_rate=20):
    samples = torch.randn(n_sample, 3, height, height).to(device) # (batch_size, channels, h, w)
    
    intermediate = []
    
    for i in range(timesteps, 0, -1): # travel back in time 
        print(f"Sampling timestep {i:3d}", end='\r')
        
        # reshape the tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
        
        z = torch.randn_like(samples) if i > 1 else 0
        
        eps = nn_model(samples, t, c=context)        
        samples = denoise_add_noise(samples, i, eps, z)
        
        if i % save_rate==0 or i==timesteps or i<8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    
    return samples, intermediate

In [None]:
plt.clf()
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
samples, intermediate = sample_ddpm_context(32, ctx)
animation_ddpm_context = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)

file_name = 'DDPM_sample_with_context.mp4'
animation_ddpm_context.save(file_name, writer="ffmpeg", fps=15)
print(f"Animation saved as {file_name}")

HTML(animation_ddpm_context.to_jshtml())

In [None]:
def show_images(imgs, ctx, nrow=2):
    labels = ["hero", "non-hero", "food", "spell", "side-facing", "human"]
    _, axs = plt.subplots(nrow, imgs.shape[0] // nrow, figsize=(4, 3))
    axs = axs.flatten()
    for img, context, ax in zip(imgs, ctx.detach().cpu().numpy(), axs):
        img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(img)
        label = labels[np.argmax(context)]
        ax.set_title(f"{label}")

    plt.show()

    
ctx = torch.tensor([
    # hero, non-hero, food, spell, side-facing
    [1,0,0,0,0],    # hero
    [1,0,0,0,0],    # non-hero 
    [0,0,1,0,0],    # food
    [0,0,0,1,0],    # spell
    [0,0,0,0,1],    # side-facing
    [0,0,0,1,1]     # human
]).float().to(device)
samples, _ = sample_ddpm_context(ctx.shape[0], ctx)
show_images(samples, ctx)

In [None]:
import matplotlib.pyplot as plt
import ipywidgets as widgets

context_dropdown = widgets.Dropdown(
    options=["hero", "non-hero", "food", "spell", "side-facing", "human"],
    description='Select Context:',
)
# Function to show images based on selected context
def show_images_for_context(context):
    # Convert the selected context to a one-hot vector
    contexts = {
        "hero": [1, 0, 0, 0, 0],
        "non-hero": [1, 0, 0, 0, 0],
        "food": [0, 0, 1, 0, 0],
        "spell": [0, 0, 0, 1, 0],
        "side-facing": [0, 0, 0, 0, 1],
        "human": [1, 0, 0, 1, 0]
    }

    # Convert the context to a PyTorch tensor
    ctx = torch.tensor([contexts[context]]).float().to(device)

    # Generate images based on the selected context
    n_samples = 6  # Number of images to generate
    samples, _ = sample_ddpm_context(n_samples, ctx)

    # Determine the layout based on the number of samples
    ncols = min(n_samples, 3)
    nrows = -(-n_samples // ncols)  # Ceiling division

    # Create subplots with specific number of columns and rows
    fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
    axs = axs.flatten()

    # Show the images in the subplots
    for i in range(n_samples):
        img = (samples[i].squeeze().permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
        axs[i].imshow(img)
        axs[i].axis('off')
        axs[i].set_title(f"Image {i + 1}", fontsize=10)
        axs[i].title.set_position([0.5, -0.3])  # Adjust title position below the image

    # Hide empty subplots if any
    for j in range(n_samples, nrows * ncols):
        axs[j].axis('off')

    plt.tight_layout()
    plt.show()

# Event handler for dropdown value change
def dropdown_eventhandler(change):
    if change['type'] == 'change' and change['name'] == 'value':
        show_images_for_context(change['new'])

# Link the dropdown widget to the event handler
context_dropdown.observe(dropdown_eventhandler)

# Display the dropdown widget
display(context_dropdown)

In [68]:
torch.eye(5).long()

tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1]])

## FAST sampling

In [20]:
def denoise_ddim(x, t, t_prev, pred_noise):
    ab = ab_t[t]
    ab_prev = ab_t[t_prev]
    
    x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
    dir_xt = (1 - ab_prev).sqrt() * pred_noise
    
    return x0_pred + dir_xt

**Without Context**

In [None]:
nn_model.load_state_dict(torch.load(f"{save_dir}/model_trained.pth", map_location=device))
nn_model.eval() 

# we can sample quickly using DDIM
@torch.no_grad()
def sample_ddim(n_sample, n=20):
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    intermediate = [] 
    step_size = timesteps // n
    for i in range(timesteps, 0, -step_size):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_ddim(samples, i, i - step_size, eps)
        intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate


plt.clf()
samples, intermediate = sample_ddim(32, n=25)
animation_ddim = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)

HTML(animation_ddim.to_jshtml())


**With Context**

In [24]:
nn_model.load_state_dict(torch.load(f"{save_dir}/context_model_31.pth", map_location=device))
nn_model.eval() 

@torch.no_grad()
def sample_ddim_context(n_sample, context, n=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 
    step_size = timesteps // n
    for i in range(timesteps, 0, -step_size):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        eps = nn_model(samples, t, c=context)    # predict noise e_(x_t,t)
        samples = denoise_ddim(samples, i, i - step_size, eps)
        intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [None]:
plt.clf()
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
samples, intermediate = sample_ddim_context(32, ctx)
animation_ddpm_context = plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)


video_filename = "DDIM_with_context.mp4"
animation_ddpm_context.save(video_filename, writer="ffmpeg", fps=5)
print(f"Animation saved as {video_filename}")

HTML(animation_ddpm_context.to_jshtml())

In [None]:
def show_images(imgs, ctx, nrow=2):
    labels = ["hero", "non-hero", "food", "spell", "side-facing", "human"]
    _, axs = plt.subplots(nrow, imgs.shape[0] // nrow, figsize=(4, 3))
    axs = axs.flatten()
    i = 0
    for img, context, ax in zip(imgs, ctx.detach().cpu().numpy(), axs):
        img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(img)
        label = labels[i]
        ax.set_title(f"{label}")
        i = i + 1

    plt.show()

    
ctx = torch.tensor([
    # hero, non-hero, food, spell, side-facing
    [1,0,0,0,0],    # hero
    [1,0,0,0,0],    # non-hero 
    [0,0,1,0,0],    # food
    [0,0,0,1,0],    # spell
    [0,0,0,0,1],    # side-facing
    [0,0,0,1,1]     # human
]).float().to(device)
samples, _ = sample_ddim_context(ctx.shape[0], ctx)
show_images(samples, ctx)