In [1]:
from typing import Dict, Tuple
import sys 
!{sys.executable} -m pip install tqdm
from tqdm import tqdm
!{sys.executable} -m pip install torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
!{sys.executable} -m pip install torchvision
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
!{sys.executable} -m pip install matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML
!{sys.executable} -m pip install h5py



In [3]:
from diffusion_utilities import *
torch.cuda.is_available()

False

In [53]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features
        super(ContextUnet, self).__init__()

        # number of input channels, number of intermediate feature maps and number of classes
        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16...

        # Initialize the initial convolutional layer
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        # Initialize the down-sampling path of the U-Net with two levels
        self.down1 = UnetDown(n_feat, n_feat)        # down1 #[10, 256, 8, 8]
        self.down2 = UnetDown(n_feat, 2 * n_feat)    # down2 #[10, 256, 4,  4]
        
         # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())

        # Embed the timestep and context labels with a one-layer fully connected neural network
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)

        # Initialize the up-sampling path of the U-Net with three levels
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample  
            nn.GroupNorm(8, 2 * n_feat), # normalize                       
            nn.ReLU(),
        )
        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)

        # Initialize the final convolutional layers to map to the same number of channels as the input image
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0
            nn.GroupNorm(8, n_feat), # normalize
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
        )

    def forward(self, x, t, c=None):
        """
        x : (batch, n_feat, h, w) : input image
        t : (batch, n_cfeat)      : time step
        c : (batch, n_classes)    : context label
        """
        # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on

        # pass the input image through the initial convolutional layer
        x = self.init_conv(x)
        
        # pass the result through the down-sampling path
        down1 = self.down1(x)       #[10, 256, 8, 8]
        down2 = self.down2(down1)   #[10, 256, 4, 4]
        
        # convert the feature maps to a vector and apply an activation
        hiddenvec = self.to_vec(down2)
        
        # mask out context if context_mask == 1
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
            
        # embed context and timestep
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)     # (batch, 2*n_feat, 1,1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
        # print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape} hiddenvec {hiddenvec.shape}")

        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1 + temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2 + temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        # print(f'out {out} shape {out.shape}')
        return out


In [32]:
# hyperparameters

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_dir = './weights/'

# training hyperparameters
batch_size = 100
n_epoch = 3
lrate=1e-3

In [6]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

# Training

In [34]:
# load dataset and construct optimizer
dataset = CustomDataset("./sprites_1788_16x16.npy", "./sprite_labels_nc_1788_16x16.npy", transform, null_context=False)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [None]:
# save dataset to hdf format consumable from the F# code
dataset.savehdf()

In [36]:
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise

In [73]:
torch.manual_seed(0)
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

In [74]:
# training without context code
# set into train mode
nn_model.train()


for ep in range(3):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    print(f'lr {lrate*(1-ep/n_epoch)}')
    
    x_loss = 1.0
    i = 0
    # pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in dataloader:   # x: images
        torch.manual_seed(i)
        optim.zero_grad()
        x = x.to(device)
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 

        x_pert = perturb_input(x, t, noise)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
        
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        x_loss = loss.item()
        loss.backward()
        
        optim.step()

        if i % 20 == 0:
            print(f'loss {x_loss}')
        
        i += 1
        # if i >= 110: 
        #     break

    # # save model periodically
    if ep%4==0 or ep == int(n_epoch-1):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth")
        print('saved model at ' + save_dir + f"model_{ep}.pth")

epoch 0
lr 0.001
loss 1.1516586542129517
loss 0.492750346660614
loss 0.42482617497444153
loss 0.4150702953338623
loss 0.3941064178943634
loss 0.314238965511322
loss 0.381748229265213
loss 0.3399651348590851
loss 0.3238184452056885
loss 0.2681446671485901
loss 0.4064839780330658
loss 0.32247278094291687
loss 0.4071526527404785
loss 0.334504097700119
loss 0.31350448727607727
loss 0.3659830689430237
loss 0.2884584367275238
loss 0.26043570041656494
loss 0.2589876055717468
loss 0.3609120547771454
loss 0.24572817981243134
loss 0.2274949699640274
loss 0.27715063095092773
loss 0.25526225566864014
loss 0.3018032908439636
loss 0.23833973705768585
loss 0.24979878962039948
loss 0.2835819125175476
loss 0.2343187928199768
loss 0.2291814386844635
loss 0.21802441775798798
loss 0.21870696544647217
loss 0.26937758922576904
loss 0.2468492090702057
loss 0.22930386662483215
loss 0.23147143423557281
loss 0.2918083071708679
loss 0.22675734758377075
loss 0.23769094049930573
loss 0.23509490489959717
loss 0.274

# Sampling

In [75]:
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

In [76]:
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, save_rate = 20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)  

    # array to keep track of generated steps for plotting
    intermediate = [] 

    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = nn_model(samples, t)    # predict noise e_(x_t,t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate ==0 or i==timesteps:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

# View Epoch 0

In [77]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_0.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

Loaded in Model


In [78]:
torch.manual_seed(0)
samples, intermediate_ddpm = sample_ddpm(32)

sampling timestep   1

In [79]:
# visualize samples
plt.clf()

animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

gif animating frame 24 of 25

<Figure size 640x480 with 0 Axes>

# View Epoch 2

In [80]:
# load in model weights and set to eval mode
nn_model.load_state_dict(torch.load(f"{save_dir}/model_2.pth", map_location=device))
nn_model.eval()
print("Loaded in Model")

Loaded in Model


In [88]:
torch.manual_seed(0)
samples, intermediate_ddpm = sample_ddpm(32)

sampling timestep   1

In [89]:
# visualize samples
plt.clf()

animation_ddpm = plot_sample(intermediate_ddpm,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

gif animating frame 24 of 25

<Figure size 640x480 with 0 Axes>