# Практическое задание по занятию Denosing Diffusional Models

ФИО: Усцов Артем Алексеевич

# Задание

0) Скачайте репозиторий git clone https://github.com/awjuliani/pytorch-diffusion.git

1) Обучите модель для датасета Fasion MNIST. Продемонстрируйте обратный диффузионный процесс с нескольких random seeds (1 балл)

2) Добавьте к обучению DDPM условие на метку с помощью Classifier Free Guidance https://arxiv.org/abs/2207.12598. (1 балл)

3) Обучите модель в режиме inpainting. Параграф 4.1 в https://arxiv.org/pdf/2201.09865.pdf . (1 балл)


# Load pytorch-diffusion

In [None]:
! git clone https://github.com/awjuliani/pytorch-diffusion.git

In [None]:
%cd pytorch-diffusion

In [None]:
! pip install pytorch-lightning

In [None]:
import torch
from data import DiffSet
import pytorch_lightning as pl
from model import DiffusionModel
from torch.utils.data import DataLoader
import imageio
import glob
import matplotlib.pyplot as plt
import torchvision

# Data

In [None]:
# Training hyperparameters
diffusion_steps = 1000
dataset_choice = "Fashion"
max_epoch = 10
batch_size = 128

# Loading parameters
load_model = False
load_version_num = 1


In [None]:

# Code for optionally loading model
pass_version = None
last_checkpoint = None

if load_model:
    pass_version = load_version_num
    last_checkpoint = glob.glob(
        f"./lightning_logs/{dataset_choice}/version_{load_version_num}/checkpoints/*.ckpt"
    )[-1]

In [None]:
# Create datasets and data loaders
train_dataset = DiffSet(True, dataset_choice)
val_dataset = DiffSet(False, dataset_choice)

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

# Create model and trainer
if load_model:
    model = DiffusionModel.load_from_checkpoint(last_checkpoint, in_size=train_dataset.size*train_dataset.size, t_range=diffusion_steps, img_depth=train_dataset.depth)
else:
    model = DiffusionModel(train_dataset.size*train_dataset.size, diffusion_steps, train_dataset.depth)

In [None]:
# Load Trainer model
tb_logger = pl.loggers.TensorBoardLogger(
    "lightning_logs/",
    name=dataset_choice,
    version=pass_version,
)

trainer = pl.Trainer(
    max_epochs=max_epoch, 
    log_every_n_steps=10, 
    auto_select_gpus=True,
    resume_from_checkpoint=last_checkpoint, 
    logger=tb_logger,
    gpus=1
)

# Simple Diffusion Model

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
gen_samples = []
for i in range(5):
    samples = []
    torch.manual_seed(i)
    x = torch.randn((1, train_dataset.depth, train_dataset.size, train_dataset.size))
    sample_steps = torch.arange(model.t_range-1, 0, -1)
    for t in sample_steps:
        x = model.denoise_sample(x, t)
        if t % 50 == 0:
            samples.append(x)
    gen_samples.append(samples)

In [None]:
len(gen_samples)

In [None]:
for samples in gen_samples:
    plt.figure(figsize=(20, 3))
    for i, sample in enumerate(samples):
        plt.subplot(1, len(samples), i + 1)
        plt.imshow(sample.detach().numpy().reshape(32, 32), cmap=plt.cm.Greys_r)
        plt.axis('off')
    plt.show()

# Condition Model

In [None]:
import numpy as np
from torch import nn
import math

In [None]:
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
from torchvision import transforms

In [None]:
class ConditionDiffSet(DiffSet):
    def __init__(self, train, dataset="MNIST"):
        super(ConditionDiffSet, self).__init__(train, dataset=dataset)
        datasets = {
            "MNIST": MNIST,
            "Fashion": FashionMNIST,
            "CIFAR": CIFAR10,
        }

        train_dataset = datasets[dataset](
            "./data", download=True, train=train
        )
        self.labels = train_dataset.targets
    
    def __getitem__(self, item):
        return self.input_seq[item], self.labels[item]

In [None]:
class ConditionDiffusionModel(DiffusionModel):
    def __init__(self, in_size, t_range, img_d, num_classes=10, time_dim=256):
        super().__init__(in_size, t_range, img_d)
        self.time_dim=time_dim
        self.label_emb = nn.Embedding(num_classes, self.time_dim)

    def forward(self, x, t, y=None):
        """
        Model is U-Net with added positional encodings and self-attention layers.
        """
        # t = t.unsqueeze(-1)
        # t = self.pos_encoding(t, self.time_dim)
        # if y is not None:
        #     t += self.label_emb(y)
        x1 = self.inc(x)
        x2 = self.down1(x1) + self.pos_encoding(t, 128, 16) + self.pos_encoding(y, 128, 16)
        x3 = self.down2(x2) + self.pos_encoding(t, 256, 8) + self.pos_encoding(y, 256, 8)
        x3 = self.sa1(x3)
        x4 = self.down3(x3) + self.pos_encoding(t, 256, 4) + self.pos_encoding(y, 256, 4)
        x4 = self.sa2(x4)

        
        x = self.up1(x4, x3) + self.pos_encoding(t, 128, 8) + self.pos_encoding(y, 128, 8)
        x = self.sa3(x)
        x = self.up2(x, x2) + self.pos_encoding(t, 64, 16) + self.pos_encoding(y, 64, 16)
        x = self.up3(x, x1) + self.pos_encoding(t, 64, 32) + self.pos_encoding(y, 64, 32)
        output = self.outc(x)
        return output

    def get_loss(self, batch, batch_idx):
        """
        Corresponds to Algorithm 1 from (Ho et al., 2020).
        """
        p_cond = 0.5
        x, y = batch
        y = y*torch.ones_like(y).bernoulli(p_cond)

        # p_uncond = np.random.choice([False, True], p=[0.2, 0.8])
        # if not p_uncond:
        #     y = 0

        ts = torch.randint(0, self.t_range, [x.shape[0]], device=self.device)
        noise_imgs = []
        epsilons = torch.randn(x.shape, device=self.device)
        for i in range(len(ts)):
            a_hat = self.alpha_bar(ts[i])
            noise_imgs.append(
                (math.sqrt(a_hat) * x[i]) + (math.sqrt(1 - a_hat) * epsilons[i])
            )
        noise_imgs = torch.stack(noise_imgs, dim=0)
        e_hat = self.forward(noise_imgs, ts.unsqueeze(-1).type(torch.float), y.unsqueeze(-1).type(torch.float))
        loss = nn.functional.mse_loss(
            e_hat.reshape(-1, self.in_size), epsilons.reshape(-1, self.in_size)
        )
        return loss
    
    def denoise_sample(self, x, t, y=None):
        """
        Corresponds to the inner loop of Algorithm 2 from (Ho et al., 2020).
        """
        with torch.no_grad():
            if t > 1:
                z = torch.randn(x.shape)
            else:
                z = 0
            e_hat = self.forward(x, t.type(torch.float), y.unsqueeze(-1).type(torch.float)) + self.forward(x, t.type(torch.float), torch.tensor(0).unsqueeze(-1).type(torch.float))

            pre_scale = 1 / math.sqrt(self.alpha(t))
            e_scale = (1 - self.alpha(t)) / math.sqrt(1 - self.alpha_bar(t))
            post_sigma = math.sqrt(self.beta(t)) * z
            x = pre_scale * (x - e_scale * e_hat) + post_sigma
            return x

