## Imports

In [1]:
# !pip install kaggle

In [2]:
# from google.colab import files
# files.upload()

In [3]:
# !mkdir ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json

In [4]:
  # !kaggle datasets download -d mihailchirobocea/ffhq-64-train-50k

In [5]:
# !unzip /content/ffhq-64-train-50k.zip

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms 
from torch.utils.data import DataLoader
from torch import optim
from torch.cuda.amp import autocast, GradScaler
import copy
import numpy as np
import math
from PIL import Image
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import os

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

## Utils

In [7]:
def plot_images(images):
    plt.figure(figsize=(4, 4))
    plt.imshow(torch.cat([torch.cat([i for i in images.cpu()], dim=-1),], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)


def get_data(args):
    
    transform = transforms.Compose([
        # transforms.Lambda(lambda img: transforms.Resize(args.image_size)(img) if min(img.size) < args.image_size 
        #                   else (transforms.Resize(2*args.image_size)(img) if max(img.size) >= 2*args.image_size else img)),
        # transforms.RandomCrop(args.image_size),
        transforms.RandomHorizontalFlip(0.5),
        # transforms.RandomVerticalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader

## Learning Rate Schduler

In [8]:
from math import cos, pi, floor, sin

from torch.optim import lr_scheduler

def anneal_linear(start, end, proportion):
    return start + proportion * (end - start)


def anneal_cos(start, end, proportion):
    cos_val = cos(pi * proportion) + 1

    return end + (start - end) / 2 * cos_val


class Phase:
    def __init__(self, start, end, n_iter, anneal_fn):
        self.start, self.end = start, end
        self.n_iter = n_iter
        self.anneal_fn = anneal_fn
        self.n = 0

    def step(self):
        self.n += 1

        return self.anneal_fn(self.start, self.end, self.n / self.n_iter)

    def reset(self):
        self.n = 0

    @property
    def is_done(self):
        return self.n >= self.n_iter


class CycleScheduler:
    def __init__(
        self,
        optimizer,
        lr_max,
        n_iter,
        momentum=(0.95, 0.85),
        divider=25,
        warmup_proportion=0.1,
        phase=('linear', 'cos'),
    ):
        self.optimizer = optimizer

        phase1 = int(n_iter * warmup_proportion)
        phase2 = n_iter - phase1
        lr_min = lr_max / divider

        phase_map = {'linear': anneal_linear, 'cos': anneal_cos}

        self.lr_phase = [
            Phase(lr_min, lr_max, phase1, phase_map[phase[0]]),
            Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]),
        ]

        self.momentum = momentum

        if momentum is not None:
            mom1, mom2 = momentum
            self.momentum_phase = [
                Phase(mom1, mom2, phase1, phase_map[phase[0]]),
                Phase(mom2, mom1, phase2, phase_map[phase[1]]),
            ]

        else:
            self.momentum_phase = []

        self.phase = 0

    def step(self):
        lr = self.lr_phase[self.phase].step()

        if self.momentum is not None:
            momentum = self.momentum_phase[self.phase].step()

        else:
            momentum = None

        for group in self.optimizer.param_groups:
            group['lr'] = lr

            if self.momentum is not None:
                if 'betas' in group:
                    group['betas'] = (momentum, group['betas'][1])

                else:
                    group['momentum'] = momentum

        if self.lr_phase[self.phase].is_done:
            self.phase += 1

        if self.phase >= len(self.lr_phase):
            for phase in self.lr_phase:
                phase.reset()

            for phase in self.momentum_phase:
                phase.reset()

            self.phase = 0

        return lr, momentum

## Modules

In [9]:
class EMA():
    def __init__(self, beta, step = 0):
        super().__init__()
        self.beta = beta
        self.step = step

    def update_model_average(self, ema_model, current_model):
        for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()):
            old_weight, up_weight = ema_params.data, current_params.data
            ema_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())
        
        

