<a href="https://colab.research.google.com/github/VincentGariepy/Chess-Game/blob/main/Train_DALL_E.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dalle-pytorch --upgrade



In [None]:
!pip install nltk



In [None]:
!pip install wandb
import wandb
!wandb login

Collecting wandb
  Downloading wandb-0.12.14-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 5.1 MB/s 
[?25hCollecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 26.0 MB/s 
Collecting setproctitle
  Downloading setproctitle-1.2.2-cp37-cp37m-manylinux1_x86_64.whl (36 kB)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.5.8-py2.py3-none-any.whl (144 kB)
[K     |████████████████████████████████| 144 kB 10.7 MB/s 
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 1.5 MB/s 
[?25hCollecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24

In [None]:
!pip install gdown



In [None]:
!unzip TrainingImages_Airplanes.zip
!rm -rf TrainingImages_Airplanes.zip

Archive:  TrainingImages_Airplanes.zip
  inflating: TrainingImages_Airplanes/100124.jpg  
  inflating: TrainingImages_Airplanes/100124.txt  
  inflating: TrainingImages_Airplanes/100404.jpg  
  inflating: TrainingImages_Airplanes/100404.txt  
  inflating: TrainingImages_Airplanes/100563.jpg  
  inflating: TrainingImages_Airplanes/100563.txt  
  inflating: TrainingImages_Airplanes/100746.jpg  
  inflating: TrainingImages_Airplanes/100746.txt  
  inflating: TrainingImages_Airplanes/100757.jpg  
  inflating: TrainingImages_Airplanes/100757.txt  
  inflating: TrainingImages_Airplanes/100974.jpg  
  inflating: TrainingImages_Airplanes/100974.txt  
  inflating: TrainingImages_Airplanes/101088.jpg  
  inflating: TrainingImages_Airplanes/101088.txt  
  inflating: TrainingImages_Airplanes/101223.jpg  
  inflating: TrainingImages_Airplanes/101223.txt  
  inflating: TrainingImages_Airplanes/101270.jpg  
  inflating: TrainingImages_Airplanes/101270.txt  
  inflating: TrainingImages_Airplanes/10131

In [None]:
from random import choice
from pathlib import Path

# torch

import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_

# vision imports

from PIL import Image
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

from dalle_pytorch import OpenAIDiscreteVAE, DiscreteVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer

# helpers

def exists(val):
    return val is not None

# argument parsing

VAE_PATH = None   # './vae.pt' - will use OpenAIs pretrained VAE if not set
DALLE_PATH = None # './dalle.pt'
TAMING = False  # use VAE from taming transformers paper
IMAGE_TEXT_FOLDER = './TrainingImages_Airplanes'
BPE_PATH = None
RESUME = exists(DALLE_PATH)

EPOCHS = 20
BATCH_SIZE = 4
LEARNING_RATE = 3e-4
GRAD_CLIP_NORM = 0.5

MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DEPTH = 2
HEADS = 4
DIM_HEAD = 64
REVERSIBLE = True

VOCAB_SIZE = tokenizer.vocab_size

# tokenizer

if BPE_PATH is not None:
    tokenizer = HugTokenizer(BPE_PATH)

# reconstitute vae

if RESUME:
    dalle_path = Path(DALLE_PATH)
    assert dalle_path.exists(), 'DALL-E model file does not exist'

    loaded_obj = torch.load(str(dalle_path))

    dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']

    if vae_params is not None:
        vae = DiscreteVAE(**vae_params)
    else:
        vae = OpenAIDiscreteVAE()

    dalle_params = dict(        
        **dalle_params
    )

    IMAGE_SIZE = vae.image_size

else:
    if exists(VAE_PATH):
        vae_path = Path(VAE_PATH)
        assert vae_path.exists(), 'VAE model file does not exist'

        loaded_obj = torch.load(str(vae_path))

        vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

        vae = DiscreteVAE(**vae_params)
        vae.load_state_dict(weights)
    else:
        print('using pretrained VAE for encoding images to tokens')
        vae_params = None

        vae_klass = OpenAIDiscreteVAE if not TAMING else VQGanVAE1024
        vae = vae_klass()

    IMAGE_SIZE = vae.image_size

    dalle_params = dict(
        num_text_tokens = VOCAB_SIZE,
        text_seq_len = TEXT_SEQ_LEN,
        dim = MODEL_DIM,
        depth = DEPTH,
        heads = HEADS,
        dim_head = DIM_HEAD,
        reversible = REVERSIBLE
    )

# helpers

def save_model(path):
    save_obj = {
        'hparams': dalle_params,
        'vae_params': vae_params,
        'weights': dalle.state_dict()
    }

    torch.save(save_obj, path)

# dataset loading

class TextImageDataset(Dataset):
    def __init__(self, folder, text_len = 256, image_size = 128):
        super().__init__()
        path = Path(folder)

        text_files = [*path.glob('**/*.txt')]

        image_files = [
            *path.glob('**/*.png'),
            *path.glob('**/*.jpg'),
            *path.glob('**/*.jpeg')
        ]

        text_files = {t.stem: t for t in text_files}
        image_files = {i.stem: i for i in image_files}

        keys = (image_files.keys() & text_files.keys())

        self.keys = list(keys)
        self.text_files = {k: v for k, v in text_files.items() if k in keys}
        self.image_files = {k: v for k, v in image_files.items() if k in keys}

        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.RandomResizedCrop(image_size, scale = (0.75, 1.), ratio = (1., 1.)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, ind):
        key = self.keys[ind]
        text_file = self.text_files[key]
        image_file = self.image_files[key]

        image = Image.open(image_file)
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        description = choice(descriptions)

        tokenized_text = tokenizer.tokenize(description).squeeze(0)
        mask = tokenized_text != 0

        image_tensor = self.image_tranform(image)
        return tokenized_text, image_tensor, mask

# create dataset and dataloader

ds = TextImageDataset(
    IMAGE_TEXT_FOLDER,
    text_len = TEXT_SEQ_LEN,
    image_size = IMAGE_SIZE
)

assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')

dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

# initialize DALL-E

dalle = DALLE(vae = vae, **dalle_params).cuda()

if RESUME:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr = LEARNING_RATE)

# experiment tracker

import wandb

model_config = dict(
    depth = DEPTH,
    heads = HEADS,
    dim_head = DIM_HEAD
)

run = wandb.init(project = 'dalle_train_transformer', resume = RESUME, config = model_config)

# training

for epoch in range(EPOCHS):
    for i, (text, images, mask) in enumerate(dl):
        text, images, mask = map(lambda t: t.cuda(), (text, images, mask))

        loss = dalle(text, images, return_loss = True)

        loss.backward()
        clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)

        opt.step()
        opt.zero_grad()

        log = {}

        if i % 10 == 0:
            print(epoch, i, f'loss - {loss.item()}')

            log = {
                **log,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item()
            }

        if i % 100 == 0:
            sample_text = text[:1]
            token_list = sample_text.masked_select(sample_text != 0).tolist()
            decoded_text = tokenizer.decode(token_list)

            image = dalle.generate_images(
                text[:1],
                filter_thres = 0.9    # topk sampling at 0.9
            )

            save_model(f'./dalle.pt')
            wandb.save(f'./dalle.pt')

            log = {
                **log,
                'image': wandb.Image(image, caption = decoded_text)
            }

        wandb.log(log)

    # save trained model to wandb as an artifact every epoch's end

    model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))
    model_artifact.add_file('dalle.pt')
    run.log_artifact(model_artifact)

