In [None]:
#@title Setup

%pip install dalle-pytorch
%pip install wandb

%pip install gdown
!wget https://github.com/aria2/aria2/releases/download/release-1.35.0/aria2-1.35.0.tar.gz
!tar -xf aria2-1.35.0.tar.gz
%cd "aria2-1.35.0"
!./configure && make && make install
%cd /content

# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
#@title Download COCO2017, VirtualGenome, and WIT_380k
!wget https://www.dropbox.com/s/qrufasneaduttmc/wit_en_small_380k.tar.gz
!wget https://www.dropbox.com/s/dtjjz9cpenmpowr/train2017.zip
!wget https://www.dropbox.com/s/30sdd7f9m2nuii2/virtual_genome_captions.tar.gz

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# !git clone "https://github.com/afiaka87/dalle-pytorch-datasets.git"

In [None]:
image_text_folder = "/content/output" #@param
import wandb
!wandb login

In [None]:
import argparse
from random import choice, sample
from pathlib import Path

# torch

import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_

# vision imports

from PIL import Image
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE1024, DiscreteVAE, DALLE
from dalle_pytorch.simple_tokenizer import tokenize, tokenizer, VOCAB_SIZE

# argument parsing

image_text_folder = "/content/output" #@param
taming = True #@param 

# helpers

def exists(val):
    return val is not None

# constants

VAE_PATH = None
DALLE_PATH = None # 'dalle.pt'
RESUME = exists(DALLE_PATH)

EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 3e-4
GRAD_CLIP_NORM = 0.8

MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DEPTH = 6
HEADS = 12
DIM_HEAD = 64
REVERSIBLE = False

# reconstitute vae

if RESUME:
    dalle_path = Path(DALLE_PATH)
    assert dalle_path.exists(), 'DALL-E model file does not exist'

    loaded_obj = torch.load(str(dalle_path), map_location='cpu')

    dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']

    if vae_params is not None:
        vae = DiscreteVAE(**vae_params)
    else:
        vae_klass = OpenAIDiscreteVAE if not taming else VQGanVAE1024
        vae = vae_klass()
        
    dalle_params = dict(        
        **dalle_params
    )
    IMAGE_SIZE = vae.image_size
else:
    if exists(VAE_PATH):
        vae_path = Path(VAE_PATH)
        assert vae_path.exists(), 'VAE model file does not exist'

        loaded_obj = torch.load(str(vae_path))

        vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

        vae = DiscreteVAE(**vae_params)
        vae.load_state_dict(weights)
    else:
        print('using pretrained VAE for encoding images to tokens')
        vae_params = None

        vae_klass = OpenAIDiscreteVAE if not taming else VQGanVAE1024
        vae = vae_klass()

    IMAGE_SIZE = vae.image_size

    dalle_params = dict(
        num_text_tokens = VOCAB_SIZE,
        text_seq_len = TEXT_SEQ_LEN,
        dim = MODEL_DIM,
        depth = DEPTH,
        heads = HEADS,
        dim_head = DIM_HEAD,
        reversible = REVERSIBLE,
        attn_types = ('axial_row', 'axial_col', 'conv_like')
    )

# helpers

def save_model(path):
    save_obj = {
        'hparams': dalle_params,
        'vae_params': vae_params,
        'weights': dalle.state_dict()
    }

    torch.save(save_obj, path)

# dataset loading

class TextImageDataset(Dataset):
    def __init__(self, folder, text_len = 256, image_size = 128):
        super().__init__()
        path = Path(folder)

        text_files = [*path.glob('**/*.txt')]

        image_files = [
            *path.glob('**/*.png'),
            *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg')
        ]

        text_files = {t.stem: t for t in text_files}
        image_files = {i.stem: i for i in image_files}

        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}

        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.RandomResizedCrop(image_size, scale = (0.9, 1.), ratio = (1., 1.)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, ind):
        key = self.keys[ind]
        text_file = self.text_files[key]
        image_file = self.image_files[key]

        image = Image.open(image_file)
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        description = choice(descriptions)
        description = description.replace("  ", " ")

        if len(tokenizer.encode(text)) >= 256:
          print(f"Caption at idx {ind} too long. Selecting a hundred words.")
          description = " ".join(description.replace("  ", " ").split(" ")[:100])


        try:
          tokenized_text = tokenize(description).squeeze(0)
        except Exception:
          print(f"Tokenized text failed at index {ind}. Returning {ind+1} instead.")
          if ind < self.__len__() - 1:
            return self.__getitem__(ind+1)
          else:
            return self.__getitem__(ind-1)

        mask = tokenized_text != 0
        image_tensor = self.image_tranform(image)
        return tokenized_text, image_tensor, mask

# create dataset and dataloader

ds = TextImageDataset(
    image_text_folder,
    text_len = TEXT_SEQ_LEN,
    image_size = IMAGE_SIZE
)

assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')

dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

# initialize DALL-E

dalle = DALLE(vae = vae, **dalle_params).cuda()

if RESUME:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr = LEARNING_RATE)

# experiment tracker

import wandb

model_config = dict(
    depth = DEPTH,
    heads = HEADS,
    dim_head = DIM_HEAD
)

run = wandb.init(project = 'dalle-pytorch-datasets', resume = RESUME, config = model_config)

# training

for epoch in range(EPOCHS):
    for i, (text, images, mask) in enumerate(dl):
        text, images, mask = map(lambda t: t.cuda(), (text, images, mask))

        loss = dalle(text, images, mask = mask, return_loss = True)

        loss.backward()
        clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)

        opt.step()
        opt.zero_grad()

        log = {}

        if i % 10 == 0:
            print(epoch, i, f'loss - {loss.item()}')

            log = {
                **log,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item()
            }

        if i % 100 == 0:
            sample_text = text[:1]
            token_list = sample_text.masked_select(sample_text != 0).tolist()
            decoded_text = tokenizer.decode(token_list)

            image = dalle.generate_images(
                text[:1],
                mask = mask[:1],
                filter_thres = 0.9    # topk sampling at 0.9
            )

            save_model(f'./dalle.pt')
            wandb.save(f'./dalle.pt')

            log = {
                **log,
                'image': wandb.Image(image, caption = decoded_text)
            }

        wandb.log(log)

    # save trained model to wandb as an artifact every epoch's end

    model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))
    model_artifact.add_file('dalle.pt')
    run.log_artifact(model_artifact)

save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt')
model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))
model_artifact.add_file('dalle-final.pt')
run.log_artifact(model_artifact)

wandb.finish()