[Leo]

The code help you to explore the basics of Dall-E and make sure it's working. 

First, install the lib via pip: pip install dalle-pytorch

## DALL-E and VAE

Initialize a variable auto encoder, pass it into the dall-E model to __init__.

It will downloda the model at first run.

In [1]:
import torch
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

vae = OpenAIDiscreteVAE()       # loads pretrained OpenAI VAE

dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 1,                  # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1,           # feedforward dropout
    reversible = False          # setting this to True will allow you to use a big network without memory costs, but a 2x computation cost
    attn_types = ('full', 'axial_row', 'axial_col', 'conv_like')  # cycles between these four types of attention, can also pick one
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()

100%|████████████████████████| 215185363/215185363 [00:31<00:00, 6750386.17it/s]
100%|████████████████████████| 175360231/175360231 [00:24<00:00, 7199733.14it/s]


Above is all what you need to train a Dall-E model.

Note that OpenAIDiscreteVAE is pre-trained, DiscreteVAE is not pretrained.

Dall-E itself is not pre-trained. There's no offical weights.

## CLIP
CLIP is what OpenAI uses to filter the output of DALL-E, only picking the good ones to present.

CLIP needs to be passed into the DALL-E model at inference time.

CLIP needs to be trained as well. There's no offical weights.

In [3]:
from dalle_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 10000,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    num_visual_tokens = 512,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
)

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()