class HeadAttention(nn.Module):
    def __init__(self, channels):
        super(HeadAttention, self).__init__()
        self.channels = channels

        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        h, w = x.shape[-2:]
        x = x.view(-1, self.channels, h * w).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, h, w)
    


class SkipAttention(nn.Module):
    def __init__(self, f_in_g, f_in_x, f_out):
        super().__init__()
        
        self.w_g = nn.Sequential(
            nn.Conv2d(f_in_g, f_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(f_out)
        )
        
        self.w_x = nn.Sequential(
            nn.Conv2d(f_in_x, f_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(f_out)
        )

        # self.g_up = nn.ConvTranspose2d(f_out, f_out, 4, stride=2, padding=1)

        self.relu = nn.ReLU(inplace=True)
        
        self.psi = nn.Sequential(
            nn.Conv2d(f_out, 1, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
         
    def forward(self, g, x):
        g1 = self.w_g(g)
        x1 = self.w_x(x)
        # g1 = self.g_up(g1)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        return psi*x
    
    
    
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, emb_dim=512):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )
        
        self.emb_layer = nn.Linear(emb_dim, out_channels)

    def forward(self, x, t = None):
        if self.residual:
            x = F.gelu(x + self.double_conv(x))
        else:
            x = self.double_conv(x)
        if t is not None:
            emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
            x = x + emb
        return x
    
        
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=512):
        super().__init__()

        self.down = nn.Conv2d(in_channels, in_channels, 4, stride=2, padding=1)
        
        self.conv1 = DoubleConv(in_channels, in_channels, residual=True)
        self.conv2 = DoubleConv(in_channels, out_channels)

        self.emb_layer = nn.Linear(emb_dim, out_channels)

        self.head_attention = HeadAttention(out_channels)

    def forward(self, x, t):
        x = self.down(x)
        x = self.conv1(x)
        x = self.conv2(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        x = x + emb
        x = self.head_attention(x)
        return x


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, gated_attention, emb_dim=512):
        super().__init__()

        self.gated_attention = gated_attention

        self.attention = SkipAttention(in_channels // 2, in_channels // 2, in_channels // 2)

        self.conv1 = DoubleConv(in_channels, in_channels, residual=True)
        self.conv2 = DoubleConv(in_channels, out_channels)

        self.up = nn.ConvTranspose2d(out_channels, out_channels, 4, stride=2, padding=1)

        self.emb_layer = nn.Linear(emb_dim, out_channels)
        
        self.head_attention = HeadAttention(out_channels)
        

    def forward(self, x, skip_x, t):
        
        if self.gated_attention:
            skip_x = self.attention(g=x, x=skip_x)
        
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        x = x + emb
        x = self.head_attention(x)
        x = self.up(x)
        return x
    
    
class In(nn.Module):
    def __init__(self, out_channels, in_channels = 3, emb_dim=512):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.GroupNorm(1, out_channels),
            nn.GELU(),
        )
        self.conv2 = DoubleConv(out_channels, out_channels, residual = True)

        self.emb_layer = nn.Linear(emb_dim, out_channels)

    def forward(self, x, t):
        x = self.conv1(x)
        x = self.conv2(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        x = x + emb
        return x
    
    
class Out(nn.Module):
    def __init__(self, in_channels, gated_attention, out_channels = 3, emb_dim=512):
        super().__init__()
        
        self.gated_attention = gated_attention

        self.attention = SkipAttention(in_channels // 2, in_channels // 2, in_channels // 2)
        
        self.conv1 = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual = True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True),
        )

    def forward(self, x, skip_x):
        
        if self.gated_attention:
            skip_x = self.attention(g=x, x=skip_x)
        
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv1(x)
        return x
    
    
class UNet(nn.Module):
    def __init__(self, device, gated_attention = False, c_in=3, c_out=3, time_dim=512):
        super().__init__()

        self.channels = [128 * k for k in range(1,5)]
        self.device = device
        self.time_dim = time_dim
        # self.time_embed = nn.Sequential(
        #     nn.Linear(time_dim, time_dim),
        #     nn.GELU(),
        #     nn.Linear(time_dim, time_dim),
        #     nn.GELU(),
        # )

        self.inc = In(self.channels[0])
        
        self.down1 = Down(self.channels[0], self.channels[1])
        self.down2 = Down(self.channels[1], self.channels[2])
        self.down3 = Down(self.channels[2], self.channels[3])

        self.bot1 = DoubleConv(self.channels[3], self.channels[3], residual = True)
        self.bot2 = DoubleConv(self.channels[3], self.channels[3], residual = True)
        self.bot3 = DoubleConv(self.channels[3], self.channels[3], residual = True)
        self.bot4 = DoubleConv(self.channels[3], self.channels[3], residual = True)
        self.bot5 = DoubleConv(self.channels[3], self.channels[3], residual = True)

        self.up1 = Up(2*self.channels[3], self.channels[2], gated_attention)
        self.up2 = Up(2*self.channels[2], self.channels[1], gated_attention)
        self.up3 = Up(2*self.channels[1], self.channels[0], gated_attention)

        self.outc = Out(2*self.channels[0], gated_attention)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)
        # t = self.time_embed(t)

        x1 = self.inc(x, t)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x4 = self.down3(x3, t)

        x5 = self.bot1(x4, t = t)
        x5 = self.bot2(x5, t = t)
        x5 = self.bot3(x5, t = t)
        x5 = self.bot4(x5, t = t)
        x5 = self.bot5(x5, t = t)

        x = self.up1(x5, x4, t)
        x = self.up2(x, x3, t)
        x = self.up3(x, x2, t)
        x = self.outc(x, x1)
        
        return x

