# My Notes

📜 [Zero-Shot Text-to-Image Generation](https://arxiv.org/pdf/2102.12092)

> Text-to-image generation has traditionally focused on finding better modeling assumptions for training on a fixed dataset.

> We describe a simple approach for this task based on a transformer that auto-regressively models the text and image tokens as a single stream of data. With sufficient data and scale, our approach is competitive with previous domain-specific models when evaluated in a zero-shot fashion.

Instead of focusing on inductive bias to improve image modeling, they instead focus on data and scale - and as usual, it works!

> Recent advances fueled by large-scale generative models suggest a possible route for further improvements [in text-to-image modeling]. Specifically, when compute, model size, and data are scaled carefully, auto-regressive transformers have achieved impressive results in several domains such as text, images, and audio.

> Could dataset size and model size be the limiting factor of current approaches?

> In this work, we demonstrate that training a 12-billion parameter autoregressive transformer on 250 million image-text pairs collected from the internet results in a flexible, high fidelity generative model of images controllable through natural language.

> The resulting system achieves high quality image generation on the popular MS-COCO dataset zero-shot, without using any of the training labels.

They apply the same scaling hypothesis here to text-to-image models, and once again, get SoTA results with this hypothesis, creating a model that can perform well on previous datasets zero-shot without even training on them.

### Method

> Our goal is to train a transformer to auto-regressively model the text and image tokens as a single stream of data.

> However, using pixels directly as image tokens would require an inordinate amount of memory for high-resolution images.

> We address these issues by using a two-stage training procedure:

**Stage 1:** We train a discrete variational auto-encoder (dVAE) to compress each 256×256 RGB image into a 32 × 32 grid of image tokens, each element of which can assume 8192 possible values. This reduces the context size of the transformer.

**Stage 2:** We concatenate up to 256 BPE-encoded text tokens with the $32 \times 32$ image tokens, and train an autoregressive transformer to model the joint distribution over the text and image tokens.

>

They’re getting creative here. Using the strategies from VQ-VAE to compress the image farther so the context is smaller, and then use this to create image patch tokens like with ViT to send to the transformer - where the word and image tokens can all attend to each other!

> We can model the overall procedure as maximizing the evidence lower bound (ELB) on the joint likelihood of the model distribution over images $x$, captions $y$ and the tokens $z$ for the encoded RGB image.

$$
\ln p_{\theta,\psi}(x,y) \geqslant \mathbb{E}_{z \sim q_{\phi}(z | x)} (\ln p_\theta(x|y,z) - \beta D_{KL} (q_\phi(y, z|x), p_{\psi}(y, z))
$$

Here, we model the ELB with $p_{\theta,\psi}(x,y)$ representing the target probability to minimize - the probability of a given image $x$ given that we’re provided with the caption $y$.

This can be minimized by taking the KL divergence between the probability of caption $y$ and the tokens $z$ from the auto-encoder given the original image $x$ (this is the probability we have in training runs) - $q_\phi(y, z|x)$, with the joint probability of a specific caption and image tokens appearing together over the distribution of the model $p_\psi(y,z)$ (in the transformer?).

In other words, we want the probability of given image tokens appearing with a caption given a specific image to be the same probability as just the tokens and caption appearing together (since the tokens should be a lossless representation of the image).

The second term with the KL divergence allows us to minimize the difference between these distributions, zeroing out the term, which will contribute to maximizing the ELB.

Similarly the $\ln p_\theta (x|y,z)$ term allows the model to maximize the probability of generating the correct image $x$ given the caption $y$ and compressed image representation $z$.

Critically, the expectation is sampling $z \sim q_\phi(z, x)$ indicating the distribution over the most probable $z$ values given $x$ - so this entire ELB allows the VAE to improve the sampling of $z$ via this distribution, such that the KL divergence is minimized.

**1. Stage One: Learning the Visual Codebook**

> In the first stage of training, we maximize the ELB with respect to $\phi$ and $\theta$, which corresponds to training a dVAE on the images alone.

They first focused just on the distributions $q_\phi$ of $32 \times 32$ image tokens generated by the dVAE given the image $x^2$, and the distribution $p_\theta$ which is the distribution over the RGB images generated by the dVAE decoder given the image tokens.

In practice, this means focusing on optimizing the encoder & decoder stages to compress down and then re-generate the original images.

> The ELB now becomes difficult to optimize: as $q_\psi$ is a discrete distribution, and we cannot use the re-parameterization gradient to maximize it.

Because DALL E represents images with discrete rather than continuous data (it uses a grid of values which can assume exactly 8192 values), sampling from a continuous distribution between the encoder and the decoder as customarily done no longer works.

This is because using $\sigma$ and $\mu$ in this space would result in sampling jumps to different tokens, since variance in this subspace just implies skipping to different tokens (since values are discrete).

Given that this space is discrete, it’s also not differentiable, as fractional gradients have no meaning here.

> We instead use the gumbel-softmax relation, replacing the expectation over $q_\phi$ with one over $q_\phi^\tau$, where the relaxation becomes tight as the temperature $\tau \rarr 0$.

Instead of outputting a $\sigma$ and $\mu$ from the encoder to sample with, the model instead outputs a set of logits of the scores for each of the 8192 possible tokens at each position.

Then, gumbel noise is added to these scores to simulate the randomness effect, and the softmax of these scores is taken with a temperature value $\tau$ to control the softnening of this function.

This process, called the gumbel-softmax relation, creates a continuous and differentiable function simulating sampling for our discrete tokens, which can be used in the VAE.

> The likelihood for $p_\theta$ is evaluated using the log-laplace distribution.

> We also found that increasing the KL weight to β = 6.6 promotes better codebook usage and ultimately leads to a _smaller_ reconstruction error at the end of training.

They maximize the weight of the KL divergence term in the ELB, which allows the auto-encoder to ensure that each mapping of image → image tokens is relatively unique, so it maintains all the important information in the compression.

**2. Stage 2: Learning the Prior**

> In the second stage, we fix $\phi$ and $\theta$, and learn the prior distribution over the text and image tokens by maximizing the ELB with respect to $\psi$.

Now we fix the distributions of the image to image token compression, and the image tokens back to the image, and we focus on the joint distribution of image tokens with text.

> Given a text-image pair, we BPE-encode the lowercased caption using at most 256 tokens with vocabulary size $16,384$, and encode the image using $32 \times 32 = 1024$ tokens with vocabulary size $8192$.

> The image tokens are obtained using argmax sampling from the dVAE encoder logits, without adding any gumbel noise.

The gumbel noise was used during training, but is not actually needed during usage of the dVAE - the logits can just be used directly by picking the most likely token for each part of the image.

> The transformer is a decoder-only model in which each image token can attend to all text tokens in any one of its 64 self-attention layers.

> Instead, we opt to learn a special padding token separately for each of the 256 text positions.

They use a padding token (which should carry no information) to fill out the remaining spots in the max 256 length image description, since each input should have the same number of text and image tokens.

**3. Data Collection**

> To scale up to 12-billion parameters, we created a dataset of a similar scale to JFT-300M by collecting 250 million text-images pairs from the internet.

**4. Mixed-Precision Training**

> To save GPU memory and increase throughput, most parameters, Adam moments, and activations are stored in 16-bit precision.

First mention I’ve seen of low-level compute details including floating point precisions used.

> Getting the model to train in 16-bit precision past one billion parameters, without diverging, was the most challenging part of this project. We believe the root cause of this instability to be underflow in the 16-bit gradients.

Here we hit an actual engineering challenge discussed in the paper.

**5. Distributed Optimization**

> Our 12-billion parameter model consumes about 24 GB of memory when stored in 16-bit precision, which exceeds the memory of a 16 GB NVIDIA V100 GPU. We address this using parameter sharding.

> Parameter sharding allows us to almost completely hide the latency of the intra-machine communication by overlapping it with compute-intensive operations.

**6. Sample Generation**

> We rerank the samples drawn from the transformer using a pre-trained contrastive model. Given a caption and a candidate image, the contrastive model assigns a score based on how well the image matches the caption.

> Training the transformer on the tokens from the dVAE encoder allows us to allocate its modeling capacity to the low-frequency information that makes images visually recognizable to us.

> However, it also disadvantages the model, since the heavy compression renders it unable to produce high-frequency details.

### Experiments

**1. Quantitative Results**

> Given a caption, the sample from our model receives the majority vote for better matching the caption 93% of the time. It also receives the majority vote for being more realistic 90% of the time.

**2. Qualitative Results**

> We found that our model has the ability to generalize in ways that we did not originally anticipate. […] It has developed a rudimentary ability to compose unusual concepts at high levels of abstraction.

> Our model also appears to be capable of combinatorial generalization, such as when rendering text or when probed on sentences like “an illustration of a baby hedgehog in a Christmas sweater walking a dog.”

> To a limited degree of reliability, we also find our model to be capable of zero-shot image-to-image translation controllable by natural language.

Here’s the beginning of editing images with text - the model can update existing images/complete them with captions.

> This works with several other kinds of transformations.

### Conclusion

> We investigate a simple approach for text-to-image generation based on an autoregressive transformer, when it is executed at scale.

> Our findings suggest that improving generalization as a function of scale may be a useful driver for progress on this task.
