In [None]:
!nvidia-smi

Tue Dec  6 10:03:49 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P0    28W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install datasets imagen_pytorch transformers einops wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 28.5 MB/s 
[?25hCollecting imagen_pytorch
  Downloading imagen_pytorch-1.17.1-py3-none-any.whl (60 kB)
[K     |████████████████████████████████| 60 kB 4.5 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 59.0 MB/s 
[?25hCollecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 197 kB/s 
[?25hCollecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 68.4 MB/s 
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 74.2 MB/s 
Collecting huggingface-hub<1.0.0,

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
from tqdm import tqdm_notebook as tqdm
import numpy as np
import random
from datasets import load_dataset
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from torch.utils.data import Dataset
from torchvision import transforms as T
from transformers import T5Tokenizer, T5EncoderModel, T5Config
from einops import rearrange
import os
import wandb

Downloading:   0%|          | 0.00/605 [00:00<?, ?B/s]

In [None]:
seed = 3128974198
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# We will be saving checkpoints to our google drive so we can download them 
# later
model_save_dir = "/content/drive/MyDrive/imagen_colab/"
if not os.path.exists(model_save_dir):
  os.mkdir(model_save_dir)

In [None]:
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
def get_text_embeddings(name, labels, max_length = 256):
    if os.path.isfile(name):
        return torch.load(name)
    
    model_name = 'google/t5-v1_1-base'
    tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)

    model = T5EncoderModel.from_pretrained(model_name)
    model.eval()
    
    def photo_prefix(noun):
        if noun[0] in ['a', 'e', 'i', 'o', 'u']:
            return 'a photo of an ' + noun
        return 'a photo of a ' + noun

    texts = [photo_prefix(x) for x in labels]
    
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = max_length,
        truncation = True
    )
    
    with torch.no_grad():
        output = model(input_ids=encoded.input_ids , attention_mask=encoded.attention_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = encoded.attention_mask.bool()
    
    encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.)
    
    torch.save(encoded_text, name)
    
    return encoded_text

class HFDataset(Dataset):
    def __init__(self, hf_dataset, embeddings, transform=None):
        assert len(hf_dataset.features['label'].names) == embeddings.shape[0]
        
        self.data = hf_dataset
        self.transform = transform
        self.embeddings = embeddings
        
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = sample['img']
        label = sample['label']
        
        if self.transform is not None:
            img = self.transform(img)
        
        text_embedding = self.embeddings[label]
        
        return img, text_embedding.clone()
    
def make(config):
    cfar = load_dataset('cifar10')
    labels = cfar['train'].features['label'].names
    text_embeddings = get_text_embeddings("cifar10-embeddings.pkl", labels)

    unet = Unet(
      dim = config.dim, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
      cond_dim = config.cond_dim,
      dim_mults = config.dim_mults, # the channel dimensions inside the model (multiplied by dim)
      num_resnet_blocks = config.num_resnet_blocks,
      layer_attns = (False,) + (True,) * (len(config.dim_mults) - 1),
      layer_cross_attns = (False,) + (True,) * (len(config.dim_mults) - 1)
    )

    imagen = Imagen(
        unets = unet,
        image_sizes = config.image_sizes,
        timesteps = config.timesteps,
        cond_drop_prob = config.cond_drop_prob,
        dynamic_thresholding = config.dynamic_thresholding,
    ).cuda()

    trainer = ImagenTrainer(imagen, lr=config.lr)

    ds = HFDataset(cfar['train'], text_embeddings, transform=T.Compose([ T.RandomHorizontalFlip(), T.ToTensor() ]))


    tst_ds = HFDataset(cfar['test'], text_embeddings, transform=T.Compose([ T.RandomHorizontalFlip(), T.ToTensor() ]))

    trainer.add_train_dataset(ds, batch_size = config.batch_size)
    trainer.add_valid_dataset(tst_ds, batch_size=config.batch_size)

    return trainer, text_embeddings, labels


def train(trainer, text_embeddings, labels, config, sample_factor=None, validate_every=None, save_every=None):
    assert config.model_save_dir[-1] == '/'

    sample_every = 10

    for i in range(config.steps):
        loss = trainer.train_step(max_batch_size = config.batch_size)

        wandb.log({'train_loss': loss}, step=i)

        if validate_every is not None and i % validate_every == 0:
            avg_loss = 0
            for _ in range(100):
                valid_loss = trainer.valid_step(unet_number=1, max_batch_size=config.batch_size)
                avg_loss += valid_loss
            wandb.log({'valid loss': avg_loss}, step=i)

        if sample_factor is not None and i % sample_every == 0:
            images = trainer.sample(text_embeds=text_embeddings, batch_size = config.batch_size, return_pil_images = True)
            samples = []
            for j, img in enumerate(images):
                samples.append(wandb.Image(img, caption=labels[j]))
            wandb.log({"samples": samples}, step=i)
            sample_every = int(sample_every * sample_factor)

        if save_every is not None and i != 0 and i % save_every == 0:
          trainer.save(f"{config.model_save_dir}{wandb.run.name}-{i}.ckpt")
    
    # final save at the end if we did not already save this round
    if save_every is not None and i % save_every != 0:
      trainer.save(f"{config.model_save_dir}{wandb.run.name}-{i}.ckpt")

In [None]:
hyperparams = {
    "steps": 200_000,
    "dim": 128,
    "cond_dim": 128,
    "dim_mults": (1, 2, 4),
    "image_sizes": 32,
    "timesteps": 250,
    "cond_drop_prob": 0.1,
    "batch_size": 64,
    'lr': 1e-4,
    'num_resnet_blocks': 3,
    "model_save_dir": model_save_dir,
    "dynamic_thresholding": True
}

config = wandb.config

def build(hyperparams):
    
    with wandb.init(project='cifar10-imagen', config=hyperparams):
        
        config = wandb.config
        
        trainer, embeddings, labels = make(config)
        
        train(trainer, embeddings, labels, config, sample_factor=1.3, validate_every=None, save_every=10_000)
        
        return trainer

trainer = build(hyperparams)

[34m[1mwandb[0m: Currently logged in as: [33mlewington[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading builder script:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/4.98k [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4...


Downloading data:   0%|          | 0.00/170M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset cifar10 downloaded and prepared to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/990M [00:00<?, ?B/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

checkpoint saved to /content/drive/MyDrive/imagen_colab/giddy-capybara-30-10000.ckpt


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

checkpoint saved to /content/drive/MyDrive/imagen_colab/giddy-capybara-30-20000.ckpt


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

checkpoint saved to /content/drive/MyDrive/imagen_colab/giddy-capybara-30-30000.ckpt


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

checkpoint saved to /content/drive/MyDrive/imagen_colab/giddy-capybara-30-40000.ckpt


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

checkpoint saved to /content/drive/MyDrive/imagen_colab/giddy-capybara-30-50000.ckpt
checkpoint saved to /content/drive/MyDrive/imagen_colab/giddy-capybara-30-60000.ckpt


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]