# Cold Latent Diffusion

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.utils.data import Dataset

import os
import matplotlib.pyplot as plt
import numpy as np
import imageio
import copy
import math
from tqdm.notebook import trange, tqdm
from Unet import Unet

from diffusers.models import AutoencoderKL

In [None]:
# training parameters
batch_size = 64
lr = 2e-5

train_epoch = 3000

# data_loader
latent_size = 32

data_set_root = "../../datasets"

<b> Use a GPU if avaliable </b>

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 1
device = torch.device(gpu_indx if use_cuda else "cpu")

In [None]:
class LatentDataset(Dataset):
    def __init__(self, latent_dir):
        self.latent_dir = latent_dir
        self.latent_files = sorted(os.listdir(latent_dir))

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self, idx):
        latent_file = self.latent_files[idx]
        latent = np.load(os.path.join(self.latent_dir, latent_file))
        return torch.tensor(latent)

In [None]:
data_set_root = "."
trainset = LatentDataset(data_set_root)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

## Cosine schedual

In [None]:
def cosine_alphas_bar(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, steps, steps)
    alphas_bar = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_bar = alphas_bar / alphas_bar[0]
    return alphas_bar[1:]

## Reverse Cold Diffusion Process 
We're implementing DDIM, specifically cold diffusion, which can use any transformation, however we are just doing de-noising.

[Cold Diffusion](https://arxiv.org/pdf/2208.09392.pdf)

In [None]:
def noise_from_x0(curr_img, img_pred, alpha):
    return (curr_img - alpha.sqrt() * img_pred)/((1 - alpha).sqrt() + 1e-4)

In [None]:
def cold_diffuse(diffusion_model, sample_in, total_steps):
    diffusion_model.eval()
    bs = sample_in.shape[0]
    alphas = torch.flip(cosine_alphas_bar(total_steps), (0,)).to(device)
    random_sample = copy.deepcopy(sample_in)
    with torch.no_grad():
        for i in trange(total_steps - 1):
            index = (i * torch.ones(bs, device=sample_in.device)).long()

            img_output = diffusion_model(random_sample, index)

            noise = noise_from_x0(random_sample, img_output, alphas[i])
            x0 = img_output

            rep1 = alphas[i].sqrt() * x0 + (1 - alphas[i]).sqrt() * noise
            rep2 = alphas[i + 1].sqrt() * x0 + (1 - alphas[i + 1]).sqrt() * noise

            random_sample += rep2 - rep1

        index = ((total_steps - 1) * torch.ones(bs, device=sample_in.device)).long()
        img_output = diffusion_model(random_sample, index)

    return img_output


In [None]:
# Create a dataloader itterable object
dataiter = iter(train_loader)
# Sample from the itterable object
latents = next(dataiter)

In [None]:
timesteps = 500

# network
u_net = Unet(channels=latents.shape[1],
             img_size=latent_size,
             out_dim=latents.shape[1],
             dim=64,
             dim_mults=(1, 2, 4, 8)).to(device)

# Adam optimizer
optimizer = optim.Adam(u_net.parameters(), lr=lr)

# Scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

alphas = torch.flip(cosine_alphas_bar(timesteps), (0,)).to(device)

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in u_net.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
loss_log = []
mean_loss = 0
start_epoch = 0

In [None]:
# Load Checkpoint
# cp = torch.load("latent_u_net.pt")
# u_net.load_state_dict(cp["model_state_dict"])
# optimizer.load_state_dict(cp["optimizer_state_dict"])
# loss_log = cp["train_data_logger"]
# start_epoch = cp["epoch"]

In [None]:
pbar = trange(start_epoch, train_epoch, leave=False, desc="Epoch")    
u_net.train()
for epoch in pbar:
    pbar.set_postfix_str('Loss: %.4f' % (mean_loss/len(train_loader)))
    mean_loss = 0

    for i, (latents) in enumerate(tqdm(train_loader, leave=False)):        
        latents = latents.to(device)
        
        #the size of the current minibatch
        bs = latents.shape[0]

        rand_index = torch.randint(timesteps, (bs, ), device=device)
        random_sample = torch.randn_like(latents)
        alpha_batch = alphas[rand_index].reshape(bs, 1, 1, 1)
        
        noise_input = alpha_batch.sqrt() * latents + (1 - alpha_batch).sqrt() * random_sample
        
        with torch.cuda.amp.autocast():
            latent_pred = u_net(noise_input, rand_index)
            loss = F.l1_loss(latent_pred, latents)
        
        # Backpropagation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        #log the generator training loss
        loss_log.append(loss.item())
        mean_loss += loss.item()

    torch.save({'epoch': epoch + 1,
                'train_data_logger': loss_log,
                'model_state_dict': u_net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                 }, "latent_u_net.pt")

In [None]:
# Plot loss
plt.plot(loss_log[1000:])

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device)

In [None]:
latent_noise = 0.5 * torch.randn(8, 4, latent_size, latent_size, device=device)
with torch.no_grad():
    with torch.cuda.amp.autocast():
        fake_latents = cold_diffuse(u_net, latent_noise, total_steps=timesteps)
        fake_sample = vae.decode(fake_latents / 0.18215).sample

In [None]:
plt.figure(figsize = (20, 10))
out = vutils.make_grid(fake_sample.detach().float().cpu(), nrow=4, normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))