In [None]:
class Imagen(nn.Module):
    def __init__(
        self,
        unets,
        *,
        image_sizes,                                # for cascading ddpm, image size at each stage
        text_encoder_name = DEFAULT_T5_NAME,
        text_embed_dim = None,
        channels = 3,
        timesteps = 1000,
        cond_drop_prob = 0.1,
        loss_type = 'l2',
        noise_schedules = 'cosine',
        pred_objectives = 'noise',
        lowres_noise_schedule = 'linear',
        lowres_sample_noise_level = 0.2,            # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
        per_sample_random_aug_noise_level = False,  # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
        condition_on_text = True,
        auto_normalize_img = True,                  # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
        continuous_times = True,
        p2_loss_weight_gamma = 0.5,                 # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
        p2_loss_weight_k = 1,
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.9,      # unsure what this was based on perusal of paper
    ):
        super().__init__()

        # loss

        if loss_type == 'l1':
            loss_fn = F.l1_loss
        elif loss_type == 'l2':
            loss_fn = F.mse_loss
        elif loss_type == 'huber':
            loss_fn = F.smooth_l1_loss
        else:
            raise NotImplementedError()

        self.loss_type = loss_type
        self.loss_fn = loss_fn

        # conditioning hparams

        self.condition_on_text = condition_on_text
        self.unconditional = not condition_on_text

        # channels

        self.channels = channels

        # automatically take care of ensuring that first unet is unconditional
        # while the rest of the unets are conditioned on the low resolution image produced by previous unet

        unets = cast_tuple(unets)
        num_unets = len(unets)

        # determine noise schedules per unet

        timesteps = cast_tuple(timesteps, num_unets)

        # make sure noise schedule defaults to 'cosine', 'cosine', and then 'linear' for rest of super-resoluting unets

        noise_schedules = cast_tuple(noise_schedules)
        noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
        noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')

        # construct noise schedulers

        noise_scheduler_klass = GaussianDiffusion if not continuous_times else GaussianDiffusionContinuousTimes
        self.noise_schedulers = nn.ModuleList([])

        for timestep, noise_schedule in zip(timesteps, noise_schedules):
            noise_scheduler = noise_scheduler_klass(noise_schedule = noise_schedule, timesteps = timestep)
            self.noise_schedulers.append(noise_scheduler)

        # lowres augmentation noise schedule

        self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)

        # ddpm objectives - predicting noise by default

        self.pred_objectives = cast_tuple(pred_objectives, num_unets)

        # get text encoder

        self.text_encoder_name = text_encoder_name
        self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))

        # construct unets

        self.unets = nn.ModuleList([])

        for ind, one_unet in enumerate(unets):
            assert isinstance(one_unet, Unet)
            is_first = ind == 0

            one_unet = one_unet.cast_model_parameters(
                lowres_cond = not is_first,
                cond_on_text = self.condition_on_text,
                text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
                channels = self.channels,
                channels_out = self.channels,
                learned_sinu_pos_emb = continuous_times
            )

            self.unets.append(one_unet)

        # unet image sizes

        self.image_sizes = cast_tuple(image_sizes)
        assert num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'

        self.sample_channels = cast_tuple(self.channels, num_unets)

        # cascading ddpm related stuff

        lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
        assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'

        self.lowres_sample_noise_level = lowres_sample_noise_level
        self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level

        # classifier free guidance

        self.cond_drop_prob = cond_drop_prob
        self.can_classifier_guidance = cond_drop_prob > 0.

        # normalize and unnormalize image functions

        self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
        self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity

        # dynamic thresholding

        self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
        self.dynamic_thresholding_percentile = dynamic_thresholding_percentile

        # p2 loss weight

        self.p2_loss_weight_k = p2_loss_weight_k
        self.p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)

        assert all([(gamma_value <= 2) for gamma_value in self.p2_loss_weight_gamma]), 'in paper, they noticed any gamma greater than 2 is harmful'

        # one temp parameter for keeping track of device

        self.register_buffer('_temp', torch.tensor([0.]), persistent = False)

        # default to device of unets passed in

        self.to(next(self.unets.parameters()).device)
        

In [None]:
def exists(val):
    return val is not None

def identity(t, *args, **kwargs):
    return t

def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

In [None]:
def cast_tuple(val, length = None):
    if isinstance(val, list):
        val = tuple(val)

    output = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length):
        assert len(output) == length

    return output

In [None]:
def pad_tuple_to_length(t, length, fillvalue = None):
    remain_length = length - len(t)
    if remain_length <= 0:
        return t
    return (*t, *((fillvalue,) * remain_length))

In [None]:
noise_schedules = 'cosine'
num_unets       = 4

noise_schedules = cast_tuple(noise_schedules)
print(noise_schedules)
noise_schedules = pad_tuple_to_length(noise_schedules, 2, 'cosine')
print(noise_schedules)
noise_schedules = pad_tuple_to_length(noise_schedules, num_unets, 'linear')
print(noise_schedules)

In [None]:
noise_schedules = 'cosine'
# noise_schedules = (noise_schedules,)
print(noise_schedules)
noise_schedules = (noise_schedules, 'cosine')
print(noise_schedules)
n_unets = 4
mults = n_unets - len(noise_schedules) if n_unets - len(noise_schedules) > 0 else 0
noise_schedules = (*noise_schedules, *('linear',)*mults)
print(noise_schedules)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ComplexModels import UNet
from GaussianDiffusion import GaussianDiffusion
from TextEncoder import TextEncoderT5Based

In [2]:
def normalize_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_zero_to_one(normed_img):
    return (normed_img + 1) * 0.5

def resize_image_to(image, target_image_size):
    orig_image_size = image.shape[-1]

    if orig_image_size == target_image_size:
        return image

    scale_factors = target_image_size / orig_image_size
    return resize(image, scale_factors = scale_factors)

In [13]:
class Imagen(nn.Module):
    def __init__(
        self,
        unets,
        *,
        image_sizes,                                # for cascading ddpm, image size at each stage
        text_encoder_name = 'google/t5-v1_1-small',
        text_embed_dim = None,
        channels = 3,
        timesteps = 1000,
        cond_drop_prob = 0.1,
        loss_type = 'l2',
        noise_schedules = 'cosine',
        pred_objectives = 'noise',
        lowres_noise_schedule = 'linear',
        lowres_sample_noise_level = 0.2,            # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
        per_sample_random_aug_noise_level = False,  # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
        condition_on_text = True,
        auto_normalize_img = True,                  # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
        continuous_times = True,
        p2_loss_weight_gamma = 0.5,                 # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
        p2_loss_weight_k = 1,
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.9,      # unsure what this was based on perusal of paper
    ):
        super(Imagen, self).__init__()
        
        # loss

        self.loss_type   = 'l2'
        self.loss_fn     = F.mse_loss
        self.channels    = channels
        self.image_sizes = image_sizes

        # conditioning hparams

        self.condition_on_text, self.unconditional = True, False   
        
        self.lowres_sample_noise_level         = lowres_sample_noise_level
        self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level # False

        n_unets   = len(unets)
        timesteps = (timesteps,)*n_unets  
        
        noise_schedules = (noise_schedules, 'cosine')
        mults           = n_unets - len(noise_schedules) if n_unets - len(noise_schedules) > 0 else 0
        noise_schedules = (*noise_schedules, *('linear',)*mults)

        self.lowres_noise_schedule = GaussianDiffusion(noise_type=lowres_noise_schedule)
        self.pred_objectives       = (pred_objectives,)*n_unets

        self.text_encoder_name = text_encoder_name
        self.text_embed_dim    = TextEncoderT5Based(text_encoder_name).embed_dim
        self.text_encoder      = TextEncoderT5Based(text_encoder_name)     

        self.noise_schedulers = nn.ModuleList([])
        for timestep, noise_schedule in zip(timesteps, noise_schedules):
            noise_scheduler = GaussianDiffusion(noise_type=noise_schedule, timesteps=timestep)
            self.noise_schedulers.append(noise_scheduler)
            
        self.unets = nn.ModuleList([])        
        for i, current_unet in enumerate(unets):
            self.unets.append(current_unet.lowres_change(lowres_cond = not (i == 0)))

        self.sample_channels = (self.channels,)*n_unets

        lowres_conditions = tuple([t.lowres_cond for t in self.unets])
        # assert lowres_conditions == (False, *((True,) * (n_unets - 1)))

        self.cond_drop_prob          = cond_drop_prob
        self.can_classifier_guidance = cond_drop_prob > 0.

        self.normalize_img   = normalize_neg_one_to_one  # if auto_normalize_img else identity
        self.unnormalize_img = unnormalize_zero_to_one   # if auto_normalize_img else identity


        self.dynamic_thresholding            = (dynamic_thresholding,)*n_unets
        self.dynamic_thresholding_percentile = dynamic_thresholding_percentile

        self.p2_loss_weight_k     = p2_loss_weight_k
        self.p2_loss_weight_gamma = (p2_loss_weight_gamma,)*n_unets

        self.register_buffer('_temp', torch.tensor([0.]), persistent = False)

        self.to(next(self.unets.parameters()).device)
        
    @property
    def device(self):
        return self._temp.device

    def get_unet(self, unet_number):
        assert 0 < unet_number <= len(self.unets)
        index = unet_number - 1
        return self.unets[index]

    @contextmanager
    def one_unet_in_gpu(self, unet_number = None, unet = None):
        assert exists(unet_number) ^ exists(unet)

        if exists(unet_number):
            unet = self.get_unet(unet_number)

        self.cuda()

        devices = [module_device(unet) for unet in self.unets]
        self.unets.cpu()
        unet.cuda()

        yield

        for unet, device in zip(self.unets, devices):
            unet.to(device)

            
    # ========================================================================================================
    # ========================================================================================================
    # ========================================================================================================
    # ========================================================================================================
    
    def p_mean_variance(self, unet, x, t, noise_scheduler, text_embeds = None, text_mask = None,
                        cond_images = None, lowres_cond_img = None, lowres_noise_times = None,
                        cond_scale = 1., model_output = None, t_next = None, pred_objective = 'noise',
                        dynamic_threshold = True):
        # False = False or True
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

        pred = default(model_output, lambda: unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times)))

        if pred_objective == 'noise':
            x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
        elif pred_objective == 'x_start':
            x_start = pred
        else:
            raise ValueError(f'unknown objective {pred_objective}')

        if dynamic_threshold:
            # following pseudocode in appendix
            # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
            s = torch.quantile(
                rearrange(x_start, 'b ... -> b (...)').abs(),
                self.dynamic_thresholding_percentile,
                dim = -1
            )

            s.clamp_(min = 1.)
            s = right_pad_dims_to(x_start, s)
            x_start = x_start.clamp(-s, s) / s
        else:
            x_start.clamp_(-1., 1.)

        return noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
        

In [14]:
unet1 = UNet(dim = 8, cond_dim = 40, text_embed_dim = 40, dim_mults = (1, 2, 4, 8), num_resnet_blocks = 3,
             layer_attns = (False, True, True, True), layer_cross_attns = (False, True, True, True))

In [15]:
type(unet1)

ComplexModels.UNet

In [16]:
unet1.lowres_cond

False

In [None]:
Imagen((unet1,),
        image_sizes=(32,),                                # for cascading ddpm, image size at each stage
        text_encoder_name = 'google/t5-v1_1-small',
        text_embed_dim = None,
        channels = 3,
        timesteps = 1000,
        cond_drop_prob = 0.1,
        loss_type = 'l2',
        noise_schedules = 'cosine',
        pred_objectives = 'noise',
        lowres_noise_schedule = 'linear',
        lowres_sample_noise_level = 0.2,            # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
        per_sample_random_aug_noise_level = False,  # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
        condition_on_text = True,
        auto_normalize_img = True,                  # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
        continuous_times = True,
        p2_loss_weight_gamma = 0.5,                 # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
        p2_loss_weight_k = 1,
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.9,      # unsure what this was based on perusal of paper
    )

In [None]:
def lowers_change(self, lowres_cond):
    
        if lowres_cond == self.lowres_cond:
            return self
        return self.__class__(**{**self._locals, **dict(lowres_cond = lowres_cond)})

In [None]:
def p_mean_variance(self, unet, x, t, noise_scheduler, 
                    
                        text_embeds = None, text_mask = None,
                        cond_images = None, lowres_cond_img = None, lowres_noise_times = None,
                        cond_scale = 1., model_output = None, t_next = None, pred_objective = 'noise',
                        dynamic_threshold = True):

In [None]:
def p_sample(self, unet, x, t, *, noise_scheduler,
             
             t_next = None, text_embeds = None, text_mask = None,
             cond_images = None, cond_scale = 1., lowres_cond_img = None, lowres_noise_times = None,
             pred_objective = 'noise', dynamic_threshold = True):
        

In [None]:
def p_sample_loop(self, unet, shape, *, noise_scheduler, 
                  
                  lowres_cond_img = None,
                  lowres_noise_times = None, text_embeds = None, text_mask = None, cond_images = None,
                  cond_scale = 1, pred_objective = 'noise', dynamic_threshold = True):
        

In [None]:
def sample(
        self,
        texts: List[str] = None,
        text_masks = None,
        text_embeds = None,
        cond_images = None,
        batch_size = 1,
        cond_scale = 1.,
        lowres_sample_noise_level = None,
        stop_at_unet_number = None,
        return_all_unet_outputs = False,
        return_pil_images = False,
        device = None,
    )

In [None]:
p_losses(self, unet, x_start, times, *, noise_scheduler,
         lowres_cond_img = None, lowres_aug_times = None, text_embeds = None,
         text_mask = None, cond_images = None, noise = None, times_next = None,
         pred_objective = 'noise', p2_loss_weight_gamma = 0.)

In [None]:
def resize_image_to(image, target_image_size):
    orig_image_size = image.shape[-1]

    if orig_image_size == target_image_size:
        return image

    scale_factors = target_image_size / orig_image_size
    return resize(image, scale_factors = scale_factors)

In [None]:
def forward(self, images, texts: List[str], unet_number, cond_images = None):  

        unet_index = unet_number - 1        
        unet       = self.unets[unet_index]

        noise_scheduler      = self.noise_schedulers[unet_index]
        p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
        pred_objective       = self.pred_objectives[unet_index]
        target_image_size    = self.image_sizes[unet_index]
        prev_image_size      = self.image_sizes[unet_index - 1] if unet_index > 0 else None
        b, c, h, w, device,  = *images.shape, images.device

        times = noise_scheduler.sample_random_times(b, device = device)

        text_embeds, text_masks = self.text_encoder(texts)
        text_embeds = tuple([t.to(images.device) for t in text_embeds])
        text_masks  = tuple([t.to(images.device) for t in text_masks])

        assert not (text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension'

        lowres_cond_img = lowres_aug_times = None
        if prev_image_size is not None:      
            lowres_cond_img = resize_image_to(resize_image_to(images, prev_image_size), target_image_size)
            lowres_aug_time = repeat(self.lowres_noise_schedule.sample_random_times(1, device = device), '1 -> b', b = b)


        images = resize_image_to(images, target_image_size)
        
        '''p_losses(self, unet, x_start, times, text_embeds = None, text_mask = None, cond_images = None,
                    noise_scheduler,
                    lowres_cond_img = None, lowres_aug_times = None, 
                    ### noise = None, times_next = None,
                    pred_objective = 'noise', p2_loss_weight_gamma = 0.)'''

        return self.p_losses(unet, images, times, text_embeds=text_embeds, text_mask=text_masks,cond_images=cond_images,
                             noise_scheduler = noise_scheduler,
                             lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times,
                             pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma)


In [None]:
noise = default(noise, lambda: torch.randn_like(x_start))

        # normalize to [-1, 1]

        x_start = self.normalize_img(x_start)
        lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

        # get x_t

        x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)

        # also noise the lowres conditioning image
        # at sample time, they then fix the noise level of 0.1 - 0.3

        lowres_cond_img_noisy = None
        if exists(lowres_cond_img):
            lowres_aug_times = default(lowres_aug_times, times)
            lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))

        # get prediction

        pred = unet.forward(
            x_noisy,
            noise_scheduler.get_condition(times),
            text_embeds = text_embeds,
            text_mask = text_mask,
            cond_images = cond_images,
            lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
            lowres_cond_img = lowres_cond_img_noisy,
            cond_drop_prob = self.cond_drop_prob,
        )

        # prediction objective

        if pred_objective == 'noise':
            target = noise
        elif pred_objective == 'x_start':
            target = x_start
        else:
            raise ValueError(f'unknown objective {pred_objective}')

        # losses

        losses = self.loss_fn(pred, target, reduction = 'none')
        losses = reduce(losses, 'b ... -> b', 'mean')

        # p2 loss reweighting

        if p2_loss_weight_gamma > 0:
            loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
            losses = losses * loss_weight

        return losses.mean()


In [None]:
    p_losses(self, unet, x_start, times, text_embeds = None, text_mask = None, cond_images = None,
             noise_scheduler,
             lowres_cond_img = None, lowres_aug_times = None, 
             ### noise = None, times_next = None,
             pred_objective = 'noise', p2_loss_weight_gamma = 0.):
        
        
        noise = noise if noise is not None else torch.randn_like(x_start)
        noise = default(noise, lambda: )

        # get x_t
        x_start          = self.normalize_img(x_start)
        x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
        
        lowres_cond_img_noisy = None
        if lowres_cond_img is not None:      
            lowres_cond_img  = self.normalize_img(lowres_cond_img)            
            lowres_aug_times = lowres_aug_times if lowres_aug_times is not None else times
            
            lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
            
        pred = unet.forward(x_noisy,
                            noise_scheduler.get_condition(times),
                            text_embeds = text_embeds,
                            text_mask = text_mask)         
                            '''cond_images = cond_images,
                            lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
                            lowres_cond_img = lowres_cond_img_noisy,
                            cond_drop_prob = self.cond_drop_prob'''
            
        target = noise if pred_objective == 'noise' else x_start
        
        losses = self.loss_fn(pred, target, reduction = 'none')
        losses = reduce(losses, 'b ... -> b', 'mean')

        # p2 loss reweighting

        if p2_loss_weight_gamma > 0:
            loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
            losses = losses * loss_weight

        return losses.mean()


In [23]:
def exists(val):
    return val is not None

def identity(t, *args, **kwargs):
    return t

def maybe(fn):
    @wraps(fn)
    def inner(x):
        if not exists(x):
            return x
        return fn(x)
    return inner

In [25]:
from functools import partial, wraps

In [30]:
def cu(x):
    return x + 1

p = maybe(cu)(4)
print(p)

5


In [19]:
from einops_exts import check_shape

In [20]:
x = torch.rand((1,3,4,4))

check_shape(x, 'b c h w', c = 3)

tensor([[[[7.6171e-01, 1.0307e-01, 7.5614e-01, 7.5960e-01],
          [6.1990e-01, 4.3320e-01, 9.8969e-01, 5.4741e-02],
          [3.0489e-01, 6.5976e-01, 7.5253e-01, 3.9645e-01],
          [7.5082e-01, 3.0904e-01, 7.9639e-01, 3.8395e-01]],

         [[1.4804e-01, 1.4207e-02, 6.3807e-01, 2.8259e-04],
          [2.4826e-01, 6.7222e-02, 8.7710e-01, 9.0763e-01],
          [9.6059e-02, 8.9632e-01, 3.2495e-01, 3.5925e-01],
          [4.1697e-01, 3.3956e-01, 6.8137e-01, 1.4271e-01]],

         [[5.1636e-02, 2.3114e-01, 7.9665e-01, 8.4681e-01],
          [3.5812e-01, 5.6943e-01, 9.3415e-01, 4.5141e-01],
          [8.1132e-01, 3.7259e-01, 6.0285e-01, 9.4935e-01],
          [2.7806e-01, 7.9014e-01, 3.4780e-01, 3.0346e-01]]]])

In [21]:
x

tensor([[[[7.6171e-01, 1.0307e-01, 7.5614e-01, 7.5960e-01],
          [6.1990e-01, 4.3320e-01, 9.8969e-01, 5.4741e-02],
          [3.0489e-01, 6.5976e-01, 7.5253e-01, 3.9645e-01],
          [7.5082e-01, 3.0904e-01, 7.9639e-01, 3.8395e-01]],

         [[1.4804e-01, 1.4207e-02, 6.3807e-01, 2.8259e-04],
          [2.4826e-01, 6.7222e-02, 8.7710e-01, 9.0763e-01],
          [9.6059e-02, 8.9632e-01, 3.2495e-01, 3.5925e-01],
          [4.1697e-01, 3.3956e-01, 6.8137e-01, 1.4271e-01]],

         [[5.1636e-02, 2.3114e-01, 7.9665e-01, 8.4681e-01],
          [3.5812e-01, 5.6943e-01, 9.3415e-01, 4.5141e-01],
          [8.1132e-01, 3.7259e-01, 6.0285e-01, 9.4935e-01],
          [2.7806e-01, 7.9014e-01, 3.4780e-01, 3.0346e-01]]]])

In [None]:
class Imagen(nn.Module):
    def __init__(
        self,
        unets,
        *,
        image_sizes,                                # for cascading ddpm, image size at each stage
        text_encoder_name = 'google/t5-v1_1-small',
        text_embed_dim = None,
        channels = 3,
        timesteps = 1000,
        cond_drop_prob = 0.1,
        loss_type = 'l2',
        noise_schedules = 'cosine',
        pred_objectives = 'noise',
        lowres_noise_schedule = 'linear',
        lowres_sample_noise_level = 0.2,            # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
        per_sample_random_aug_noise_level = False,  # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
        condition_on_text = True,
        auto_normalize_img = True,                  # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
        continuous_times = True,
        p2_loss_weight_gamma = 0.5,                 # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time
        p2_loss_weight_k = 1,
        dynamic_thresholding = True,
        dynamic_thresholding_percentile = 0.9,      # unsure what this was based on perusal of paper
    ):
        super(Imagen, self).__init__()
        
        # loss

        self.loss_type   = 'l2'
        self.loss_fn     = F.mse_loss
        self.channels    = channels
        self.image_sizes = image_sizes

        # conditioning hparams

        self.condition_on_text, self.unconditional = True, False   
        
        self.lowres_sample_noise_level         = lowres_sample_noise_level
        self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level # False

        n_unets   = len(unets)
        timesteps = (timesteps,)*n_unets  
        
        noise_schedules = (noise_schedules, 'cosine')
        mults           = n_unets - len(noise_schedules) if n_unets - len(noise_schedules) > 0 else 0
        noise_schedules = (*noise_schedules, *('linear',)*mults)

        self.lowres_noise_schedule = GaussianDiffusion(noise_type=lowres_noise_schedule)
        self.pred_objectives       = (pred_objectives,)*n_unets

        self.text_encoder_name = text_encoder_name
        self.text_embed_dim    = TextEncoderT5Based(text_encoder_name).embed_dim
        self.text_encoder      = TextEncoderT5Based(text_encoder_name)     

        self.noise_schedulers = nn.ModuleList([])
        for timestep, noise_schedule in zip(timesteps, noise_schedules):
            noise_scheduler = GaussianDiffusion(noise_type=noise_schedule, timesteps=timestep)
            self.noise_schedulers.append(noise_scheduler)
            
        self.unets = nn.ModuleList([])        
        for i, current_unet in enumerate(unets):
            self.unets.append(current_unet.lowres_change(lowres_cond = not (i == 0)))

        self.sample_channels = (self.channels,)*n_unets

        lowres_conditions = tuple([t.lowres_cond for t in self.unets])
        # assert lowres_conditions == (False, *((True,) * (n_unets - 1)))

        self.cond_drop_prob          = cond_drop_prob
        self.can_classifier_guidance = cond_drop_prob > 0.

        self.normalize_img   = normalize_neg_one_to_one  # if auto_normalize_img else identity
        self.unnormalize_img = unnormalize_zero_to_one   # if auto_normalize_img else identity


        self.dynamic_thresholding            = (dynamic_thresholding,)*n_unets
        self.dynamic_thresholding_percentile = dynamic_thresholding_percentile

        self.p2_loss_weight_k     = p2_loss_weight_k
        self.p2_loss_weight_gamma = (p2_loss_weight_gamma,)*n_unets

        self.register_buffer('_temp', torch.tensor([0.]), persistent = False)

        self.to(next(self.unets.parameters()).device)
        
    @property
    def device(self):
        return self._temp.device

    def get_unet(self, unet_number):
        assert 0 < unet_number <= len(self.unets)
        index = unet_number - 1
        return self.unets[index]

    @contextmanager
    def one_unet_in_gpu(self, unet_number = None, unet = None):
        assert exists(unet_number) ^ exists(unet)

        if exists(unet_number):
            unet = self.get_unet(unet_number)

        self.cuda()

        devices = [module_device(unet) for unet in self.unets]
        self.unets.cpu()
        unet.cuda()

        yield

        for unet, device in zip(self.unets, devices):
            unet.to(device)
            
    def p_losses(self, unet, x_start, times, text_embeds = None, text_mask = None, cond_images = None,
             noise_scheduler,
             lowres_cond_img = None, lowres_aug_times = None, 
             ### noise = None, times_next = None,
             pred_objective = 'noise', p2_loss_weight_gamma = 0.):
        
        
        noise = noise if noise is not None else torch.randn_like(x_start)
        noise = default(noise, lambda: )

        # get x_t
        x_start          = self.normalize_img(x_start)
        x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
        
        lowres_cond_img_noisy = None
        if lowres_cond_img is not None:      
            lowres_cond_img  = self.normalize_img(lowres_cond_img)            
            lowres_aug_times = lowres_aug_times if lowres_aug_times is not None else times
            
            lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
            
        pred = unet.forward(x_noisy,
                            noise_scheduler.get_condition(times),
                            text_embeds = text_embeds,
                            text_mask = text_mask)         
                            '''cond_images = cond_images,
                            lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
                            lowres_cond_img = lowres_cond_img_noisy,
                            cond_drop_prob = self.cond_drop_prob'''
            
        target = noise if pred_objective == 'noise' else x_start
        
        losses = self.loss_fn(pred, target, reduction = 'none')
        losses = reduce(losses, 'b ... -> b', 'mean')

        # p2 loss reweighting

        if p2_loss_weight_gamma > 0:
            loss_weight = (self.p2_loss_weight_k + log_snr.exp()) ** -p2_loss_weight_gamma
            losses = losses * loss_weight

        return losses.mean()
    
    def forward(self, images, texts: List[str], unet_number, cond_images = None):  

        unet_index = unet_number - 1        
        unet       = self.unets[unet_index]

        noise_scheduler      = self.noise_schedulers[unet_index]
        p2_loss_weight_gamma = self.p2_loss_weight_gamma[unet_index]
        pred_objective       = self.pred_objectives[unet_index]
        target_image_size    = self.image_sizes[unet_index]
        prev_image_size      = self.image_sizes[unet_index - 1] if unet_index > 0 else None
        b, c, h, w, device,  = *images.shape, images.device

        times = noise_scheduler.sample_random_times(b, device = device)

        text_embeds, text_masks = self.text_encoder(texts)
        text_embeds = tuple([t.to(images.device) for t in text_embeds])
        text_masks  = tuple([t.to(images.device) for t in text_masks])

        assert not (text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension'

        lowres_cond_img = lowres_aug_times = None
        if prev_image_size is not None:      
            lowres_cond_img = resize_image_to(resize_image_to(images, prev_image_size), target_image_size)
            lowres_aug_time = repeat(self.lowres_noise_schedule.sample_random_times(1, device = device), '1 -> b', b = b)


        images = resize_image_to(images, target_image_size)
        
        '''p_losses(self, unet, x_start, times, text_embeds = None, text_mask = None, cond_images = None,
                    noise_scheduler,
                    lowres_cond_img = None, lowres_aug_times = None, 
                    ### noise = None, times_next = None,
                    pred_objective = 'noise', p2_loss_weight_gamma = 0.)'''

        return self.p_losses(unet, images, times, text_embeds=text_embeds, text_mask=text_masks,cond_images=cond_images,
                             noise_scheduler = noise_scheduler,
                             lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times,
                             pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma)


        

In [None]:
@torch.no_grad()
    @eval_decorator
    def sample(self, texts: List[str] = None, cond_images = None, batch_size = 1,cond_scale = 1.,
               lowres_sample_noise_level = None, stop_at_unet_number = None, return_all_unet_outputs = False,
               return_pil_images = False, device = 'cpu'):

        text_embeds, text_masks = self.text_encoder.textEncoder(texts)

        # NECESSÁRIO CORREÇÃO

        # text_embeds = [t.to(images.device) for t in text_embeds]
        # text_masks  = [t.to(images.device) for t in text_masks]
        text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
            
        batch_size = text_embeds.shape[0]

        assert not (text_embeds.shape[-1] != self.text_embed_dim)

        outputs = []

        is_cuda = next(self.parameters()).is_cuda
        device  = next(self.parameters()).device

        lowres_sample_noise_level = lowres_sample_noise_level if lowres_sample_noise_level is not None else self.lowres_sample_noise_level

        for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding)):

            context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()

            with context:
                lowres_cond_img = lowres_noise_times = None
                shape = (batch_size, channel, image_size, image_size)

                if unet.lowres_cond:
                    lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)

                    lowres_cond_img = resize_image_to(img, image_size)
                    lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))

                shape = (batch_size, self.channels, image_size, image_size)

                img = self.p_sample_loop(unet, shape, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images,
                                         cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times,
                                         noise_scheduler = noise_scheduler, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)

                outputs.append(img)

            if stop_at_unet_number is not None and stop_at_unet_number == unet_number:
                break

        output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs

        if not return_pil_images:
            return outputs[output_index]

        if not return_all_unet_outputs:
            outputs = outputs[-1:]

        pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))

        return pil_images[output_index]

In [None]:
    @torch.no_grad()
    def p_sample_loop(self, unet, shape, text_embeds = None, text_mask = None, cond_images = None,
                      cond_scale = 1, lowres_cond_img = None, lowres_noise_times = None,
                      noise_scheduler, pred_objective = 'noise', dynamic_threshold = True):
        
        device = self.device
        batch  = shape[0]
        img    = torch.randn(shape, device = device)
        
        lowres_cond_img = lowres_cond_img if is None else self.normalize_img(lowres_cond_img)
        timesteps       = noise_scheduler.get_sampling_timesteps(batch, device = device)

        for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps)):
            
            img = self.p_sample(unet, img, times, t_next = times_next, text_embeds = text_embeds, text_mask = text_mask,
                                cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img,
                                lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler, 
                                pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)

        img.clamp_(-1., 1.)
        unnormalize_img = self.unnormalize_img(img)
        return unnormalize_img


In [None]:
    @torch.no_grad()
    def p_sample(self, unet, x, t, t_next = None, text_embeds = None, text_mask = None,
                 cond_images = None, cond_scale = 1., lowres_cond_img = None,
                 lowres_noise_times = None, noise_scheduler = None,
                 pred_objective = 'noise', dynamic_threshold = True):
        
        # b, *_, device = *x.shape, x.device
        b, device = *x.shape[0], x.device
        
        model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, t_next = t_next, text_embeds = text_embeds, text_mask = text_mask,
                                                                 cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img,
                                                                 lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler,
                                                                 pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)
        
        noise = torch.randn_like(x)

        is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
        nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

In [None]:
    def p_mean_variance(self, unet, x, t, t_next = None, text_embeds = None, text_mask = None,
                        cond_images = None, cond_scale = 1., lowres_cond_img = None,
                        lowres_noise_times = None, noise_scheduler,  
                        pred_objective = 'noise', dynamic_threshold = True, model_output = None):
        
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

        pred = default(model_output, lambda: unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times)))

        x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if pred_objective == 'noise' else pred

        if dynamic_threshold:
            # following pseudocode in appendix
            # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
            s = torch.quantile(rearrange(x_start, 'b ... -> b (...)').abs(),
                               self.dynamic_thresholding_percentile,
                               dim = -1)

            s.clamp_(min = 1.)
            s = right_pad_dims_to(x_start, s)
            x_start = x_start.clamp(-s, s) / s
        else:
            x_start.clamp_(-1., 1.)

        return noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)

In [None]:
def forward_with_cond_scale(self, *args, cond_scale = 1., **kwargs):
    
        logits = self.forward(*args, **kwargs)
        if cond_scale == 1:
            return logits
        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

In [None]:
    def p_mean_variance(self, unet, x, t, t_next = None, text_embeds = None, text_mask = None,
                        cond_images = None, cond_scale = 1., lowres_cond_img = None,
                        lowres_noise_times = None, noise_scheduler,  
                        pred_objective = 'noise', dynamic_threshold = True, model_output = None):
        
        assert not (cond_scale != 1. and not self.can_classifier_guidance), 'imagen was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

        pred = default(model_output, lambda: unet.forward_with_cond_scale(x, noise_scheduler.get_condition(t), text_embeds = text_embeds, text_mask = text_mask, cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times)))

        x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if pred_objective == 'noise' else pred

        if dynamic_threshold:
            # following pseudocode in appendix
            # s is the dynamic threshold, determined by percentile of absolute values of reconstructed sample per batch element
            s = torch.quantile(rearrange(x_start, 'b ... -> b (...)').abs(),
                               self.dynamic_thresholding_percentile,
                               dim = -1)

            s.clamp_(min = 1.)
            s = right_pad_dims_to(x_start, s)
            x_start = x_start.clamp(-s, s) / s
        else:
            x_start.clamp_(-1., 1.)

        return noise_scheduler.q_posterior(x_start = x_start, x_t = x, t = t, t_next = t_next)
    
    @torch.no_grad()
    def p_sample(self, unet, x, t, t_next = None, text_embeds = None, text_mask = None,
                 cond_images = None, cond_scale = 1., lowres_cond_img = None,
                 lowres_noise_times = None, noise_scheduler = None,
                 pred_objective = 'noise', dynamic_threshold = True):
        
        # b, *_, device = *x.shape, x.device
        b, device = *x.shape[0], x.device
        
        model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, t_next = t_next, text_embeds = text_embeds, text_mask = text_mask,
                                                                 cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img,
                                                                 lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler,
                                                                 pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)
        
        noise = torch.randn_like(x)

        is_last_sampling_timestep = (t_next == 0) if isinstance(noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0)
        nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
    
    @torch.no_grad()
    def p_sample_loop(self, unet, shape, text_embeds = None, text_mask = None, cond_images = None,
                      cond_scale = 1, lowres_cond_img = None, lowres_noise_times = None,
                      noise_scheduler, pred_objective = 'noise', dynamic_threshold = True):
        
        device = self.device
        batch  = shape[0]
        img    = torch.randn(shape, device = device)
        
        lowres_cond_img = lowres_cond_img if is None else self.normalize_img(lowres_cond_img)
        timesteps       = noise_scheduler.get_sampling_timesteps(batch, device = device)

        for times, times_next in tqdm(timesteps, desc = 'sampling loop time step', total = len(timesteps)):
            
            img = self.p_sample(unet, img, times, t_next = times_next, text_embeds = text_embeds, text_mask = text_mask,
                                cond_images = cond_images, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img,
                                lowres_noise_times = lowres_noise_times, noise_scheduler = noise_scheduler, 
                                pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)

        img.clamp_(-1., 1.)
        unnormalize_img = self.unnormalize_img(img)
        return unnormalize_img
    
    @torch.no_grad()
    @eval_decorator
    def sample(self, texts: List[str] = None, cond_images = None, batch_size = 1,cond_scale = 1.,
               lowres_sample_noise_level = None, stop_at_unet_number = None, return_all_unet_outputs = False,
               return_pil_images = False, device = 'cpu'):

        text_embeds, text_masks = self.text_encoder.textEncoder(texts)

        # NECESSÁRIO CORREÇÃO

        # text_embeds = [t.to(images.device) for t in text_embeds]
        # text_masks  = [t.to(images.device) for t in text_masks]
        text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
            
        batch_size = text_embeds.shape[0]

        assert not (text_embeds.shape[-1] != self.text_embed_dim)

        outputs = []

        is_cuda = next(self.parameters()).is_cuda
        device  = next(self.parameters()).device

        lowres_sample_noise_level = lowres_sample_noise_level if lowres_sample_noise_level is not None else self.lowres_sample_noise_level

        for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding)):

            context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()

            with context:
                lowres_cond_img = lowres_noise_times = None
                shape = (batch_size, channel, image_size, image_size)

                if unet.lowres_cond:
                    lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)

                    lowres_cond_img = resize_image_to(img, image_size)
                    lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))

                shape = (batch_size, self.channels, image_size, image_size)

                img = self.p_sample_loop(unet, shape, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images,
                                         cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_times = lowres_noise_times,
                                         noise_scheduler = noise_scheduler, pred_objective = pred_objective, dynamic_threshold = dynamic_threshold)

                outputs.append(img)

            if stop_at_unet_number is not None and stop_at_unet_number == unet_number:
                break

        output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs

        if not return_pil_images:
            return outputs[output_index]

        if not return_all_unet_outputs:
            outputs = outputs[-1:]

        pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))

        return pil_images[output_index]
