In [1]:
from tqdm import tqdm
from datasets import load_dataset
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from torch.utils.data import Dataset
from torchvision import transforms as T
from transformers import T5Tokenizer, T5EncoderModel, T5Config
from einops import rearrange
import os
import wandb

NOTE: Redirects are currently not supported in Windows or MacOs.
    https://github.com/beartype/beartype#pep-585-deprecations
  warn(


In [2]:
seed = 3128974198
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

NameError: name 'np' is not defined

In [None]:
wandb.login()

In [None]:
def get_text_embeddings(name, labels, max_length = 256):
    if os.path.isfile(name):
        return torch.load(name)
    
    model_name = 'google/t5-v1_1-base'
    tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)

    model = T5EncoderModel.from_pretrained(model_name)
    model.eval()
    
    def photo_prefix(noun):
        if noun[0] in ['a', 'e', 'i', 'o', 'u']:
            return 'a photo of an ' + noun
        return 'a photo of a ' + noun

    texts = [photo_prefix(x) for x in labels]
    
    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = "pt",
        padding = 'longest',
        max_length = max_length,
        truncation = True
    )
    
    with torch.no_grad():
        output = model(input_ids=encoded.input_ids , attention_mask=encoded.attention_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = encoded.attention_mask.bool()
    
    encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.)
    
    torch.save(encoded_text, name)
    
    return encoded_text

class HFDataset(Dataset):
    def __init__(self, hf_dataset, embeddings, transform=None):
        assert len(hf_dataset.features['label'].names) == embeddings.shape[0]
        
        self.data = hf_dataset
        self.transform = transform
        self.embeddings = embeddings
        
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        img = sample['img']
        label = sample['label']
        
        if self.transform is not None:
            img = self.transform(img)
        
        text_embedding = self.embeddings[label]
        
        return img, text_embedding.clone()
    
def make(config):
    cfar = load_dataset('cifar10')
    labels = cfar['train'].features['label'].names
    text_embeddings = get_text_embeddings("cifar10-embeddings.pkl", labels)

    unet = Unet(
    dim = config.dim, # the "Z" layer dimension, i.e. the number of filters the outputs to the first layer
    cond_dim = config.cond_dim,
    dim_mults = config.dim_mults, # the channel dimensions inside the model (multiplied by dim)
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = (False, True, True)
    )

    imagen = Imagen(
        unets = unet,
        image_sizes = config.image_sizes,
        timesteps = config.timesteps,
        cond_drop_prob = config.cond_drop_prob
    ).cuda()

    trainer = ImagenTrainer(imagen)

    ds = HFDataset(cfar['train'], text_embeddings, transform=T.Compose([ T.RandomHorizontalFlip(), T.ToTensor() ]))


    tst_ds = HFDataset(cfar['test'], text_embeddings, transform=T.Compose([ T.RandomHorizontalFlip(), T.ToTensor() ]))

    trainer.add_train_dataset(ds, batch_size = config.batch_size)
    trainer.add_valid_dataset(tst_ds, batch_size=config.batch_size)

    return trainer, text_embeddings, labels


def train(trainer, text_embeddings, labels, config):
    for i in tqdm(range(config.steps)):
        loss = trainer.train_step(max_batch_size = config.batch_size)

        wandb.log({'train_loss': loss}, step=i)

        if not (i % config.validate_every):
            valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = config.batch_size)
            wandb.log({'valid loss': loss}, step=i)

        if not (i % config.sample_every):
            images = trainer.sample(text_embeds=text_embeddings, batch_size = config.batch_size, return_pil_images = True)
            samples = []
            for i, img in enumerate(images):
                samples.append(wandb.Image(img, caption=labels[i]))
            wandb.log({"samples": samples}, step=i)
            

In [36]:
hyperparams = {
    "steps": 200000,
    "dim": 128,
    "cond_dim": 256,
    "dim_mults": (1, 2, 4),
    "image_sizes": 32,
    "timesteps": 1000,
    "cond_drop_prob": 0.1,
    "batch_size": 4,
    "sample_every": 300,
    "validate_every": 1000,
}


config = wandb.config

def build(hyperparams):
    
    with wandb.init(project='cifar10-imagen', config=hyperparams):
        
        config = wandb.config
        
        trainer, embeddings, labels = make(config)
        
        train(trainer, embeddings, labels, config)
        
        return trainer

trainer = build(hyperparams)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

Found cached dataset cifar10 (C:/Users/Metabox/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|                                                                                                                                                                                                            | 0/200000 [00:00<?, ?it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  0%|▎                                                                                                                                                                                              | 300/200000 [02:10<11:38:43,  4.76it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  0%|▌                                                                                                                                                                                              | 600/200000 [04:34<11:52:39,  4.66it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  0%|▊                                                                                                                                                                                              | 900/200000 [06:57<11:24:28,  4.85it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  1%|█▏                                                                                                                                                                                            | 1200/200000 [09:20<11:37:15,  4.75it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  1%|█▍                                                                                                                                                                                            | 1500/200000 [11:44<11:32:31,  4.78it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  1%|█▋                                                                                                                                                                                            | 1800/200000 [14:08<11:36:39,  4.74it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  1%|█▉                                                                                                                                                                                            | 2100/200000 [16:31<11:25:09,  4.81it/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

  1%|█▉                                                                                                                                                                                            | 2100/200000 [16:37<26:06:06,  2.11it/s]


0,1
train_loss,█▄▃▃▂▃▁▂▁▂▂▂▁▁▂▁▂▂▂▂▂▂▁▂▁▂▁▁▂▁▁▁▂▁▁▁▁▁▂▁
valid loss,█▁▁

0,1
train_loss,0.02765
valid loss,0.01944


KeyboardInterrupt: 