In [8]:
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.t5 import t5_encode_text
from imagen_pytorch.data import Dataset
import random
from tqdm import tqdm
import numpy as np
from einops import repeat
import torchvision.transforms as T
from torch.special import expm1
import matplotlib.pyplot as plt

np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
unet = Unet(
    dim = 128, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
    cond_dim = 256,
    dim_mults = (1, 2, 4), # the channel dimensions inside the model (multiplied by dim)
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = (False, True, True)
)

In [11]:
dim_mults = (1, 2, 4, 8)

unet = Unet(
  dim = 128, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
  cond_dim = 128,
  dim_mults = dim_mults, # the channel dimensions inside the model (multiplied by dim)
  num_resnet_blocks = 3,
  layer_attns = (False,) + (True,) * (len(dim_mults) - 1),
  layer_cross_attns = (False,) + (True,) * (len(dim_mults) - 1)
)

imagen = Imagen(
    unets = unet,
    image_sizes = 32,
    timesteps = 250,
    cond_drop_prob = .1
).cuda()

trainer = ImagenTrainer(imagen, lr=1e-4)

In [3]:
imagen = Imagen(
    unets = unet,
    image_sizes = 32,
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()


In [4]:
# images = imagen.sample(texts = [
#     'a whale breaching from afar',
#     # 'young girl blowing out candles on her birthday cake',
#     # 'fireworks with blue and green sparkles'
# ], cond_scale = 3)

# img = T.ToPILImage()(images[0])
# img.save("whale.png")

In [5]:
encodings, mask = t5_encode_text(['a whale breaching from afar'], return_attn_mask=True)
print(encodings.shape)

torch.Size([1, 10, 768])


In [6]:
def sample(imagen, unet, cond_scale, text_embeds, text_mask, device, use_tqdm=True):
    batch_size = text_embeds.shape[0]
    noise_scheduler = imagen.noise_schedulers[0]
    img = torch.randn((batch_size, imagen.sample_channels[0], imagen.image_sizes[0], imagen.image_sizes[0]), device = device)
    timesteps = noise_scheduler.get_sampling_timesteps(batch_size, device = device)
        
    for times, times_next in tqdm(timesteps, desc='sampling loop time step', total=len(timesteps), disable=not use_tqdm):       
        
        img, _ = imagen.p_sample(
            unet,
            img,
            times,
            t_next = times_next,
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_scale = cond_scale,
            noise_scheduler = noise_scheduler,
            pred_objective = imagen.pred_objectives[0],
            dynamic_threshold = imagen.dynamic_thresholding
        )
        
    img.clamp_(-1., 1.)

    unnormalize_img = imagen.unnormalize_img(img)

    return img

In [7]:
out = sample(
    imagen, 
    unet, 
    cond_scale=1., 
    text_embeds=encodings,
    text_mask=mask,
    device=device
)

sampling loop time step: 100%|█████████████████████████████████████████████████████| 1000/1000 [00:39<00:00, 25.52it/s]


In [151]:
img = T.ToPILImage()(out[0])
img.save("whale7.png")

In [None]:

# @torch.jit.script
# def beta_linear_log_snr(t):
#     return -torch.log(expm1(1e-4 + 10 * (t ** 2)))

# def log_snr_to_alpha_sigma(log_snr):
#     return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr))

#     def q_posterior(self, x_start, x_t, t, *, t_next = None):
#         """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """
#         log_snr = self.log_snr(t)
#         log_snr_next = self.log_snr(t_next)
#         log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))

#         alpha, sigma = log_snr_to_alpha_sigma(log_snr)
#         alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)

#         # c - as defined near eq 33
#         c = -expm1(log_snr - log_snr_next)
#         posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)

#         # following (eq. 33)
#         posterior_variance = (sigma_next ** 2) * c
#         posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20)
#         return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
# def normalize_neg_one_to_one(img):
#     return img * 2 - 1

# def unnormalize_zero_to_one(normed_img):
#     return (normed_img + 1) * 0.5

# def get_sampling_timesteps(batch, num_timesteps, device):
#     times = torch.linspace(1., 0., num_timesteps + 1, device = device)
#     times = repeat(times, 't -> b t', b = batch)
#     times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
#     times = times.unbind(dim = -1)
#     return times

# def right_pad_dims_to(x, t):
#     padding_dims = x.ndim - t.ndim
#     if padding_dims <= 0:
#         return t
#     return t.view(*t.shape, *((1,) * padding_dims))

# def predict_start_from_noise(self, x_t, t, noise):
#     log_snr = beta_linear_log_snr(t)
#     log_snr = right_pad_dims_to(x_t, log_snr)
#     alpha, sigma = log_snr_to_alpha_sigma(log_snr)
#     return (x_t - sigma * noise) / alpha.clamp(min = 1e-8)

# def predict_start_from_v(self, x_t, t, v):
#     log_snr = beta_linear_log_snr(t)
#     log_snr = right_pad_dims_to(x_t, log_snr)
#     alpha, sigma = log_snr_to_alpha_sigma(log_snr)
#     return alpha * x_t - sigma * v


# def get_sampling_timesteps(batch, num_timesteps, device):
#     times = torch.linspace(1., 0., num_timesteps + 1, device = device)
#     times = repeat(times, 't -> b t', b = batch)
#     times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
#     times = times.unbind(dim = -1)
#     return times


#         pred = unet.forward_with_cond_scale(
#             img, 
#             beta_linear_log_snr(t), 
#             text_embeds = text_embeds, 
#             text_mask = text_mask,
#             cond_scale = cond_scale,
#         )
        
        
#         if pred_objective == 'noise':
#             x_start = predict_start_from_noise(x, t = times, noise = pred)
#         elif pred_objective == 'x_start':
#             x_start = pred
#         elif pred_objective == 'v':
#             x_start = predict_start_from_v(x, t = t, v = pred)
#         else:
#             raise ValueError(f'unknown objective {pred_objective}')

        
#         if dynamic_threshold:
#             # following pseudocode in appendix
#             # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
#             s = torch.quantile(
#                 rearrange(x_start, 'b ... -> b (...)').abs(),
#                 self.dynamic_thresholding_percentile,
#                 dim = -1
#             )

#             s.clamp_(min = 1.)
#             s = right_pad_dims_to(x_start, s)
#             x_start = x_start.clamp(-s, s) / s
#         else:
#             x_start.clamp_(-1., 1.)
        
        
#         (model_mean, _, model_log_variance) = q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
        
#         noise = torch.randn_like(x)
            
#         is_last_sampling_timestep = (times_next == 0) 
#         nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(batch_size, *((1,) * (len(x.shape) - 1)))
                    
#         img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        
#     img.clamp_(-1., 1.)
    
    
#         # finally un-normalize