In [1]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import string
import re
import itertools
import io
import json
import os
import sys
import ast
import time
import requests
import random
import math
import inspect

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import PIL
import scipy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import einops
import transformers
import diffusers
import accelerate
#import clip
import torchvision.transforms.functional as TF

import datasets
from datasets import load_dataset

import pickle

accelerator=accelerate.Accelerator()
device=accelerator.device

base_dir="<PATH TO SAVE>"



In [2]:
#FID module to compute FID score for generated images during training
class FID_Module(nn.Module):
  def __init__(self, train_data):
    super(FID_Module, self).__init__()
    self.train_data=train_data
    self.inception =torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True).to(device)
    #register hook before the final classifier
    self.inception.Mixed_7c.register_forward_hook(self.hook_wo_fc)
    self.avg_pool=nn.AdaptiveAvgPool2d(output_size=(1,1))

    self.true_mean, self.true_cov=self.get_true_images_mean_cov(train_data) #2048-d vectors
  
  def get_frechet_distance(self, mean, cov):
    #mean vectors, cov matrices: batch size should be equal.
    #test set is for generalization => we want to know how well the distribution is estimated
    #take FID score for whle 50k traindata

    diff = (mean - self.true_mean).cpu().numpy()
    cov=cov.cpu().numpy()
    true_cov=self.true_cov.cpu().numpy()
    # Product might be almost singular
    #be aware of negative elemnts in cov products.
    covmean, _ = scipy.linalg.sqrtm(cov.dot(true_cov), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(cov.shape[0]) * eps
        covmean = scipy.linalg.sqrtm((cov + offset).dot(true_cov + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(cov)
            + np.trace(true_cov) - 2 * tr_covmean)
  
  def get_mean_cov(self, inception_output):
    #inception output: (B,2048)
    #mean vector: (2048,)
    #cov matrix: (2048,2048)
    mean=torch.mean(inception_output, dim=0)
    cov=torch.cov(inception_output.permute(1,0))
    return mean, cov

  def hook_wo_fc(self, module, input, output):
    self.wo_fc_output=output
    return
  
  def get_true_images_mean_cov(self, train_data):
    batch_size=256
    output_list=[]
    #preprocess train data into (299,299)
    fid_preprocess = torchvision.transforms.Compose([
      torchvision.transforms.Resize(299),
      torchvision.transforms.CenterCrop(299),
    ])
    train_loader=torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    pbar=tqdm(desc="Mean, Cov for Train Images", total=len(train_data))
    #no need to normalize, [0,1] for FID score computation
    for batch_data in train_loader:
      x, labels=batch_data
      x=fid_preprocess(x)
      x=x.to(device)
      _, output=self.get_output(x)
      output_list.append(output)
      pbar.update(batch_size)
    pbar.close()
    whole_output=torch.cat(output_list, dim=0)
    true_mean, true_cov=self.get_mean_cov(whole_output)
    return true_mean, true_cov

  def get_output(self, x):
    with torch.no_grad():
      incep_output=self.inception(x).logits
      wo_fc_output=self.avg_pool(self.wo_fc_output)
      wo_fc_output=wo_fc_output.reshape(x.size(0), 2048)
    return incep_output, wo_fc_output
  
  def get_inception_score(self, incep_output):
    #incep output: (B,1000)
    batch_incep_probs=incep_output.softmax(dim=-1)
    incep_marginal=torch.mean(batch_incep_probs, dim=0)
    avg_kl_div=0
    for incep_probs in batch_incep_probs:
      #has numerical errors at this moment.
      incep_categorical=torch.distributions.categorical.Categorical(probs=incep_probs)
      marginal_categorical=torch.distributions.categorical.Categorical(probs=incep_marginal)
      kl_div=torch.distributions.kl.kl_divergence(incep_categorical, marginal_categorical)
      avg_kl_div=avg_kl_div+(kl_div.item()-avg_kl_div)
    incep_score=np.exp(avg_kl_div)
    return incep_score
  
  def forward(self, gen_images):
    #get FID score
    batch_size=256
    #preprocess
    fid_preprocess = torchvision.transforms.Compose([
      torchvision.transforms.Resize(299),
      torchvision.transforms.CenterCrop(299),
    ])
    gen_images=fid_preprocess(gen_images)

    gen_loader=torch.utils.data.DataLoader(gen_images, batch_size=batch_size, shuffle=True)
    incep_output_list=[]
    wo_fc_output_list=[]
    pbar=tqdm(desc="", total=gen_images.size(0))
    #getting UNet output of generated images
    for gen_batch in gen_loader:
      incep_output, wo_fc_output=self.get_output(gen_images)
      incep_output_list.append(incep_output)
      wo_fc_output_list.append(wo_fc_output)
      pbar.update(batch_size)
    whole_incep_output=torch.cat(incep_output_list, dim=0)
    whole_wo_fc_output=torch.cat(wo_fc_output_list, dim=0)
    #computing FID score
    gen_mean, gen_cov=self.get_mean_cov(whole_wo_fc_output)
    fid=self.get_frechet_distance(gen_mean, gen_cov)
    incep_score=self.get_inception_score(whole_incep_output)
    return incep_score, fid

In [3]:
#DDPM Model
class DDPM(nn.Module):
  def __init__(self, betas, image_size, train_time_steps=1000, ddim_steps=20):
    super(DDPM, self).__init__()
    self.img_h=image_size[0]
    self.img_w=image_size[1]
    #train time steps
    self.train_time_steps=train_time_steps
    self.ddim_steps=ddim_steps
    self.ddpm_timesteps=torch.arange(1, train_time_steps+1, 1).to(device) #should be in range of [1,T]
    #CIFAR10 uses quadratic trajectoryes
    c_value=(train_time_steps-1)/(ddim_steps-1)**2
    dti=[int(c_value*(i**2)) for i in range(0, ddim_steps)]
    self.ddim_timestep_indices=torch.LongTensor(dti).to(device)

    #unet for predicting epsilon vector => the only learnable paramters
    #model_id = "google/ddpm-cifar10-32"
    #pretrained = diffusers.DDPMPipeline.from_pretrained(model_id)
    #self.unet=pretrained.unet.to(device) #using pretrained UNet
    self.unet=diffusers.UNet2DModel(sample_size=image_size, ).to(device)
    
    #forward process variance schedule
    self.betas=betas
    self.alphas=1-betas
    self.sqrt_alphas=torch.sqrt(self.alphas)
    self.inv_sqrt_alphas=1/self.sqrt_alphas
    #alpha_bar_t
    self.alpha_bars=torch.cumprod(self.alphas, dim=0)
    self.sqrt_alpha_bars=torch.sqrt(self.alpha_bars)
    self.sqrt_one_minus_alpha_bars=torch.sqrt(1-self.alpha_bars)
    #alpha_bar_(t-1) => used for computing beta_tilde => to alpha_bars: add 1 to first index and push 1 index right
    self.alpha_bars_prev=torch.zeros_like(self.alpha_bars).to(device)
    self.alpha_bars_prev[0]=1
    self.alpha_bars_prev[1:]=self.alpha_bars[:-1]
    self.sqrt_alpha_bars_prev=torch.sqrt(self.alpha_bars_prev)
    self.beta_tildes=((1-self.alpha_bars_prev)/(1-self.alpha_bars))*self.betas

    #set variance for backward process as beta_tilde or betas'
    self.backward_std_hat=torch.sqrt(self.betas)
    self.backward_std=torch.sqrt(self.beta_tildes)
  
  def normalize(self, tensor):
    #[0,1] => [-1,1]
    #mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    return tensor*2-1
  
  def unnormalize(self, tensor):
    #[-1,1] => [0,1]
    #mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    return (tensor+1)*0.5
    
  #used for training => same as q_sampling but with estimated noise vector
  def get_x_t_vector(self, x_0, timesteps, epsilon):
    #modeling the forward process from x_0 -> x_t =>
    x_0_coeff=torch.gather(self.sqrt_alpha_bars, 0, timesteps).reshape(timesteps.size(0),1,1,1)
    eps_coeff=torch.gather(self.sqrt_one_minus_alpha_bars, 0, timesteps).reshape(timesteps.size(0),1,1,1)
    x_t=x_0_coeff*x_0+eps_coeff*epsilon
    return x_t
  
  #rest used for sequential sampling
  def q_posterior_sample(self, x_t, x_0, timestep_idx):
    batch_size=x_0.size(0)
    normal_vector=torch.randn_like(x_0).to(device)
    posterior_mean=self.get_posterior_mean(x_t, x_0, timestep_idx)
    x_prev=posterior_mean+self.beta_tildes[timestep_idx]*normal_vector
    return x_prev
  
  def get_x_0_vector(self, x_t, timestep_idx, epsilon):
    x_0_coeff=self.sqrt_alpha_bars[timestep_idx]
    eps_coeff=self.sqrt_one_minus_alpha_bars[timestep_idx]
    x_0=(x_t-eps_coeff*epsilon)/x_0_coeff
    return x_0
  
  #used to demonstrate forward process
  def q_sample(self, x_0, timestep_idx):
    #estimating x_t from x_0 using q(x_t|x_t-1)
    batch_size=x_0.size(0)
    normal_vector=torch.randn_like(x_0).to(device)
    x_t=self.sqrt_alpha_bars[timestep_idx]*x_0+self.sqrt_one_minus_alpha_bars[timestep_idx]*normal_vector
    return x_t

  def get_backward_mean(self, x_t, timestep_idx, epsilon):
    mu_vector=self.inv_sqrt_alphas[timestep_idx]*(x_t-(self.betas[timestep_idx]/self.sqrt_one_minus_alpha_bars[timestep_idx])*epsilon)
    return mu_vector

  def p_sample(self, x_t, timestep_idx):
    #x_(t-1)=N(x_t; mu(x_t, t), sigma_t*I)
    timestep=self.ddpm_timesteps[timestep_idx]
    timestep_tensor=torch.full([x_t.size(0)], timestep).to(device)
    eps_pred=self.forward(x_t, timestep_tensor) #pretrained UNet gives almost gaussian tensor.
    mu_vector=self.get_backward_mean(x_t, timestep_idx, eps_pred)
    if timestep==0:
      normal_vector=torch.zeros_like(x_t).to(device)
    else:
      normal_vector=torch.randn_like(x_t).to(device)
    x_prev=mu_vector+self.backward_std_hat[timestep_idx]*normal_vector
    return x_prev
  
  def ddim_sample(self, x_t, idx, next_idx):
    timestep_tensor=torch.full([x_t.size(0)], self.ddpm_timesteps[idx]).to(device)
    #timestep idx in DDPM lists => ddim timstep -1
    eps_pred=self.forward(x_t, timestep_tensor)
    normal_vector=torch.randn_like(x_t).to(device)
    x_0=self.get_x_0_vector(x_t, idx, eps_pred)
    if idx==0:
      #directly sample x_0 from x_t
      return x_0
    else:
      #alpha_bars_prev=alpha_bars[next_idx] => b.c. ddim steps skipping steps => alpha_bars_prev's index !=idx-1
      eps_coeff=torch.sqrt(1-self.alpha_bars[next_idx]-self.ddim_std[idx]**2)
      x_prev=self.sqrt_alpha_bars_prev[next_idx]*x_0+eps_coeff*eps_pred+self.ddim_std[idx]*normal_vector
      return x_prev
  
  def generate(self, batch_size, mode="ddpm", ddim_eta=None, verbose=True):
    assert mode in ['ddpm', 'ddim'], "Invalid Mode, Valid Modes: ddpm, ddim"
    assert (mode=="ddim" and ddim_eta!=None) or (mode!='ddim' and ddim_eta==None), "DDIM eta should be provided"
    #ddim with eta=1 => DDPM, eta=0 => DDIM (deterministic)
    x=torch.randn(batch_size, 3, self.img_w, self.img_w).to(device)
    #iterate through T steps to generate unconditionally
    inference_steps=self.train_time_steps if mode=="ddpm" else self.ddim_steps
    if mode=='ddim':
      self.ddim_std=ddim_eta*self.backward_std
    else:
      self.ddim_std=self.backward_std
    if verbose:
      pbar=tqdm(desc="Unconditional Generation", total=inference_steps)
    for timestep_idx in np.arange(inference_steps-1, -1, -1):
      #predicting x_(t-1)
      with torch.no_grad():
        if mode=="ddpm":
          x=self.p_sample(x, timestep_idx)
          #x=self.ddim_sample(x, timestep_idx)
        else:
          idx=self.ddim_timestep_indices[timestep_idx] #steps in DDPM corresponding to DDIM
          if timestep_idx==0:
            next_idx=0
          else:            
            next_idx=self.ddim_timestep_indices[timestep_idx-1]
          x=self.ddim_sample(x, idx, next_idx)
      if verbose:
        pbar.update(1)
    x=self.unnormalize(x)
    x=torch.clamp(x, min=0, max=1)
    if verbose:
      pbar.close()
    return x

  def forward(self, x_t, timesteps):
    #can be processed in batch x_t : b,c,h,w & timesteps: b,
    #predicting eps(x_t, t), timesteps ranging [0, T-1]
    eps_pred=self.unet(x_t, timesteps).sample
    #eps_pred: same shape as x_t
    return eps_pred

In [4]:
class DDPM_Trainer():
  def __init__(self, data, image_size, num_train_steps, ddim_steps, beta_low, beta_high):
    self.image_size=image_size
    #pipeline = diffusers.DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32")
    #self.unet=pipeline.unet.to(device)
    #ddpm schedule parameters
    self.num_train_steps=num_train_steps
    self.ddim_steps=ddim_steps
    self.betas=betas=torch.linspace(beta_low, beta_high, num_train_steps).to(device)

    self.ddpm=DDPM(self.betas, image_size, num_train_steps, ddim_steps)
    self.fid_module=FID_Module(data['train'])

    #testing with CIFAR 10
    self.train_data=data['train']
    self.test_data=data['test']

    self.train_logs={
        'loss_history': [],
        'val_fid_history': [],
        'is_history': [],
        'validation_images': [],
    }
  
  def show_noising_image(self):
    train_loader=torch.utils.data.DataLoader(self.train_data, batch_size=4, shuffle=True)
    milestones=[int(a*self.num_train_steps) for a in [0.25, 0.5, 0.75, 1]]
    images=[[] for _ in milestones]
    for batch_data in train_loader:
      x, _ =batch_data
      x=self.ddpm.normalize(x)
      x=x.to(device)
      #run forward process
      for timestep in range(self.num_train_steps):
        x=self.ddpm.q_sample(x, timestep)
      break
    #display images
    plt.figure()
  
  def plot_train_logs(self):
    plt.figure(figsize=(20,5))
    #plot loss
    plt.subplot(1,3,1)
    lh=self.train_logs['loss_history']
    plt.plot(np.arange(1, len(lh+1), 1), lh)
    plt.xlabel("Update steps")
    plt.ylabel("Loss")

    #plot FID scores
    plt.subplot(1,3,2)
    fidh=self.train_logs['val_fid_history']
    plt.plot(np.arange(1, len(fidh)+1, 1), fidh)
    plt.xlabel("Validation steps")
    plt.ylabel("FID scores")

    #plot IS scores
    plt.subplot(1,3,3)
    ish=self.train_logs['is_history']
    plt.plot(np.arange(1, len(ish)+1, 1), ish)
    plt.xlabel("Validation steps")
    plt.ylabel("FID scores")

    plt.show()
    return
  
  def save_train_logs(self, name):
    with open(os.path.join(base_dir, "train_logs/{:s}.pkl".format(name)), 'wb') as file:
      pickle.dump(self.train_logs, file)
    #also save model
    torch.save(self.ddpm.state_dict(), os.path.join(base_dir, "models/{:s}.pkl".format(name)))
    return

  def get_val_sizes(self, val_batch_size, len_validation):
    if len_validation%val_batch_size==0:
      length=int(len_validation/val_batch_size)
      val_sizes=[val_batch_size for _ in range(length)]
    else:
      length=len_validation//val_batch_size
      val_sizes=[val_batch_size for _ in range(length)]
      val_sizes.append(int(len_validation-length*val_batch_size))
    return val_sizes
  
  def validate(self, epoch, val_batch_size, len_validation):
    #generate images and observe quality of images qualitatively.
    gen_images_list=[]
    #generate 50k images
    pbar=tqdm(desc="Validating", total=len_validation)
    val_sizes=self.get_val_sizes(val_batch_size, len_validation)
    for val_size in val_sizes:
      #apply DDIM sampling for efficiency.
      gen_images=self.ddpm.generate(val_size, mode="ddim", ddim_eta=0, verbose=False)
      gen_images_list.append(gen_images)
      pbar.update(val_size)
    pbar.close()
    whole_gen_images=torch.cat(gen_images_list, dim=0)
    #get fid score
    incep_score, fid_score=self.fid_module(whole_gen_images)
    #show some images
    print("Validation for Epoch: {:d}, IS: {:.3f}, FID score: {:.3f}".format(epoch, incep_score, fid_score.item()))
    plt.figure(figsize=(20,5))
    for idx in range(16):
      plt.subplot(2,8,idx+1)
      plt.imshow(whole_gen_images[idx].permute(1,2,0).cpu().numpy())
    plt.show()
    self.train_logs['validation_images'].append(whole_gen_images[:100]) #(1000,3,32,32) => large size => store only 100 images.
    self.train_logs['is_history'].append(incep_score)
    self.train_logs['val_fid_history'].append(fid_score.item())
    return incep_score, fid_score
  
  def show_generation(self, epoch):
    print("Validation for Epoch: {:d}".format(epoch))
    gen_images=self.ddpm.generate(16, verbose=False)
    plt.figure(figsize=(20,5))
    for idx in range(16):
      plt.subplot(2, 8, idx+1)
      plt.imshow(gen_images[idx].permute(1,2,0).cpu().numpy())
    plt.show()
    return

  def test(self):
    return
  
  def get_loss(self, x_0):
    batch_size, ch, img_h, img_w=x_0.size()
    #loss ftn: MSE Loss
    mse_loss=nn.MSELoss(reduction="mean")
    normal_vector=torch.randn_like(x_0).to(device)
    timestep_indices=torch.randint(low=0, high=self.num_train_steps, size=[batch_size]).to(device)
    timesteps=torch.gather(self.ddpm.ddpm_timesteps, 0, timestep_indices)
    #x_t: corrupted image from x_0
    x_t=self.ddpm.get_x_t_vector(x_0, timestep_indices, normal_vector)
    #estimated noise conditioned on timestep from x_t
    pred_eps=self.ddpm(x_t, timesteps)
    loss=mse_loss(pred_eps, normal_vector)
    return loss
  
  def train(self, num_epochs, batch_size, lr, reg, validate_every, val_batch_size, len_validation):
    #training the UNet network of the DDPM
    optimizer=optim.AdamW(self.ddpm.parameters(), lr=lr, weight_decay=reg)
    train_loader=torch.utils.data.DataLoader(self.train_data, batch_size=batch_size, shuffle=True)
    update_steps=(len(self.train_data)//batch_size)*num_epochs if len(self.train_data)%batch_size==0 else (len(self.train_data)//batch_size+1)*num_epochs
    #horizontal flip operation
    horiz_flip=torchvision.transforms.RandomHorizontalFlip(p=0.5)
    pbar=tqdm(desc="DDPM Training Update Steps", total=update_steps)
    for epoch in range(1, num_epochs+1):
      for b_idx, batch_data in enumerate(train_loader):
        x_0, _=batch_data
        x_0=x_0.to(device)
        x_0=horiz_flip(x_0)
        x_0=self.ddpm.normalize(x_0) #normalize to range of [-1,1]
        
        loss=self.get_loss(x_0)
        self.train_logs['loss_history'].append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.update(1)
      if epoch%validate_every==0:
        #incep_score, fid=self.validate(epoch, val_batch_size, len_validation)
        #print(incep_score, fid)
        self.show_generation(epoch)
    #self.test()
    pbar.close()
    return self.train_logs

In [10]:
#get CIFAR 10 dataset
cifar_preprocess=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
cifar10_train=torchvision.datasets.CIFAR10(root='data/train', train=True, transform=cifar_preprocess, download=True)
small_cifar10_train=torch.utils.data.Subset(cifar10_train, np.arange(0, 5000, 1))
cifar10_test=torchvision.datasets.CIFAR10(root='data/test', train=False, transform=cifar_preprocess, download=True)
#use full data for training
small_cifar10={
    'train': small_cifar10_train,
    'test': cifar10_test
}
cifar10={
    'train': cifar10_train,
    'test': cifar10_test
}

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/train/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/train/cifar-10-python.tar.gz to data/train
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/test/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/test/cifar-10-python.tar.gz to data/test


In [None]:
trainer=DDPM_Trainer(data=cifar10, image_size=(32,32), num_train_steps=1000, ddim_steps=20, beta_low=1e-4, beta_high=0.2)
#incompatible with MPS
train_logs=trainer.train(
    num_epochs=100,
    batch_size=128,
    lr=2e-4,
    reg=0,
    validate_every=10,
    val_batch_size=128,
    len_validation=10000,
)