In [None]:
######################
# Code for Diffusion Models
# Ref: https://github.com/dome272/Diffusion-Models-pytorch
# Code for Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch 
# Ref: https://github.com/lucidrains/imagen-pytorch
######################

import os
import copy
import numpy as np
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch import optim
from helpers import *
import logging
from torch.utils.tensorboard import SummaryWriter

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

In [None]:
seed_value = 42
torch.manual_seed(seed_value)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)

In [None]:
!rm -rf models/*
!rm -rf results/*
!rm -rf runs/*

In [None]:
import sys
import warnings

warnings.filterwarnings("ignore")
sys.path.append("./imagen/")

from imagen_pytorch import Unet, Imagen, ImagenTrainer, NullUnet

In [None]:
unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

unets = [unet1, unet2]

In [None]:
def sample(trainer, n_images, label_embeds=None, continuous_embeds=None):
    x = trainer.sample(batch_size=n_images, 
                       label_embeds=label_embeds,
                       continuous_embeds=continuous_embeds,
                       use_tqdm=True)
    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8)
    return x

In [None]:
def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = args.dataloader
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    epoch = 0
    
    for k in range(1, len(unets)+1):
        trainer = ImagenTrainer(imagen, lr=args.lr, verbose=False).cuda()
        try:
            ckpt_path = os.path.join("models", args.run_name, f"ckpt_{k}.pt")
            ckpt_trainer_path = os.path.join("models", args.run_name, f"ckpt_trainer_{k}.pt")
            checkpoint = torch.load(ckpt_path)
            if device == "cuda": trainer.load(ckpt_trainer_path)
            else: trainer.load(ckpt_trainer_path, map_location=torch.device('cpu'))
            start_epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            logging.info(f"Resuming training from epoch: {start_epoch+1} for unet_{k}")  
            if (start_epoch+1) >= args.epochs: 
                logging.info(f"No more epochs to train for unet_{k}")
            epoch = start_epoch+1
        except FileNotFoundError:
            start_epoch = -1
            loss = None
            logging.info(f"Starting training from scratch for unet_{k}")
            epoch = 0

        for epoch in range(start_epoch+1, args.epochs):
            logging.info(f"Starting epoch {epoch}:")
            pbar = tqdm(dataloader)
            for i, (images, labels) in enumerate(pbar):
                images = images.to(device)
                labels = labels.to(device)
                continuous = torch.rand((args.batch_size, args.continuous_dim)).to(device)
               
                loss = trainer(images=images, label_embeds=labels, 
                               continuous_embeds=continuous,
                               unet_number=k)
                trainer.update(unet_number=k)

                pbar.set_postfix({f"MSE_{k}":loss})
                logger.add_scalar(f"MSE_{k}", loss, global_step=epoch*len(dataloader)+i)

            checkpoint = {
                'epoch': epoch,
                'loss': loss
            }
            trainer.save(os.path.join("models", args.run_name, f"ckpt_trainer_{k}.pt"))
            torch.save(checkpoint, os.path.join("models", args.run_name, f"ckpt_{k}.pt"))

            logging.info(f"Completed epoch {epoch}.")
        
            if True:
                logging.info(f"Starting sampling for epoch {epoch}:")
                trainer = ImagenTrainer(imagen, lr=args.lr, verbose=False).cuda()
                for k in range(1, len(unets)+1):
                    ckpt_trainer_path = os.path.join("models", args.run_name, f"ckpt_trainer_{k}.pt")
                    if device == "cuda": trainer.load(ckpt_trainer_path)
                    else: trainer.load(ckpt_trainer_path, map_location=torch.device('cpu'))
                n_images = 2
                labels = torch.randint(0, 10, (n_images, )).to(device)
                continuous = torch.rand((n_images, args.continuous_dim)).to(device)
                ema_sampled_images = sample(trainer, n_images=n_images, 
                                            label_embeds=labels, 
                                            continuous_embeds=continuous)
                plot_images(ema_sampled_images)
                save_images(ema_sampled_images, os.path.join("results", args.run_name, f"{epoch}_ema.jpg"))
                logging.info(f"Completed sampling for epoch {epoch}.")

In [None]:
import argparse

class DDPMArgs:
    def __init__(self):
        pass
    
args = DDPMArgs()
args.run_name = "DDPM_cascaded"
args.epochs = 1
args.batch_size = 8
args.image_size = 64
args.num_classes = 10
args.continuous_embed_dim = 5
args.dataset_path = r"/rds/general/user/zr523/home/cifar10/cifar10-64/train"
args.device = "cuda"
args.lr = 3e-4
args.experiment = True
args.subset_data = 1000

args.dataloader = get_cifar10_data(args)
logging.info(f"Dataset loaded")

In [None]:
imagen = Imagen(
    unets = unets,
    condition_on_labels = True,
    label_embed_dim = args.num_classes,
    condition_on_continuous = True,
    continuous_embed_dim = args.continuous_embed_dim,
    image_sizes = (64, 64),
    timesteps = 1000,
    cond_drop_prob = 0.1
)

train(args)