# DALL-E - Generating Images from Text

* [DALLE-pytorch](https://github.com/lucidrains/DALLE-pytorch)

In [1]:
COLAB = True

In [4]:
# --- Remember to change the runtime to use GPU for better performance
if COLAB:
    !nvidia-smi
    !git clone https://github.com/alpha2phi/jupyter-notebooks.git
    
!pip install -Uqq dalle-pytorch

Thu Mar 25 00:57:52 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
import logging
import torch
import numpy as np
import math
import itertools
import os
import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
%config InlineBackend.figure_format = 'retina'
from PIL import Image
from tqdm.autonotebook import tqdm, trange
import torch
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

In [7]:
# Load ImageNet labels
captions = []

if COLAB:
  with open("jupyter-notebooks/nbs/test_data/imagenet_labels.txt", "r") as f:
      captions = [[s.strip(' ""\r\n')] for s in f.readlines()]
else:
  with open("test_data/imagenet_labels.txt", "r") as f:
      captions = [[s.strip(' ""\r\n')] for s in f.readlines()]

logging.info(captions)

2021-03-25 00:58:17 INFO     [['tench, Tinca tinca'], ['goldfish, Carassius auratus'], ['great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias'], ['tiger shark, Galeocerdo cuvieri'], ['hammerhead, hammerhead shark'], ['electric ray, crampfish, numbfish, torpedo'], ['stingray'], ['cock'], ['hen'], ['ostrich, Struthio camelus'], ['brambling, Fringilla montifringilla'], ['goldfinch, Carduelis carduelis'], ['house finch, linnet, Carpodacus mexicanus'], ['junco, snowbird'], ['indigo bunting, indigo finch, indigo bird, Passerina cyanea'], ['robin, American robin, Turdus migratorius'], ['bulbul'], ['jay'], ['magpie'], ['chickadee'], ['water ouzel, dipper'], ['kite'], ['bald eagle, American eagle, Haliaeetus leucocephalus'], ['vulture'], ['great grey owl, great gray owl, Strix nebulosa'], ['European fire salamander, Salamandra salamandra'], ['common newt, Triturus vulgaris'], ['eft'], ['spotted salamander, Ambystoma maculatum'], ['axolotl, mud puppy, Ambystoma mex

In [8]:
# Load OpenAI VAE
vae = OpenAIDiscreteVAE().cuda()

100%|███████████████████████| 215185363/215185363 [00:14<00:00, 15153619.13it/s]
100%|████████████████████████| 175360231/175360231 [00:21<00:00, 8188420.75it/s]


In [9]:
all_words = list(sorted(frozenset(sum(captions, []))))
word_tokens = dict(zip(all_words, range(1, len(all_words) + 1)))
caption_tokens = [[word_tokens[w] for w in c] for c in captions]

logging.info(f"{all_words}")
logging.info(f"{len(word_tokens)}")
logging.info(len(caption_tokens))

longest_caption = max(len(c) for c in captions)
captions_array = np.zeros((len(caption_tokens), longest_caption), dtype=np.int64)
for i in range(len(caption_tokens)):
    captions_array[i, : len(caption_tokens[i])] = caption_tokens[i]

captions_array = torch.from_numpy(captions_array).cuda()
captions_mask = captions_array != 0

dalle = DALLE(
    dim = 1024,
    vae = vae,                                 # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = len(word_tokens) + 1,    # vocab size for text
    text_seq_len = longest_caption,            # text sequence length
    depth = 8,                 # should aim to be 64
    heads = 2,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
).cuda()

2021-03-25 00:59:04 INFO     ['Afghan hound, Afghan', 'African chameleon, Chamaeleo chamaeleon', 'African crocodile, Nile crocodile, Crocodylus niloticus', 'African elephant, Loxodonta africana', 'African grey, African gray, Psittacus erithacus', 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 'Airedale, Airedale terrier', 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', 'American alligator, Alligator mississipiensis', 'American black bear, black bear, Ursus americanus, Euarctos americanus', 'American chameleon, anole, Anolis carolinensis', 'American coot, marsh hen, mud hen, water hen, Fulica americana', 'American egret, great white heron, Egretta albus', 'American lobster, Northern lobster, Maine lobster, Homarus americanus', 'Angora, Angora rabbit', 'Appenzeller', 'Arabian camel, dromedary, Camelus dromedarius', 'Arctic fox, white fox, Alopex lagopus', 'Australian terrier', 'Band Aid', 'Bedlington terrier', 'Ber

In [10]:
def generate_images(captions):
    generated_images = []
    with torch.no_grad():
        for i in trange(0, len(captions), 128):
            generated = dalle.generate_images(captions_array[i:i + 128, ...], mask=captions_mask[i:i + 128, ...], temperature=0.00001)
            generated_images.append(generated)
            break
          
    return torch.cat(generated_images, axis=0).cpu().numpy()

## Generate Images

In [None]:
%%timeit
generated_images = generate_images("blue bear")

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

In [None]:
print(generated_images.shape)
images = generated_images[0,...].reshape(256,256,3)
plt.imshow(images)