# Задачи данного проекты:

- выбрать хотя бы два разных датасета (они могут быть и из компьютерного зрения, так и нет)
- выбрать хотя бы две метрики качества (одну внутреннюю NLL, а другую внешнюю - например, качество распознавания by 3rd-party classifier)
- пообучать генераторы с разным количеством шагов (хотя бы 3-5 значений)
- сгенерировать примеры на тесте - и здесь так же интересно - что если мы остановим генерировать раньше , чем последний шаг по времени, который был использован для обучения (т.е. исследование early stopping on inference)
- ну и посмотреть на подтипы guidance (например, задавать класс объектов) с учетом сказанного выше + интересно, как влияют разные стратегии управления дисперсией при генерации (процедуры эволюции альфа)

В итоге нужно будет предоставить код на своем гитхабе + детальный отчет о проделанной работе (в виде презентации), которую можно выложить там же.

In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST, FashionMNIST
from torchvision.datasets import CIFAR10, StanfordCars

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


# Read FasionMNIST Dataset

In [None]:
from Diffusion_project.utils import show_first_batch, transform_data_for_show

dataset = transform_data_for_show(CIFAR10, batch_size=32)
show_first_batch(dataset)

# Some visualization

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from Diffusion_project.models import CustomDiffusionModel
from Diffusion_project.unet import CustomUnet

n_steps, min_beta, max_beta = 10, 10 ** -4, 0.02
ddpm = CustomDiffusionModel(CustomUnet(), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)

In [None]:
from Diffusion_project.utils import show_forward

show_forward(ddpm, dataset, device)

# Training

In [None]:
from tqdm import tqdm

def training(ddpm, dataloader, n_epochs, optimizer, device, store_path='ddpm.pt'):
    loss_function = torch.nn.MSELoss()
    best_loss = float('inf')
     
    for epoch in tqdm(range(n_epochs)):
        epoch_loss = 0.0
        for batch in tqdm(dataloader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500"):
            x = batch[0].to(device)

            batch_size = len(x)

            t = torch.randint(0, ddpm.n_steps, (batch_size,)).to(device)
            eps = torch.randn_like(x).to(device)

            noise = ddpm(x, t, eps)
            noise_est = ddpm.reverse(noise, t)

            loss = loss_function(noise, noise_est)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            epoch_loss += loss.item() * batch_size / len(dataloader.dataset)
        
        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

optimizer = torch.optim.Adam(ddpm.parameters(), lr=0.001)     
training(ddpm, dataset, n_epochs=1, optimizer=optimizer)        

# Testing

In [None]:
store_path = 'ddpm.pt'

# Loading the trained model
best_model = CustomDiffusionModel(CustomUnet(), n_steps=n_steps, device=device)
best_model.load_state_dict(torch.load(store_path, map_location=device))
best_model.eval()
print("Model loaded")



In [None]:
from Diffusion_project.utils import show_images
import imageio
import einops


def generate_new_images(ddpm, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=1, h=28, w=28):
    """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
    frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)
    frames = []

    with torch.no_grad():
        if device is None:
            device = ddpm.device

        # Starting from random noise
        x = torch.randn(n_samples, c, h, w).to(device)

        for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
            # Estimating noise to be removed
            time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
            eta_theta = ddpm.backward(x, time_tensor)

            alpha_t = ddpm.alphas[t]
            alpha_t_bar = ddpm.alpha_bars[t]

            # Partially denoising the image
            x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)

            if t > 0:
                z = torch.randn(n_samples, c, h, w).to(device)

                # Option 1: sigma_t squared = beta_t
                beta_t = ddpm.betas[t]
                sigma_t = beta_t.sqrt()

                # Option 2: sigma_t squared = beta_tilda_t
                # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0]
                # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t
                # sigma_t = beta_tilda_t.sqrt()

                # Adding some more noise like in Langevin Dynamics fashion
                x = x + sigma_t * z

            # Adding frames to the GIF
            if idx in frame_idxs or t == 0:
                # Putting digits in range [0, 255]
                normalized = x.clone()
                for i in range(len(normalized)):
                    normalized[i] -= torch.min(normalized[i])
                    normalized[i] *= 255 / torch.max(normalized[i])

                # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame
                frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(n_samples ** 0.5))
                frame = frame.cpu().numpy().astype(np.uint8)

                # Rendering frame
                frames.append(frame)

    # Storing the gif
    with imageio.get_writer(gif_name, mode="I") as writer:
        for idx, frame in enumerate(frames):
            writer.append_data(frame)
            if idx == len(frames) - 1:
                for _ in range(frames_per_gif // 3):
                    writer.append_data(frames[-1])
    return x

print("Generating new images")
generated = generate_new_images(
        best_model,
        n_samples=100,
        device=device,
        gif_name="fashion.gif"
    )
show_images(generated, "Final result")