<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p1_Denoising_Diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Build UNet Model

In [None]:
import torch
from torch import nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, residual: bool=False) -> None:
        super().__init__()
        self.in_cahnnels = in_channels
        self_out_channels = out_channels
        self.residual = residual

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, out_channels),
            nn.GELU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(1, out_channels),
            nn.GELU()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        if self.residual:
            if self.in_channels == self.out_channels:
                out = x + x2
            else:
                out = x1 + x2
            return out / 1.414
        else:
            return x2


class Encoder_set(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.encoder_set = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            nn.MaxPool2d(2),
        )

    def forward(self, x):
        return self.encoder_set(x)


class Decoder_set(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.decoder_set = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ConvBlock(out_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x, skip_connect):
        x = torch.cat((x, skip_connect), 1)
        x = self.decoder_set(x)

        return x


class EmbeddingFC(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.embeddingfc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.Linear(output_dim, output_dim),
        )

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = self.embeddingfc(x)

        return x

class UNet(nn.Module):
    def __init__(self, in_channels, n_features, n_classes) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.n_features = n_features
        self.n_classes = n_classes

        self.init_conv = ConvBlock(in_channels, n_features, residual=True)

        self.encoder1 = Encoder_set(n_features, 1 * n_features)
        self.encoder2 = Encoder_set(n_features, 2 * n_features)

        self.to_vec = nn.Sequential(
            nn.AvgPool2d(7),
            nn.GELU(),
        )

        self.timeembedding1 = EmbeddingFC(1, 2 * n_features)
        self.timeembedding2 = EmbeddingFC(1, 1 * n_features)

        self.contextembedding1 = EmbeddingFC(n_classes, 2 * n_features)
        self.contextembedding2 = EmbeddingFC(n_classes, 1 * n_features)

        self.decoder0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_features, 2 * n_features, 7, 7),
            nn.GroupNorm(8, 2 * n_features),
            nn.ReLU(),
        )
        self.decoder1 = Decoder_set(4 * n_features, n_features)
        self.decoder2 = Decoder_set(2 * n_features, n_features)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_features, n_features, 3, 1, 1),
            nn.GroupNorm(8, n_features),
            nn.ReLU(),
            nn.Conv2d(n_features, self.in_channels, 3, 1, 1)
        )

    def forward(self, x, c, t, context_mask):

#### Denoising Duffusion Probabilistic Models

In [None]:
import torch
from tqdm import tqdm


class Denoising_Diffusion:
    def __init__(self, noise_step=1000, beta_start=1e-4, beta_end=0.01, img_size=28, device='cuda'):
        self.noise_step = noise_step
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_step)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x)
        return sqrt_alpha_hat * x +  sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return  torch.randint(low=1, high=self.noise_step, size=(n,))

    def sample(self, model, n):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_step)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise