In [2]:
!pip install dalle-pytorch

Collecting dalle-pytorch
  Downloading dalle_pytorch-0.12.5-py3-none-any.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 4.2 MB/s eta 0:00:01
[?25hCollecting transformers
  Downloading transformers-4.6.1-py3-none-any.whl (2.2 MB)
[K     |████████████████████████████████| 2.2 MB 9.2 MB/s eta 0:00:01
Collecting taming-transformers
  Downloading taming_transformers-0.0.1-py3-none-any.whl (45 kB)
[K     |████████████████████████████████| 45 kB 6.2 MB/s  eta 0:00:01
[?25hCollecting g-mlp-pytorch
  Downloading g_mlp_pytorch-0.0.16-py3-none-any.whl (5.2 kB)
Collecting ftfy
  Downloading ftfy-6.0.3.tar.gz (64 kB)
[K     |████████████████████████████████| 64 kB 6.6 MB/s  eta 0:00:01
[?25hCollecting youtokentome
  Downloading youtokentome-1.0.6-cp36-cp36m-manylinux2010_x86_64.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 33.6 MB/s eta 0:00:01
[?25hCollecting einops>=0.3
  Downloading einops-0.3.0-py2.py3-none-any.whl (25 kB)
Collecting axial-positional-embe

In [5]:
# !pip install wandb
import wandb

In [9]:
wandb.login()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nakachi/.netrc


True

In [10]:
!pip install gdown

Collecting gdown
  Downloading gdown-3.13.0.tar.gz (9.3 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: gdown
  Building wheel for gdown (PEP 517) ... [?25ldone
[?25h  Created wheel for gdown: filename=gdown-3.13.0-py3-none-any.whl size=9034 sha256=1e6cde5ddb37f8ee1dcad6510e59a199f890ead72da76352f80d4e99b8643cc6
  Stored in directory: /home/nakachi/.cache/pip/wheels/6a/87/bd/09b16161b149fd6711ac76b5420d78ed58bd6a320e892117c3
Successfully built gdown
Installing collected packages: gdown
Successfully installed gdown-3.13.0


In [11]:
!gdown https://drive.google.com/uc?id=1vF8Ht0VThpobtmShD52_INhpIgy6eEXq
!gdown https://drive.google.com/uc?id=1kaIqFwTLD7Ml3ib9NQpjoUSD4FUD21-I

Downloading...
From: https://drive.google.com/uc?id=1vF8Ht0VThpobtmShD52_INhpIgy6eEXq
To: /home/nakachi/DALLE-pytorch/notebooks/CUB_200_2011.tgz
1.15GB [00:34, 33.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1kaIqFwTLD7Ml3ib9NQpjoUSD4FUD21-I
To: /home/nakachi/DALLE-pytorch/notebooks/birds.zip
613MB [00:20, 30.2MB/s] 


In [None]:
!rm -rf birds CUB_200_2011
!unzip birds.zip
!tar zxvf CUB_200_2011.tgz

In [None]:
import math
from math import sqrt

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

# vision imports

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle classes

from dalle_pytorch import DiscreteVAE

# constants

IMAGE_SIZE = 128
IMAGE_PATH = './'

EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 1e-3
LR_DECAY_RATE = 0.98

NUM_TOKENS = 8192
NUM_LAYERS = 2
NUM_RESNET_BLOCKS = 2
SMOOTH_L1_LOSS = False
EMB_DIM = 512
HID_DIM = 256
KL_LOSS_WEIGHT = 0

STARTING_TEMP = 1.
TEMP_MIN = 0.5
ANNEAL_RATE = 1e-6

NUM_IMAGES_SAVE = 4

# data

ds = ImageFolder(
    IMAGE_PATH,
    T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])
)

dl = DataLoader(ds, BATCH_SIZE, shuffle = True)

vae_params = dict(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim = EMB_DIM,
    hidden_dim   = HID_DIM,
    num_resnet_blocks = NUM_RESNET_BLOCKS
)

vae = DiscreteVAE(
    **vae_params,
    smooth_l1_loss = SMOOTH_L1_LOSS,
    kl_div_loss_weight = KL_LOSS_WEIGHT
).cuda()


assert len(ds) > 0, 'folder does not contain any images'
print(f'{len(ds)} images found for training')

def save_model(path):
    save_obj = {
        'hparams': vae_params,
        'weights': vae.state_dict()
    }

    torch.save(save_obj, path)

# optimizer

opt = Adam(vae.parameters(), lr = LEARNING_RATE)
sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)

# weights & biases experiment tracking

import wandb

model_config = dict(
    num_tokens = NUM_TOKENS,
    smooth_l1_loss = SMOOTH_L1_LOSS,
    num_resnet_blocks = NUM_RESNET_BLOCKS,
    kl_loss_weight = KL_LOSS_WEIGHT
)

run = wandb.init(
    project = 'dalle_train_vae',
    job_type = 'train_model',
    config = model_config
)

# starting temperature

global_step = 0
temp = STARTING_TEMP

for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(dl):
        images = images.cuda()

        loss, recons = vae(
            images,
            return_loss = True,
            return_recons = True,
            temp = temp
        )

        opt.zero_grad()
        loss.backward()
        opt.step()

        logs = {}

        if i % 100 == 0:
            k = NUM_IMAGES_SAVE

            with torch.no_grad():
                codes = vae.get_codebook_indices(images[:k])
                hard_recons = vae.decode(codes)

            images, recons = map(lambda t: t[:k], (images, recons))
            images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
            images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))

            logs = {
                **logs,
                'sample images':        wandb.Image(images, caption = 'original images'),
                'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
                'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
                'codebook_indices':     wandb.Histogram(codes),
                'temperature':          temp
            }

            save_model(f'./vae.pt')
            wandb.save('./vae.pt')

            # temperature anneal

            temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)

            # lr decay

            sched.step()

        if i % 10 == 0:
            lr = sched.get_last_lr()[0]
            print(epoch, i, f'lr - {lr:6f} loss - {loss.item()}')

            logs = {
                **logs,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item(),
                'lr': lr
            }

        wandb.log(logs)
        global_step += 1

    # save trained model to wandb as an artifact every epoch's end

    model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
    model_artifact.add_file('vae.pt')
    run.log_artifact(model_artifact)

# save final vae and cleanup

save_model('./vae-final.pt')
wandb.save('./vae-final.pt')

model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
model_artifact.add_file('vae-final.pt')
run.log_artifact(model_artifact)

wandb.finish()


In [None]:
from random import choice
from pathlib import Path

# torch

import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_

# vision imports

from PIL import Image
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import OpenAIDiscreteVAE, DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer

# helpers

def exists(val):
    return val is not None

# argument parsing

VAE_PATH = None   # './vae.pt' - will use OpenAIs pretrained VAE if not set
DALLE_PATH = None # './dalle.pt'
TAMING = False    # use VAE from taming transformers paper
IMAGE_TEXT_FOLDER = './'
BPE_PATH = None
RESUME = exists(DALLE_PATH)

EPOCHS = 20
BATCH_SIZE = 4
LEARNING_RATE = 3e-4
GRAD_CLIP_NORM = 0.5

MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DEPTH = 2
HEADS = 4
DIM_HEAD = 64
REVERSIBLE = True

# tokenizer

if BPE_PATH is not None:
    tokenizer = HugTokenizer(BPE_PATH)

# reconstitute vae

if RESUME:
    dalle_path = Path(DALLE_PATH)
    assert dalle_path.exists(), 'DALL-E model file does not exist'

    loaded_obj = torch.load(str(dalle_path))

    dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']

    if vae_params is not None:
        vae = DiscreteVAE(**vae_params)
    else:
        vae = OpenAIDiscreteVAE()

    dalle_params = dict(        
        **dalle_params
    )

    IMAGE_SIZE = vae.image_size

else:
    if exists(VAE_PATH):
        vae_path = Path(VAE_PATH)
        assert vae_path.exists(), 'VAE model file does not exist'

        loaded_obj = torch.load(str(vae_path))

        vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

        vae = DiscreteVAE(**vae_params)
        vae.load_state_dict(weights)
    else:
        print('using pretrained VAE for encoding images to tokens')
        vae_params = None

        vae_klass = OpenAIDiscreteVAE if not TAMING else VQGanVAE1024
        vae = vae_klass()

    IMAGE_SIZE = vae.image_size

    dalle_params = dict(
        num_text_tokens = tokenizer.vocab_size,
        text_seq_len = TEXT_SEQ_LEN,
        dim = MODEL_DIM,
        depth = DEPTH,
        heads = HEADS,
        dim_head = DIM_HEAD,
        reversible = REVERSIBLE
    )

# helpers

def save_model(path):
    save_obj = {
        'hparams': dalle_params,
        'vae_params': vae_params,
        'weights': dalle.state_dict()
    }

    torch.save(save_obj, path)

# dataset loading

class TextImageDataset(Dataset):
    def __init__(self, folder, text_len = 256, image_size = 128):
        super().__init__()
        path = Path(folder)

        text_files = [*path.glob('**/*.txt')]

        image_files = [
            *path.glob('**/*.png'),
            *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg')
        ]

        text_files = {t.stem: t for t in text_files}
        image_files = {i.stem: i for i in image_files}

        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}

        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.RandomResizedCrop(image_size, scale = (0.75, 1.), ratio = (1., 1.)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, ind):
        key = self.keys[ind]
        text_file = self.text_files[key]
        image_file = self.image_files[key]

        image = Image.open(image_file)
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        description = choice(descriptions)

        tokenized_text = tokenizer.tokenize(description).squeeze(0)
        mask = tokenized_text != 0

        image_tensor = self.image_tranform(image)
        return tokenized_text, image_tensor, mask

# create dataset and dataloader

ds = TextImageDataset(
    IMAGE_TEXT_FOLDER,
    text_len = TEXT_SEQ_LEN,
    image_size = IMAGE_SIZE
)

assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')

dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

# initialize DALL-E

dalle = DALLE(vae = vae, **dalle_params).cuda()

if RESUME:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr = LEARNING_RATE)

# experiment tracker

import wandb

model_config = dict(
    depth = DEPTH,
    heads = HEADS,
    dim_head = DIM_HEAD
)

run = wandb.init(project = 'dalle_train_transformer', resume = RESUME, config = model_config)

# training

for epoch in range(EPOCHS):
    for i, (text, images, mask) in enumerate(dl):
        text, images, mask = map(lambda t: t.cuda(), (text, images, mask))

        loss = dalle(text, images, mask = mask, return_loss = True)

        loss.backward()
        clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)

        opt.step()
        opt.zero_grad()

        log = {}

        if i % 10 == 0:
            print(epoch, i, f'loss - {loss.item()}')

            log = {
                **log,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item()
            }

        if i % 100 == 0:
            sample_text = text[:1]
            token_list = sample_text.masked_select(sample_text != 0).tolist()
            decoded_text = tokenizer.decode(token_list)

            image = dalle.generate_images(
                text[:1],
                mask = mask[:1],
                filter_thres = 0.9    # topk sampling at 0.9
            )

            save_model(f'./dalle.pt')
            wandb.save(f'./dalle.pt')

            log = {
                **log,
                'image': wandb.Image(image, caption = decoded_text)
            }

        wandb.log(log)

    # save trained model to wandb as an artifact every epoch's end

    model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))
    model_artifact.add_file('dalle.pt')
    run.log_artifact(model_artifact)

save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt')
model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))
model_artifact.add_file('dalle-final.pt')
run.log_artifact(model_artifact)

wandb.finish()


using pretrained VAE for encoding images to tokens


100%|████████████████████████| 215185363/215185363 [01:00<00:00, 3576736.81it/s]
100%|████████████████████████| 175360231/175360231 [00:32<00:00, 5408733.24it/s]


11788 image-text pairs found for training


[34m[1mwandb[0m: Currently logged in as: [33mnakachi-s[0m (use `wandb login --relogin` to force relogin)


0 0 loss - 9.411307334899902
0 10 loss - 9.197771072387695
0 20 loss - 9.0065336227417
0 30 loss - 8.774468421936035
0 40 loss - 8.47704792022705
0 50 loss - 8.048527717590332
0 60 loss - 7.711822032928467
0 70 loss - 7.688018321990967
0 80 loss - 7.687277317047119
0 90 loss - 7.8614726066589355
0 100 loss - 7.7845234870910645
0 110 loss - 7.805528163909912
0 120 loss - 7.855661869049072
0 130 loss - 7.56503963470459
0 140 loss - 7.761192321777344
0 150 loss - 7.647830486297607
0 160 loss - 7.597595691680908
0 170 loss - 7.574670314788818
0 180 loss - 7.8922014236450195
0 190 loss - 7.843993663787842
0 200 loss - 7.737320899963379
0 210 loss - 7.724544525146484
0 220 loss - 7.8991265296936035
0 230 loss - 7.631166934967041
0 240 loss - 7.627883434295654
0 250 loss - 7.477207660675049
0 260 loss - 7.693474292755127
0 270 loss - 7.832309246063232
0 280 loss - 7.734604358673096
0 290 loss - 7.75332498550415
0 300 loss - 7.706338882446289
0 310 loss - 7.584452152252197
0 320 loss - 7.60886