## Diffusion

In [10]:
class Diffusion:
    def __init__(self, config):
        self.img_size = config.generate_img_size
        self.noise_steps = config.noise_steps
        self.device = config.device
        self.img_size = config.generate_img_size

        self.beta = self.prepare_noise_schedule(schedule=config.schedule, noise_steps=self.noise_steps).to(self.device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    @staticmethod
    def prepare_noise_schedule(schedule, noise_steps, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):

        if schedule == "linear":
            betas = torch.linspace(linear_start, linear_end, noise_steps)

        # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
        elif schedule == "cosine":
            timesteps = (torch.arange(noise_steps + 1) / noise_steps + cosine_s)
            alphas = timesteps / (1 + cosine_s) * math.pi / 2
            alphas = torch.cos(alphas).pow(2)
            alphas = alphas / alphas[0]
            betas = 1 - alphas[1:] / alphas[:-1]
            betas = betas.clamp(max=0.999)

        return betas

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    @torch.no_grad()
    def sample(self, title, model, n, epoch):

        model.eval()

        x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
        for i in reversed(range(1, self.noise_steps)):
            t = (torch.ones(n) * i).long().to(self.device)
            predicted_noise = model(x, t)
            alpha = self.alpha[t][:, None, None, None]
            alpha_hat = self.alpha_hat[t][:, None, None, None]
            beta = self.beta[t][:, None, None, None]
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)

        save_images(x, f"sample{epoch}_{title}.jpg")
        plot_images(x)
        
        model.train()
        
        return x


def train(config, path = None, scheduler_path = None):
    
    dataloader = get_data(config)
    print("loaded")
    
    model = UNet(config.device, config.gated_attention).to(config.device)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    optimizer = optim.AdamW(model.parameters(), lr=config.lr)
    mse = nn.MSELoss()

    if config.sched == "cycle" and scheduler_path == None:
        print("sch")
        scheduler = CycleScheduler(
            optimizer,
            config.lr,
            n_iter= len(dataloader) * config.epochs,
        )
    
    losses = []
    start_epoch = 0
    
    if path is not None:
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        config_dict = checkpoint["config"]
        config.__dict__.update(config_dict)
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        losses = checkpoint['loss']
        if config.sched == "cycle" and scheduler_path is not None:
            with open(scheduler_path, "rb") as file:
                scheduler = pickle.load(file)
            scheduler.optimizer = optimizer
        
        steps_ema = len(dataloader)*start_epoch
        ema = EMA(0.995, steps_ema)
        ema_model = copy.deepcopy(model).eval().requires_grad_(False)
        ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
        
        del checkpoint
        del config_dict
        
    else:
        ema = EMA(0.995)
        ema_model = copy.deepcopy(model).eval().requires_grad_(False)
        
    
    diffusion = Diffusion(config)

 
    if config.mix_precision:
        scaler = GradScaler()

    for epoch in range(start_epoch+1, 51):
        
        epoch_loss = []
        
        model.train()

        for k, (images, _) in enumerate(tqdm(dataloader)):
            
            optimizer.zero_grad()

            images = images.to(config.device)
            
            if config.mix_precision:
                with autocast():
                    t = diffusion.sample_timesteps(images.shape[0]).to(config.device)
                    x_t, noise = diffusion.noise_images(images, t)
                    predicted_noise = model(x_t, t)
                    loss = mse(noise, predicted_noise)           
                scaler.scale(loss).backward()
                if config.sched == "cycle":
                    scheduler.step()
                scaler.step(optimizer)
                scaler.update()
                ema.step_ema(ema_model, model)
            else:
                t = diffusion.sample_timesteps(images.shape[0]).to(config.device)
                x_t, noise = diffusion.noise_images(images, t)
                predicted_noise = model(x_t, t)
                loss = mse(noise, predicted_noise)           
                loss.backward()
                if config.sched == "cycle":
                    scheduler.step()
                optimizer.step()
                ema.step_ema(ema_model, model)
            

            if k%100 == 0:
                print(f"e{epoch}  |  b{k}  |  MSE {loss.item()}  |  Lr {optimizer.param_groups[0]['lr']}  |  t {t}")
                
            epoch_loss.append(loss.item())
            
        epoch_loss = np.array(epoch_loss).mean()
        print(f"Epoch {epoch} loss: {epoch_loss}")
        losses.append(epoch_loss)

        if epoch%5 == 0 or epoch > 70:

            path = f'my_diff_ffhq_v5_e{str(epoch)}.pth'

            if torch.cuda.device_count() > 1:
                torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.module.state_dict(),
                        'ema_model_state_dict': ema_model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': losses,
                        'config': config.__dict__
                        }, path)
            else:
                torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'ema_model_state_dict': ema_model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': losses,
                        'config': config.__dict__
                        }, path)
                
            if config.sched == "cycle":
                with open(f"diff_v5_ffhq_sch_e{str(epoch)}.pkl", "wb") as file:
                    pickle.dump(scheduler, file, -1)


    
class ModelConfig:
    def __init__(self, batch_size=10, image_size=64, epochs=100, lr=1e-4, 
                 device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                 mix_precision = False,
                 gated_attention = True,
                 schedule = 'linear',
                 noise_steps = 1000,
                 generate_img_size = 64,
                 sched = None,
                 dataset_path = "/kaggle/input/ffhq-64-train-50k/Train"):
        
        self.batch_size = batch_size
        self.image_size = image_size
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.mix_precision = mix_precision
        self.gated_attention = gated_attention
        self.schedule = schedule
        self.noise_steps = noise_steps
        self.generate_img_size = generate_img_size
        self.sched = sched
        self.dataset_path = dataset_path

def launch(path = None, sched_path = None):
    config = ModelConfig()
    train(config, path, sched_path)

In [11]:
net = UNet(device="cpu", gated_attention = True)
print(sum([p.numel() for p in net.parameters()]))

96328847


In [12]:
!nvidia-smi

Mon May 22 10:21:06 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    27W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [13]:
launch("/kaggle/input/ffhq-checkpoints/my_diff_ffhq_v5_e50.pth")

loaded
