In [None]:
!pip install wandb
!wandb login

Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/33/ae/79374d2b875e638090600eaa2a423479865b7590c53fb78e8ccf6a64acb1/wandb-0.10.22-py2.py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 5.9MB/s 
[?25hCollecting subprocess32>=3.5.3
[?25l  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)
[K     |████████████████████████████████| 102kB 13.9MB/s 
[?25hCollecting docker-pycreds>=0.4.0
  Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl
Collecting GitPython>=1.0.0
[?25l  Downloading https://files.pythonhosted.org/packages/a6/99/98019716955ba243657daedd1de8f3a88ca1f5b75057c38e959db22fb87b/GitPython-3.1.14-py3-none-any.whl (159kB)
[K     |████████████████████████████████| 163kB 51.4MB/s 
Collecting sentry-sdk>=0.4.0
[?25l  Down

In [None]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [None]:
import os
os.listdir("/gdrive/My Drive/DALLE")

['dataset_for_dalle_2nd_training.zip',
 'dataset_middle_three_images_tr.zip',
 'dalle_model']

In [1]:
## download the zip file from the below link and upload it into session storage
## https://drive.google.com/file/d/19Bhlmz4WHPMHMBvyNqsPqX3O7Jyl1DoA/view?usp=sharing

In [None]:
!unzip "dataset_middle_three_images_tr.zip"

In [None]:
len(os.listdir("dataset_middle_three_images_tr"))

10254

In [None]:
data_folder_location = "dataset_middle_three_images_tr"

In [None]:
import torch
torch.cuda.empty_cache()


In [None]:
from random import choice
from pathlib import Path

# torch

import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_

# vision imports

from PIL import Image
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import OpenAIDiscreteVAE, DiscreteVAE, DALLE
from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE

# helpers

def exists(val):
    return val is not None

# argument parsing

VAE_PATH = None # './vae.pt' - will use OpenAIs pretrained VAE if not set
DALLE_PATH = None # './dalle.pt'
IMAGE_TEXT_FOLDER = data_folder_location
RESUME = exists(DALLE_PATH)

# constants

EPOCHS = 20
BATCH_SIZE = 5
LEARNING_RATE = 3e-4
GRAD_CLIP_NORM = 0.5

MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DEPTH = 2
HEADS = 4
DIM_HEAD = 64

# reconstitute vae

if RESUME:
    dalle_path = Path(DALLE_PATH)
    assert dalle_path.exists(), 'DALL-E model file does not exist'

    loaded_obj = torch.load(str(dalle_path))

    dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
    vae_params = OpenAIDiscreteVAE()
    # vae = DiscreteVAE(**vae_params)

    dalle_params = dict(
        vae = vae,
        **dalle_params
    )

    IMAGE_SIZE = vae_params['image_size']
else:
    if exists(VAE_PATH):
        vae_path = Path(VAE_PATH)
        assert vae_path.exists(), 'VAE model file does not exist'

        loaded_obj = torch.load(str(vae_path))

        vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

        vae = DiscreteVAE(**vae_params)
        vae.load_state_dict(weights)
    else:
        print('using OpenAIs pretrained VAE for encoding images to tokens')
        vae_params = None

        vae = OpenAIDiscreteVAE()

    IMAGE_SIZE = vae.image_size

    dalle_params = dict(
        num_text_tokens = VOCAB_SIZE,
        text_seq_len = TEXT_SEQ_LEN,
        dim = MODEL_DIM,
        depth = DEPTH,
        heads = HEADS,
        dim_head = DIM_HEAD
    )

# helpers

def save_model(path):
    save_obj = {
        'hparams': dalle_params,
        'vae_params': vae_params,
        'weights': dalle.state_dict()
    }

    torch.save(save_obj, path)

# dataset loading

class TextImageDataset(Dataset):
    def __init__(self, folder, text_len = 256, image_size = 128):
        super().__init__()
        path = Path(folder)

        text_files = [*path.glob('**/*.txt')]

        image_files = [
            *path.glob('**/*.png'),
            *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg')
        ]

        text_files = {t.stem: t for t in text_files}
        image_files = {i.stem: i for i in image_files}

        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}

        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.CenterCrop(image_size),
            T.Resize(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, ind):
        key = self.keys[ind]
        text_file = self.text_files[key]
        image_file = self.image_files[key]

        image = Image.open(image_file)
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        description = choice(descriptions)

        tokenized_text = tokenize(description).squeeze(0)
        mask = tokenized_text != 0

        image_tensor = self.image_tranform(image)
        return tokenized_text, image_tensor, mask

# create dataset and dataloader

ds = TextImageDataset(
    IMAGE_TEXT_FOLDER,
    text_len = TEXT_SEQ_LEN,
    image_size = IMAGE_SIZE
)

assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')

dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)


using OpenAIs pretrained VAE for encoding images to tokens


100%|███████████████████████| 215185363/215185363 [00:03<00:00, 64215417.40it/s]
100%|███████████████████████| 175360231/175360231 [00:02<00:00, 62358381.65it/s]


5127 image-text pairs found for training


In [None]:

# initialize DALL-E

dalle = DALLE(vae = vae, **dalle_params).cuda()

if RESUME:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr = LEARNING_RATE)

# experiment tracker

import wandb

wandb.config.depth = DEPTH
wandb.config.heads = HEADS
wandb.config.dim_head = DIM_HEAD

wandb.init(project = 'dalle_train_transformer_100_images_20_epoch_2desc', resume = RESUME)

# training

for epoch in range(EPOCHS):
    for i, (text, images, mask) in enumerate(dl):
        text, images, mask = map(lambda t: t.cuda(), (text, images, mask))

        loss = dalle(text, images, mask = mask, return_loss = True)

        loss.backward()
        clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)

        opt.step()
        opt.zero_grad()

        log = {}

        if i % 10 == 0:
            print(epoch, i, f'loss - {loss.item()}')

            log = {
                **log,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item()
            }

        if i % 100 == 0:
            sample_text = text[:1]
            token_list = sample_text.masked_select(sample_text != 0).tolist()
            decoded_text = tokenizer.decode(token_list)

            image = dalle.generate_images(
                text[:1],
                mask = mask[:1],
                filter_thres = 0.9    # topk sampling at 0.9
            )

            save_model(f'./dalle.pt')
            wandb.save(f'./dalle.pt')

            log = {
                **log,
                'image': wandb.Image(image, caption = decoded_text)
            }

        wandb.log(log)

save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt')
wandb.finish()


In [2]:
### After training run the below file for generating the image from the trained model
model_path= "" ####### Specify the model path here

In [None]:
import argparse
from pathlib import Path
from tqdm import tqdm

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024, DALLE
from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE



# load DALL-E

dalle_path = Path(model_path)

assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')

dalle_params.pop('vae', None) # cleanup later

# if vae_params is not None:
#     vae = DiscreteVAE(**vae_params)
# elif not args.taming:
vae = OpenAIDiscreteVAE()


dalle = DALLE(vae = vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# generate images

image_size = vae.image_size

text = tokenize(["a cup of coffee"], dalle.text_seq_len).cuda()

text = repeat(text, '() n -> b n', b = 128)

outputs = []

for text_chunk in tqdm(text.split(5), desc = 'generating images'):
    output = dalle.generate_images(text_chunk, filter_thres = 0.9)
    outputs.append(output)

outputs = torch.cat(outputs)



100%|███████████████████████| 215185363/215185363 [00:06<00:00, 32433374.38it/s]
100%|███████████████████████| 175360231/175360231 [00:09<00:00, 18546900.04it/s]
generating images:  96%|█████████▌| 25/26 [21:54<00:52, 52.07s/it]

In [None]:
os.mkdir("outputs")
# save all images

outputs_dir = Path("./outputs")
outputs_dir.mkdir(parents = True, exist_ok = True)

for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
    save_image(image, outputs_dir / f'{i}.jpg')

print(f'created 128 images at "{str(outputs_dir)}"')