In [None]:
train_dataset = ConditionDiffSet(True, dataset_choice)
val_dataset = ConditionDiffSet(False, dataset_choice)

train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, shuffle=True)

In [None]:
conditional_model = ConditionDiffusionModel(train_dataset.size*train_dataset.size, diffusion_steps, 1, 10)

In [None]:
# Load Trainer model
tb_logger = pl.loggers.TensorBoardLogger(
    "lightning_logs/",
    name=dataset_choice,
)

trainer = pl.Trainer(
    max_epochs=max_epoch, 
    log_every_n_steps=10, 
    auto_select_gpus=True,
    logger=tb_logger,
    gpus=1
)

In [None]:
trainer.fit(conditional_model, train_loader, val_loader)

In [None]:
from tqdm.notebook import tqdm

In [None]:
conditional_model.to("cpu")

In [None]:
x = torch.randn((16, train_dataset.depth, train_dataset.size, train_dataset.size))
sample_steps = torch.arange(conditional_model.t_range-1, 0, -1)
for t in tqdm(sample_steps):
    x = conditional_model.denoise_sample(x, t, torch.ones(16))

In [None]:
plt.figure(figsize=(10, 10))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(x[i].detach().numpy().reshape(32, 32), cmap=plt.cm.Greys_r)
    plt.axis('off')


Видно что класс 1 преобладает

# Enpainting

In [None]:
images = train_dataset[:5]

In [None]:
mask = torch.ones(32, 32)
mask[10:20, 10:20] = torch.zeros(10, 10)

In [None]:
resampling = 5

In [None]:
from tqdm import *

In [None]:
model.to("cpu")

In [None]:
import numpy as np

In [None]:
from tqdm import std
from numpy.ma.core import mean
gen_samples = []
for image in images:
    image = image[0].unsqueeze(0)
    samples = []
    x = torch.randn((1, train_dataset.depth, train_dataset.size, train_dataset.size))
    sample_steps = torch.arange(model.t_range-1, 0, -1)
    for t in tqdm(sample_steps):
      for u in range(resampling):
        a_bar = model.alpha_bar(t)
        b_t = model.beta(t)
        e = torch.randn((1, train_dataset.depth, train_dataset.size, train_dataset.size))
        x_known = np.sqrt(a_bar)*image + (1-a_bar)*e
        x_uknown = model.denoise_sample(x, t)

        x = mask * x_known + (1-mask) * x_uknown
        if t > 0 and u < resampling-1:
          x = torch.randn((1, train_dataset.depth, train_dataset.size, train_dataset.size)) * model.beta(t)*torch.ones_like(x) + torch.sqrt(1-model.beta(t))*x
      if t % 50 == 0:
        samples.append(x)
    gen_samples.append(samples)

In [None]:
for samples in gen_samples:
    plt.figure(figsize=(20, 3))
    for i, sample in enumerate(samples):
        plt.subplot(1, len(samples), i + 1)
        plt.imshow(sample.detach().numpy().reshape(32, 32), cmap=plt.cm.Greys_r)
        plt.axis('off')
    plt.show()

In [None]:
plt.figure(figsize=(20, 20))
for i, samples in enumerate(images):
    plt.subplot(1, len(gen_samples), i + 1)
    plt.imshow((samples[-1]).detach().numpy().reshape(32, 32), cmap=plt.cm.Greys_r)
    plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
for i, samples in enumerate(images):
    plt.subplot(1, len(gen_samples), i + 1)
    plt.imshow((samples[-1]*mask).detach().numpy().reshape(32, 32), cmap=plt.cm.Greys_r)
    plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize=(20, 20))
for i, samples in enumerate(gen_samples):
    plt.subplot(1, len(gen_samples), i + 1)
    plt.imshow(samples[-1].detach().numpy().reshape(32, 32), cmap=plt.cm.Greys_r)
    plt.axis('off')
plt.show()

Feedback (опционально): сложно

Здесь вы можете оставить список опечаток из лекции или семинара:

Здесь вы можете оставить комментарии по лекции или семинару: