Adapted from https://github.com/lucidrains/muse-maskgit-pytorch/

All models take in precomputed tokens, and not images/text, as much as possible, and various unused settings were removed. The HighRes model now has a separate embedding layer for the conditioning image tokens.

Structure was modified to easily swap sequence models.

In [1]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn

import torch.nn.functional as F

In [2]:
get_encoded_dim(DEFAULT_T5_NAME)

NameError: name 'get_encoded_dim' is not defined

In [3]:
from random import random

In [4]:
from muse_maskgit_pytorch.t5 import t5_encode_text, DEFAULT_T5_NAME, get_encoded_dim

from einops import rearrange

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def encode_text(texts):
  return t5_encode_text(DEFAULT_T5_NAME)

These functions are used to generate a boolean tensor of a given shape where each element has a probablity `prob` of being `True`. Used for classifier free guidance, and taken from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py

In [6]:
def uniform(shape, min = 0, max = 1, device = None):
    return torch.zeros(shape, device = device).float().uniform_(0, 1)

def prob_mask_like(shape, prob, device = None):
  if prob == 1:
    return torch.ones(shape, device = device, dtype = torch.bool)
  if prob == 0:
    return torch.zeros(shape, device = device, dtype = torch.bool)
  return uniform(shape, device = device) < prob

This function takes a mask tensor, and, if the tensor contains $n$ True values, then roughly `prob %` of those values will stay True and the rest will be switched to False. This corresponds to unmasking those particular tokens.

From https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py

In [7]:
def get_mask_subset_prob(mask, prob, min_mask = 0):
    batch, seq, device = *mask.shape, mask.device
    num_to_mask = (mask.sum(dim = -1, keepdim = True) * prob).clamp(min = min_mask)
    logits = torch.rand((batch, seq), device = device)
    logits = logits.masked_fill(~mask, -1)

    randperm = logits.argsort(dim = -1).argsort(dim = -1).float()

    num_padding = (~mask).sum(dim = -1, keepdim = True)
    randperm -= num_padding

    subset_mask = randperm < num_to_mask
    subset_mask.masked_fill_(~mask, False)
    return subset_mask


Gumbel distribution. I call this softmax with noise. Also from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py

In [8]:
def log(t, eps = 1e-20):
  return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
  """
  0 < noise < 1
  clamp:
  1e-20 < noise < 1
  -46 < log(noise) < 0
  0 < -log(noise) < 46
  clamp:
  1e-20 < -log(noise) < 46
  -46 < log(-log(noise)) < 3.8
  -3.8 < -log(-log(noise)) < 46

  in reality around -2 < output < 4 due to slow increase of log

  """
  noise = torch.zeros_like(t).uniform_(0,1)
  return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) 

Class that implements a generic sequence model that takes in prepared image tokens and prepared text tokens (as in, passed through embedding layers and whatever else may be needed), as well as a `sequence_model` that must accept a tuple of `(image_tokens, text_tokens)` as input and return an output of the same shape as `image_tokens`.

The shapes of `image_tokens` and `text_tokens` must be equal to `(batch_size, image_token_count, dim)`.

In [9]:
class SequenceModelWrapper(nn.Module):
  def __init__(
    self,
    sequence_model,
    sequence_model_token_size,
    token_value_count, # how many values can an image token take
    sequence_length, # how many tokens is each image represented with
    is_high_resolution # whether the model is a super res model, conditioned on lower res tokens
  ):
    super().__init__()
    self.sequence_model = sequence_model

    # takes an input of shape (batch_size, indice_count) that contains the token values for each image
    # and embedds it to shape (batch_size, indice_count, sequence_model_token_size)
    # the embedding layer associates a value of size equal to sequence_model_token_size
    # to each of the possible token values; the number of token values is +1
    # because we have to take the MASK token into account as well
    self.token_emb = nn.Embedding(token_value_count + 1, sequence_model_token_size)

    # token values span from 0 to token_value_count - 1, so the next value
    # is assigned to be the mask token
    self.mask_id = token_value_count

    if is_high_resolution:
      # embedding layer for the lowres image
      # this image doesn't require a MASK token
      self.lowres_token_emb = nn.Embedding(token_value_count, sequence_model_token_size)

    # associates a vector of size sequence_model_token_size to each position in the 
    # sequence of image tokens
    self.pos_emb = nn.Embedding(sequence_length, sequence_model_token_size)

    # the sequence model outputs the same shape sequence
    # it represents each token as a vector of size sequence_model_token_size
    # so we project that to size token_value_count to obtain
    # a probability distribution of each token
    self.to_logits = nn.Linear(sequence_model_token_size, token_value_count)

    self.encode_text = encode_text

    text_embed_size = get_encoded_dim(DEFAULT_T5_NAME)

    # text token embed size must be equal to image token size
    self.text_embed_proj = nn.Linear(text_embed_size, sequence_model_token_size, bias = False)

  def forward(
    self,
    x, # image tokens, preflattened
    text_embeds, # precomputed text tokens as returned by T5
    return_embed_only, # return sequence model output before passing it through to_logits; used for self critic
    # option image tokens generated by a lower resolution model
    # precomputed and preflattened
    return_loss_only,
    conditioning_image_ids = None,
    labels = None,
    ignore_index = 0,
    cond_drop_prob = 0.,
  ):
    device, batch_size, n = x.device, *x.shape
    
    context = self.text_embed_proj(text_embeds)

    # a text embedding is masked if all its entries are equal to 0
    context_mask = (text_embeds != 0).any(dim = -1)

    # classifier free guidance
    # drop the conditioning text tokens for a fraction of the mini batches
    mask = prob_mask_like((batch_size, 1), 1 - cond_drop_prob, device)
    context_mask = context_mask & mask

    if conditioning_image_ids is not None:
      # passed through the same embedding as the main image sequence
      cond_token_emb = self.lowres_token_emb(conditioning_image_ids)

      # concatenate the 2 conditioning sequencing
      # resulting in a longer sequence
      context = torch.cat((context, cond_token_emb), dim = -2)
      
      # pad the context mask with True for the newly added conditioning tokens
      context_mask = F.pad(context_mask, (0, conditioning_image_ids.shape[-1]), value = True)

    x = self.token_emb(x)
    x = x + self.pos_emb(torch.arange(n, device = device))

    embed = self.sequence_model(x, context = context, context_mask = context_mask)

    if return_embed_only:
      return embed
    
    logits = self.to_logits(embed)

    if labels is None:
      return embed, logits

    if self.dim_out == 1:
      # loss for self-token-critic
      loss = F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)
    else:
      # loss for normal model
      loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = ignore_index)

    if return_loss_only:
      return loss

    return logits

  def forward_with_cond_scale(
    self,
    x, # image tokens, preflattened
    text_embeds, # precomputed text tokens as returned by T5
    return_embed_only, # return sequence model output before passing it through to_logits; used for self critic
    return_loss_only,
    # option image tokens generated by a lower resolution model
    # precomputed and preflattened
    conditioning_image_ids = None,
    labels = None,
    ignore_index = 0,
    cond_drop_prob = 0.,
    cond_scale = 3.,
  ):
    args = [text_embeds, return_embed_only, return_loss_only]
    kw_args = dict(conditioning_image_ids = conditioning_image_ids, labels = labels, ignore_index = ignore_index, cond_drop_prob = cond_drop_prob)

    if cond_scale == 1:
      return self.forward(x, *args, **kw_args, cond_drop_prob = 0.)
    
    logits, embed = self.forward(x, *args, **kw_args, cond_drop_prob = 1., return_embed = True)

    null_logits = self.forward(x, *args, cond_drop_prob = 1., **kw_args)

    return null_logits + (logits - null_logits) * cond_scale, embed

MaskedModel takes in image tokens, masks them according to a noise schedule and passes them to a sequence model wrapper.

In [10]:
class TokenCritic(nn.Module):
  def __init__(
      self, net
  ):
    self.net = net
    self.to_pred = nn.Linear(net.sequence_model_token_size)

  def forward(self, x, *args, labels = None, **kwargs):
    embeds = self.net.forward_with_cond_scale(x, *args, return_only_embed = True, **kwargs)
    logits = self.to_pred(embeds)

    logits = rearrange(logits, "... 1 -> ...")
    return F.binary_cross_entropy_with_logits(logits, labels)

In [11]:
class MaskedModel(nn.Module):
  def __init__(
    self,
    sequence_model,
    noise_schedule,
    no_mask_token_prob = 0.,
  ):
    super().__init__()

    self.sequence_model = sequence_model
    self.mask_id = sequence_model.mask_id
    self.noise_schedule = noise_schedule

    self.token_critic = TokenCritic(self.sequence_model)

    # probability for some of the masked tokens to be unmasked
    self.no_mask_token_prob = no_mask_token_prob

  def forward(
    self,
    image_ids, # assumed to already be flattened
    ignore_index = -1,
    conditioning_token_ids = None, # assumed to already be flattened
    text_embeds = None,
    cond_drop_prob = None,
  ):
    batch, seq_len, device = *image_ids.shape, image_ids.device

    # pick a random time for the noise scheduler
    # leading to a random number of tokens to be masked
    # each mini batch (each image) gets its own mask probability
    rand_time = uniform((batch, ), device = device)
    rand_mask_probs = self.noise_schedule(rand_time)
    num_token_masked = (seq_len * rand_mask_probs).round().clamp(min = 1)

    mask_id = self.mask_id

    # random permutation of the tokens
    # without this permutation, the first tokens in the sequence would always be masked
    # and the last tokens would almost never be masked
    batch_randperm = torch.rand((batch, seq_len), device = device).argsort(dim = -1)
    mask = batch_randperm < rearrange(num_token_masked, 'b -> b 1')

    # when computing the loss, we only care about the loss resulting from
    # the masked tokens (that the sequence model must unmask)
    # hence we mark the unmasked position with ignore_index
    # which the cross entropy loss will know to ignore 
    labels = torch.where(mask, image_ids, ignore_index)

    if self.no_mask_token_prob > 0.:
      no_mask_mask = get_mask_subset_prob(mask, self.no_mask_token_prob)
      # the function get_mask_subset_prob keeps no_mask_token_prob % of the tokens as True
      # those tokens are no longer masked because True & !True = False
      mask &= ~no_mask_mask

    x = torch.where(mask, mask_id, image_ids)

    ce_loss, logits = self.sequence_model.forward_with_cond_scale(
      x,
      text_embeds = text_embeds,
      conditioning_token_ids = conditioning_token_ids,
      labels = labels,
      cond_drop_prob = cond_drop_prob,
      ignore_index = ignore_index,
      return_logits = True
    )

    if self.token_critic is None:
      return ce_loss
    
    # normally we would apply softmax to obtain the predicted token value
    # however for training the token critic, a noisy softmax is used to choose
    # the predicted values
    sampled_ids = gumbel_sample(logits, temperature = random())
    
    # the masked tokens are unmasked, the rest stay correct
    critic_input = torch.where(mask, sampled_ids, x)

    # True if predicted tokens matched ground truth, false otherwise
    critic_labels = (image_ids != critic_input).float()

    # token critic is passed the predicted tokens
    # and compares them to the groundtruth in critic_labels
    bce_loss = self.token_critic(
      critic_input,
      text_embeds = text_embeds,
      conditioning_token_ids = conditioning_token_ids,
      labels = critic_labels,
      cond_drop_prob = cond_drop_prob
    )

    return ce_loss + bce_loss    