In [None]:
import torch
import torch.nn.functional as F
from torch.optim import Adam 
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

IMG_SIZE = 64
BATCH_SIZE = 20

T = 300

beta = torch.linspace(0.0001, 0.02, T)

alpha = 1. - beta
alpha_bar = torch.cumprod(alpha, dim=0)
alpha_bar_prev = F.pad(alpha_bar[:-1], (1, 0), value=1.0)

sqrt_recip_alpha = torch.sqrt(1.0 / alpha)
sqrt_alpha_bar = torch.sqrt(alpha_bar)
sqrt_one_minus_alpha_bar = torch.sqrt(1. - alpha_bar)

posterior_variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)

In [None]:
def get_index_from_list(vals, t, x_shape):
    out = vals.gather(-1, t.cpu())
    return out.reshape(len(t), *((1,) * (len(x_shape) - 1))).to(t.device)


def forward_diffusion_sample(x_0, t, device):
    noise = torch.randn_like(x_0)
    sqrt_alpha_bar_t = get_index_from_list(sqrt_alpha_bar, t, x_0.shape)
    sqrt_one_minus_alpha_bar_t = get_index_from_list(
        sqrt_one_minus_alpha_bar, t, x_0.shape
    )
    return (sqrt_alpha_bar_t.to(device) * x_0.to(device)
            + sqrt_one_minus_alpha_bar_t.to(device) * noise.to(device), noise.to(device))

In [None]:
def show_tensor_image(image):
    reverse_transforms = transforms.Compose(
        [
            transforms.Lambda(lambda x: (x + 1) / 2),
            transforms.Lambda(lambda x: x.permute(1, 2, 0)),
            transforms.Lambda(lambda x: x * 255.),
            transforms.Lambda(lambda x: x.numpy().astype(np.uint8)),
            transforms.ToPILImage()
        ]
    )

    if len(image.shape) == 4:
        image = image.squeeze()
    
    plt.imshow(reverse_transforms(image))

In [None]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    beta_t = get_index_from_list(
        beta, t, x.shape
    )
    
    sqrt_one_minus_alpha_bar_t = get_index_from_list(
        sqrt_one_minus_alpha_bar, t, x.shape
    )
    
    sqrt_recip_alpha_t = get_index_from_list(
        sqrt_recip_alpha, t, x.shape
    )
    
    posterior_variance_t = get_index_from_list(
        posterior_variance, t, x.shape
    )

    model_mean = sqrt_recip_alpha_t * (
        x - beta_t * model(x, t) / sqrt_one_minus_alpha_bar_t
    )
    
    noise = torch.randn_like(x)
    return model_mean + torch.sqrt(posterior_variance_t) * noise


@torch.no_grad()
def sample_last_timestep(x):
    t = torch.full((1,), 0, device=device, dtype=torch.int64)
    
    beta_t = get_index_from_list(
        beta, t, x.shape
    )
    
    sqrt_one_minus_alpha_bar_t = get_index_from_list(
        sqrt_one_minus_alpha_bar, t, x.shape
    )
    
    sqrt_recip_alpha_t = get_index_from_list(
        sqrt_recip_alpha, t, x.shape
    )

    model_mean = sqrt_recip_alpha_t * (
        x - beta_t * model(x, t) / sqrt_one_minus_alpha_bar_t
    )
    
    return model_mean


@torch.no_grad()
def sample_plot_image():
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    plt.figure(figsize=(1, 1))
    plt.axis('off')

    for i in reversed(range(T)):
        t = torch.full((1,), i, device=device, dtype=torch.int64)
        img = sample_timestep(img, t)

    img = sample_last_timestep(img)

    show_tensor_image(img.detach().cpu())
    plt.show()  

In [None]:
data_transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), 
        transforms.Lambda(lambda t: (t * 2) - 1)
    ]
)

data_set = torchvision.datasets.Flowers102(
    root='.',
    download=True,
    transform=data_transform
)

dataloader = DataLoader(
    data_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

In [None]:
from unet import Unet

model = Unet(
    image_size=64,
    image_channels=3
).to(device)

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)

for epoch in range(50):
    for batch in dataloader:
        optimizer.zero_grad()

        t = torch.randint(0, T, size=(BATCH_SIZE,), dtype=torch.int64, device=device)
        loss = get_loss(model, batch[0], t)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | Loss: {loss.item()} ")

In [None]:
sample_plot_image()