In [25]:
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset
import random
from tqdm import tqdm
import numpy as np
import torchvision.transforms as T
from imagen_pytorch.t5 import t5_encode_text
import math
import types
from functools import partial

np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [2]:
unet = Unet(
    dim = 128, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
    cond_dim = 128,
    dim_mults = (1, 2, 4), # the channel dimensions inside the model (multiplied by dim)
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = (False, True, True)
)

imagen = Imagen(
    unets = unet,
    image_sizes = 32,
    timesteps = 250,
    cond_drop_prob = 0.1,
    dynamic_thresholding=False
)


trainer = ImagenTrainer(imagen)
trainer.load("giddy-capybara.ckpt")
pass

loading saved imagen at version 1.17.1, but current package version is 1.18.1
checkpoint loaded from giddy-capybara.ckpt


In [26]:
# Replacing all instances of "exmp1" to avoid:
# UnsupportedOperatorError: Exporting the operator ::expm1 to ONNX opset version 13 is not supported

def faux_expm1(x):
    return torch.exp(x) * (1 - torch.exp(-x))

def beta_linear(t):
    return -torch.log(faux_expm1(1e-4 + 10 * (t ** 2))) # <-------------- Replacing expm1

def right_pad_dims_to(x, t):
    padding_dims = x.ndim - t.ndim
    if padding_dims <= 0:
        return t
    return t.view(*t.shape, *((1,) * padding_dims))

def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def log_snr_to_alpha_sigma(log_snr):
    return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))

def log(t, eps: float = 1e-12):
    return torch.log(t.clamp(min = eps))

def q_posterior(self, x_start, x_t, t, *, t_next = None):
    t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.))

    """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
    log_snr = self.log_snr(t)
    log_snr_next = self.log_snr(t_next)
    log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

    alpha, sigma = log_snr_to_alpha_sigma(log_snr)
    alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

    # c - as defined near eq 33
    c = -faux_expm1(log_snr - log_snr_next) # <-------------- Replacing expm1
    posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

    # following (eq. 33)
    posterior_variance = (sigma_next ** 2) * c
    posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
    return posterior_mean, posterior_variance, posterior_log_variance_clipped


scheduler = imagen.noise_schedulers[0]
scheduler.log_snr = beta_linear
scheduler.q_posterior = types.MethodType(q_posterior, scheduler)

In [27]:
# out = trainer.sample(["a photo of a truck"])
# print(out)
# img = T.ToPILImage()(out[0])
# img.save("truck-4.png")
# img

In [28]:
# out = imagen.sample(text_embeds=enc, text_masks=mask)
# print(out)
# img = T.ToPILImage()(out[0])
# img.save("truck-4.png")
# img

In [29]:
class ImagenOnnx(torch.nn.Module):
    def __init__(self, unet, imagen):
        super().__init__()
        
        if imagen.dynamic_thresholding == True or any(imagen.dynamic_thresholding):
            raise ValueError("no dynamic thresholding allowed, got:", imagen.dynamic_thresholding)

        self.unet = unet
        self.imagen = imagen    
            

    def sample(self, text_embeds, text_mask, cond_scale,  device, use_tqdm=True):
        batch_size = text_embeds.shape[0]
        noise_scheduler = self.imagen.noise_schedulers[0]
        img = torch.randn((batch_size, self.imagen.sample_channels[0], self.imagen.image_sizes[0], self.imagen.image_sizes[0]), device = device)        
        timesteps = noise_scheduler.get_sampling_timesteps(batch_size, device = device)
            
        
        for times, times_next in tqdm(timesteps, desc='sampling loop time step', total=len(timesteps), disable=not use_tqdm):       
            self_cond = x_start if self.unet.self_cond else None
            
            img, x_start = self.imagen.p_sample(
                self.unet,
                img,
                times,
                t_next = times_next,
                text_embeds = text_embeds,
                text_mask = text_mask,
                cond_scale = cond_scale,
                noise_scheduler = noise_scheduler,
                pred_objective = self.imagen.pred_objectives[0],
                dynamic_threshold = False
            )

        img.clamp_(-1., 1.)

        unnormalize_img = self.imagen.unnormalize_img(img)

        return unnormalize_img
    
    def forward(self, img, text_embeds, text_mask, times, times_next, cond_scale):
        return self.imagen.p_sample(
            self.unet,
            img,
            times,
            t_next = times_next,
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_scale = cond_scale,
            noise_scheduler=imagen.noise_schedulers[0],
            pred_objective = imagen.pred_objectives[0],
            dynamic_threshold = False
        )
    
u = ImagenOnnx(unet, imagen)

In [30]:
# enc, mask = t5_encode_text(["a photo of a truck"], return_attn_mask = True)
# out = u.sample(enc, mask, 1., torch.device('cuda'), use_tqdm=False)
# print(out)
# img = T.ToPILImage()(out[0])
# img.save("truck-12.png")
# img

In [31]:
# out = u.forward(
#     torch.rand(1, 3, 32, 32).cuda(),  
#     torch.rand(1, 27, 768).cuda(), 
#     torch.ones(1, 27, dtype=bool).cuda(), 
#     torch.Tensor([0.9]).cuda(), 
#     torch.Tensor([0.896]).cuda(), 
#     torch.Tensor([1.]).cuda()
# )

In [32]:
# you will get a bunch of ONNX warnings, it's ok I checked them and they're all not an issue
torch.onnx.export(
    u, 
    (
        torch.rand(1, 3, 32, 32).cuda(),  
        torch.rand(1, 27, 768).cuda(), 
        torch.ones(1, 27, dtype=bool).cuda(), 
        torch.Tensor([0.5]).cuda(), 
        torch.Tensor([0.56]).cuda(), 
        torch.Tensor([1.1]).cuda() # cond scale of exactly 1 causes the ONNX graph to skip some vital things, so use 1.1
    ), 
    "toymodel/public/unet-32.onnx", 
    input_names=['image', 'text_embeds', 'text_mask', 'timestep', 'time_next', 'cond_scale'], 
    output_names=['prediction', 'x_start'], # you can basically ignore x_start
    dynamic_axes={
        'image': {0: 'batch_size'},
        'text_embeds': {0: 'batch_size', 1: 'n_tokens'},
        'text_mask': {0: 'batch_size', 1: 'n_tokens'},
    }
)