save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt')
model_artifact = wandb.Artifact('trained-dalle', type = 'model', metadata = dict(model_config))
model_artifact.add_file('dalle-final.pt')
run.log_artifact(model_artifact)

wandb.finish()

using pretrained VAE for encoding images to tokens
2000 image-text pairs found for training





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0 0 loss - 9.407890319824219
0 10 loss - 9.210505485534668
0 20 loss - 9.027332305908203
0 30 loss - 8.794010162353516
0 40 loss - 8.441352844238281
0 50 loss - 8.053003311157227
0 60 loss - 7.72253942489624
0 70 loss - 7.52573823928833
0 80 loss - 7.395773410797119
0 90 loss - 7.490198612213135
0 100 loss - 7.406127452850342
0 110 loss - 7.647394180297852
0 120 loss - 7.535918235778809
0 130 loss - 7.192729949951172
0 140 loss - 7.739445209503174
0 150 loss - 7.4507036209106445
0 160 loss - 7.245863437652588
0 170 loss - 7.28628396987915
0 180 loss - 7.098884582519531
0 190 loss - 7.299960136413574
0 200 loss - 7.342501163482666
0 210 loss - 7.5867228507995605
0 220 loss - 7.296907424926758
0 230 loss - 7.257593154907227
0 240 loss - 7.231086254119873
0 250 loss - 7.376694202423096
0 260 loss - 7.2111711502075195
0 270 loss - 7.629433631896973
0 280 loss - 7.198481559753418
0 290 loss - 7.201820373535156
0 300 loss - 7.672863483428955
0 310 loss - 7.774088382720947
0 320 loss - 7.2819