In [None]:
import os
import sys
from diffusion_utils import plot_utils
from diffusion_utils import data_utils
from diffusion_utils import context_unet
from diffusion_utils import sampler
from diffusion_utils import trainer
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from IPython.display import HTML

In [None]:
folder_path = 'C:\Applications\Projets\FibreAug'
image_path = os.path.join(folder_path, "dataset", "sprites_1788_16x16.npy")
label_path = None
save_dir = os.path.join(folder_path, "test_diffusion_model", "test0")
checkpoint = os.path.join(save_dir, "checkpoint_19.pth")
context_enable = False
ddim_enable = True
train_enable = True

sys.path.append(folder_path)
# Create save_dir if it does not exist
os.makedirs(save_dir, exist_ok=True)

# Setting Things Up

In [None]:
# hyperparameters

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
width = 16 # image width
height = 16 # image height

# training hyperparameters
batch_size = 4
n_epoch = 4
lrate = 1e-3

In [None]:
# construct model
my_model = context_unet.ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, width=width, height=height).to(device)

# Training

In [None]:
# load dataset and construct optimizer
dataset = data_utils.CustomDataset(image_path, label_path, data_utils.transform, context_enable=context_enable)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(my_model.parameters(), lr=lrate)

In [None]:
# construct trainer
my_trainer = trainer.Trainer(my_model, dataloader, optim, timesteps, lrate, device, beta1, beta2, context_enable, save_dir)

In [None]:
# Load checkpoint
if os.path.exists(checkpoint):
    start_epoch = my_trainer.load_checkpoint(checkpoint) + 1
else:
    start_epoch = 0

In [None]:
# training with context code
if train_enable:
    my_trainer.train(n_epoch, start_epoch)

# Sampling

In [None]:
# construct sampler and set model to eval mode
my_sampler = sampler.Sampler(model=my_model, timesteps=timesteps, width=width, height=height, device=device, beta1=beta1, beta2=beta2)
my_model.eval()

In [None]:
# visualize samples with randomly selected context
plt.clf()

ctx = None
if context_enable:
  ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()

if ddim_enable:
  samples, intermediate = my_sampler.sample_ddim(32, 20, ctx)
else:
  samples, intermediate = my_sampler.sample_ddpm(32, 20, ctx)

animation_ddpm = plot_utils.plot_sample(intermediate,32,4,save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())

In [None]:
# user defined context
if context_enable:
  ctx = torch.tensor([
      # hero, non-hero, food, spell, side-facing
      [1,0,0,0,0],
      [1,0,0,0,0],
      [0,0,0,0,1],
      [0,0,0,0,1],
      [0,1,0,0,0],
      [0,1,0,0,0],
      [0,0,1,0,0],
      [0,0,1,0,0],
  ]).float().to(device)
  
  if ddim_enable:
    samples, intermediate = my_sampler.sample_ddim(ctx.shape[0], 20, ctx)
  else:
    samples, intermediate = my_sampler.sample_ddpm(ctx.shape[0], 20, ctx)
    
  plot_utils.show_images(samples)

In [None]:
# mix of defined context
if context_enable:
  ctx = torch.tensor([
      # hero, non-hero, food, spell, side-facing
      [1,0,0,0,0],      #human
      [1,0,0.6,0,0],
      [0,0,0.6,0.4,0],
      [1,0,0,0,1],
      [1,1,0,0,0],
      [1,0,0,1,0]
  ]).float().to(device)
  
  if ddim_enable:
    samples, _ = my_sampler.sample_ddim(ctx.shape[0], 20, ctx)
  else:
    samples, _ = my_sampler.sample_ddpm(ctx.shape[0], 20, ctx)
    
  plot_utils.show_images(samples)

In [None]:
%timeit -r 1 my_sampler.sample_ddim(32, 25)
%timeit -r 1 my_sampler.sample_ddpm(32, 25)