<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p2_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import math
import torch
import numpy as np
import pandas as pd
from torch import nn
from PIL import Image
from tqdm import tqdm
from typing import List
import torch.nn.functional as F
import torchvision.transforms as tr
from torchvision.utils import save_image, make_grid

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install gsutil
!gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip

In [None]:
!unzip /content/hw2_data.zip

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


In [None]:
swish = F.silu

@torch.no_grad()
def variance_scaling_init_(tensor, scale=1, mode="fan_avg", distribution="uniform"):
    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)

    if mode == "fan_in":
        scale /= fan_in

    elif mode == "fan_out":
        scale /= fan_out

    else:
        scale /= (fan_in + fan_out) / 2

    if distribution == "normal":
        std = math.sqrt(scale)

        return tensor.normal_(0, std)

    else:
        bound = math.sqrt(3 * scale)

        return tensor.uniform_(-bound, bound)


def conv2d(
    in_channel,
    out_channel,
    kernel_size,
    stride=1,
    padding=0,
    bias=True,
    scale=1,
    mode="fan_avg",
):
    conv = nn.Conv2d(
        in_channel, out_channel, kernel_size, stride=stride, padding=padding, bias=bias
    )

    variance_scaling_init_(conv.weight, scale, mode=mode)

    if bias:
        nn.init.zeros_(conv.bias)

    return conv


def linear(in_channel, out_channel, scale=1, mode="fan_avg"):
    lin = nn.Linear(in_channel, out_channel)

    variance_scaling_init_(lin.weight, scale, mode=mode)
    nn.init.zeros_(lin.bias)

    return lin


class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return swish(input)


class Upsample(nn.Sequential):
    def __init__(self, channel):
        layers = [
            nn.Upsample(scale_factor=2, mode="nearest"),
            conv2d(channel, channel, 3, padding=1),
        ]

        super().__init__(*layers)


class Downsample(nn.Sequential):
    def __init__(self, channel):
        layers = [conv2d(channel, channel, 3, stride=2, padding=1)]

        super().__init__(*layers)


class ResBlock(nn.Module):
    def __init__(
        self, in_channel, out_channel, time_dim, use_affine_time=False, dropout=0
    ):
        super().__init__()

        self.use_affine_time = use_affine_time
        time_out_dim = out_channel
        time_scale = 1
        norm_affine = True

        if self.use_affine_time:
            time_out_dim *= 2
            time_scale = 1e-10
            norm_affine = False

        self.norm1 = nn.GroupNorm(32, in_channel)
        self.activation1 = Swish()
        self.conv1 = conv2d(in_channel, out_channel, 3, padding=1)

        self.time = nn.Sequential(
            Swish(), linear(time_dim, time_out_dim, scale=time_scale)
        )

        self.norm2 = nn.GroupNorm(32, out_channel, affine=norm_affine)
        self.activation2 = Swish()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = conv2d(out_channel, out_channel, 3, padding=1, scale=1e-10)

        if in_channel != out_channel:
            self.skip = conv2d(in_channel, out_channel, 1)

        else:
            self.skip = None

    def forward(self, input, time):
        batch = input.shape[0]

        out = self.conv1(self.activation1(self.norm1(input)))

        if self.use_affine_time:
            gamma, beta = self.time(time).view(batch, -1, 1, 1).chunk(2, dim=1)
            out = (1 + gamma) * self.norm2(out) + beta

        else:
            out = out + self.time(time).view(batch, -1, 1, 1)
            out = self.norm2(out)

        out = self.conv2(self.dropout(self.activation2(out)))

        if self.skip is not None:
            input = self.skip(input)

        return out + input

class SelfAttention(nn.Module):
    def __init__(self, in_channel, n_head=1):
        super().__init__()

        self.n_head = n_head

        self.norm = nn.GroupNorm(32, in_channel)
        self.qkv = conv2d(in_channel, in_channel * 3, 1)
        self.out = conv2d(in_channel, in_channel, 1, scale=1e-10)

    def forward(self, input):
        batch, channel, height, width = input.shape
        n_head = self.n_head
        head_dim = channel // n_head

        norm = self.norm(input)
        qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
        query, key, value = qkv.chunk(3, dim=2)  # bhdyx

        attn = torch.einsum(
            "bnchw, bncyx -> bnhwyx", query, key
        ).contiguous() / math.sqrt(channel)
        attn = attn.view(batch, n_head, height, width, -1)
        attn = torch.softmax(attn, -1)
        attn = attn.view(batch, n_head, height, width, height, width)

        out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
        out = self.out(out.view(batch, channel, height, width))

        return out + input


class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.dim = dim

        inv_freq = torch.exp(
            torch.arange(0, dim, 2, dtype=torch.float32) * (-math.log(10000) / dim)
        )

        self.register_buffer("inv_freq", inv_freq)

    def forward(self, input):
        shape = input.shape
        sinusoid_in = torch.ger(input.view(-1).float(), self.inv_freq)
        pos_emb = torch.cat([sinusoid_in.sin(), sinusoid_in.cos()], dim=-1)
        pos_emb = pos_emb.view(*shape, self.dim)

        return pos_emb


class ResBlockWithAttention(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        time_dim,
        dropout,
        use_attention=False,
        attention_head=1,
        use_affine_time=False,
    ):
        super().__init__()

        self.resblocks = ResBlock(
            in_channel, out_channel, time_dim, use_affine_time, dropout
        )

        if use_attention:
            self.attention = SelfAttention(out_channel, n_head=attention_head)

        else:
            self.attention = None

    def forward(self, input, time):
        out = self.resblocks(input, time)

        if self.attention is not None:
            out = self.attention(out)

        return out


def spatial_unfold(input, unfold):
    if unfold == 1:
        return input

    batch, channel, height, width = input.shape
    h_unfold = height * unfold
    w_unfold = width * unfold

    return (
        input.view(batch, -1, unfold, unfold, height, width)
        .permute(0, 1, 4, 2, 5, 3)
        .reshape(batch, -1, h_unfold, w_unfold)
    )

class UNet(nn.Module):
    def __init__(
        self,
        in_channel = 3,
        channel = 128,
        attn_heads = 1,
        use_affine_time = False,
        dropout = 0,
    ):
        super(UNet, self).__init__()


        time_dim = channel * 4

        self.time = nn.Sequential(
            TimeEmbedding(channel),
            linear(channel, time_dim),
            Swish(),
            linear(time_dim, time_dim),
        )

        self.down1 = conv2d(in_channel, channel, 3, padding=1)
        self.down2 = ResBlockWithAttention(128, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down3 = ResBlockWithAttention(128, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down4 = Downsample(128)
        self.down5 = ResBlockWithAttention(128, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down6 = ResBlockWithAttention(128, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down7 = Downsample(128)
        self.down8 = ResBlockWithAttention(128, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down9 = ResBlockWithAttention(256, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down10 = Downsample(256)
        self.down11 = ResBlockWithAttention(256, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down12 = ResBlockWithAttention(256, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down13 = Downsample(256)
        self.down14 = ResBlockWithAttention(256, 512,
                time_dim,
                dropout,
                use_attention=True,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down15 = ResBlockWithAttention(512, 512,
                time_dim,
                dropout,
                use_attention=True,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down16 = Downsample(512)
        self.down17 = ResBlockWithAttention(512, 512,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.down18 = ResBlockWithAttention(512, 512,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )

        self.mid1 = ResBlockWithAttention(
                    512,
                    512,
                    time_dim,
                    dropout=dropout,
                    use_attention=True,
                    attention_head=attn_heads,
                    use_affine_time=use_affine_time,
                )
        self.mid2 = ResBlockWithAttention(
                    512,
                    512,
                    time_dim,
                    dropout=dropout,
                    use_affine_time=use_affine_time,
                )

        self.up1 = ResBlockWithAttention(1024, 512,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up2 = ResBlockWithAttention(1024, 512,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up3 = ResBlockWithAttention(1024, 512,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up4 = Upsample(512)
        self.up5 = ResBlockWithAttention(1024, 512,
                time_dim,
                dropout,
                use_attention=True,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up6 = ResBlockWithAttention(1024, 512,
                time_dim,
                dropout,
                use_attention=True,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up7 = ResBlockWithAttention(768, 512,
                time_dim,
                dropout,
                use_attention=True,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up8 = Upsample(512)
        self.up9 = ResBlockWithAttention(768, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up10 = ResBlockWithAttention(512, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up11 = ResBlockWithAttention(512, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up12 = Upsample(256)
        self.up13 = ResBlockWithAttention(512, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up14 = ResBlockWithAttention(512, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up15 = ResBlockWithAttention(384, 256,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up16 = Upsample(256)
        self.up17 = ResBlockWithAttention(384, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up18 = ResBlockWithAttention(256, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up19 = ResBlockWithAttention(256, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up20 = Upsample(128)
        self.up21 = ResBlockWithAttention(256, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up22 = ResBlockWithAttention(256, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )
        self.up23 = ResBlockWithAttention(256, 128,
                time_dim,
                dropout,
                use_attention=False,
                attention_head=attn_heads,
                use_affine_time=use_affine_time,
            )

        self.out = nn.Sequential(
            nn.GroupNorm(32, 128),
            Swish(),
            conv2d(128, 3 , 3, padding=1, scale=1e-10),
        )

    def forward(self, x, time):
        time_embed = self.time(time)

        feats = []

        x = self.down1(x)
        feats.append(x)
        x = self.down2(x, time_embed)
        feats.append(x)
        x = self.down3(x, time_embed)
        feats.append(x)
        x = self.down4(x)
        feats.append(x)
        x = self.down5(x, time_embed)
        feats.append(x)
        x = self.down6(x, time_embed)
        feats.append(x)
        x = self.down7(x)
        feats.append(x)
        x = self.down8(x, time_embed)
        feats.append(x)
        x = self.down9(x, time_embed)
        feats.append(x)
        x = self.down10(x)
        feats.append(x)
        x = self.down11(x, time_embed)
        feats.append(x)
        x = self.down12(x, time_embed)
        feats.append(x)
        x = self.down13(x)
        feats.append(x)
        x = self.down14(x, time_embed)
        feats.append(x)
        x = self.down15(x, time_embed)
        feats.append(x)
        x = self.down16(x)
        feats.append(x)
        x = self.down17(x, time_embed)
        feats.append(x)
        x = self.down18(x, time_embed)
        feats.append(x)


        x = self.mid1(x, time_embed)
        x = self.mid2(x, time_embed)

        x = self.up1(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up2(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up3(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up4(x)
        x = self.up5(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up6(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up7(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up8(x)
        x = self.up9(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up10(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up11(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up12(x)
        x = self.up13(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up14(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up15(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up16(x)
        x = self.up17(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up18(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up19(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up20(x)
        x = self.up21(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up22(torch.cat((x, feats.pop()), 1), time_embed)
        x = self.up23(torch.cat((x, feats.pop()), 1), time_embed)

        out = self.out(x)
        out = spatial_unfold(out, 1)

        return out

#### read noise path -> torch.tensor

In [None]:
def noise_read(noise_dir: str) -> torch.tensor:
    noise_paths = sorted([os.path.join(noise_dir, i) for i in os.listdir(noise_dir) if i.endswith('.pt') ])

    noise_imgs = []
    for path in noise_paths:
        noise_imgs.append(torch.load(path))
    return torch.cat(noise_imgs)

#### DDIM

In [None]:
class DDIM(nn.Module):
    def __init__(self, eta=0, sample_steps=50, noise_steps=1000, beta_start=1e-4, beta_end=2e-2, img_size=256, device=device):
        super().__init__()
        self.eta = eta
        self.smaple_steps = sample_steps
        self.img_size = img_size

        t_steps = torch.arange(0, noise_steps, (noise_steps // sample_steps)).long() + 1
        t_steps = reversed(torch.cat((torch.tensor([0], dtype=torch.long), t_steps)))
        self.t_steps = list(zip(t_steps[:-1], t_steps[1:]))

        # beta_t = beta_start + (beta_end - beta_start) * torch.arange(0, noise_steps + 1, dtype=torch.long) / noise_steps
        beta_t = torch.linspace(beta_start, beta_end, noise_steps, dtype=torch.float32)
        alpha_t = 1 - beta_t
        self.alpha_bar = torch.cumprod(alpha_t, dim=0).to(device)

    def sample(self, noise_dir, net, device, n=10):
        net.eval()
        net.to(device)

        x_store = []

        with torch.no_grad():
            # Input dim: torch.Size([n, 3, img_size, img_size])
            # x = torch.randn((n, 3, self.img_size, self.img_size)).to(device)
            x = noise_read(noise_dir).to(device)

            for i, previous_i in tqdm(self.t_steps):
                i, previous_i = i.to(device), previous_i.to(device)
                # Time step, creating a tensor of size n
                t = torch.ones(n, dtype=torch.long).to(device)
                t *= i
                # Previous time step, creating a tensor of size n
                previous_t = torch.ones(n, dtype=torch.long).to(device)
                previous_t *= previous_i
                # Expand to a 4-dimensional tensor, and get the value according to the time step t
                alpha_t = self.alpha_bar[t][:, None, None, None]
                alpha_prev = self.alpha_bar[previous_t][:, None, None, None]

                noise = torch.randn(n, *(3, self.img_size, self.img_size)).to(device, dtype=torch.float32) if i > 1 else 0

                # Images and time steps input into the model
                predicted_noise = net(x, t).to(device, dtype=torch.float32)

                # x0_t = (x - (predicted_noise * torch.sqrt((1 - alpha_t)))) / torch.sqrt(alpha_t)
                x0_t = torch.clamp((x - (predicted_noise * torch.sqrt((1 - alpha_t)))) / torch.sqrt(alpha_t), -1, 1)
                c1 = self.eta * torch.sqrt((1 - alpha_t / alpha_prev) * (1 - alpha_prev) / (1 - alpha_t))
                c2 = torch.sqrt((1 - alpha_prev) - c1 ** 2)
                x = torch.sqrt(alpha_prev) * x0_t + c2 * predicted_noise + c1 * noise

        # Return the value to the range of 0 and 1
        # x = (x + 1) * 0.5
        return x

In [None]:
net = UNet()
net.to(device)
net.load_state_dict(torch.load('/content/hw2_data/face/UNet.pt'))

<All keys matched successfully>

In [None]:
ddim = DDIM(eta=0)
ddim.eval()
ddim.to(device)

DDIM()

In [None]:
x = ddim.sample('/content/hw2_data/face/noise', net, device, n=10)

for i in range(10):
  save_image(x[i], f'/content/drive/MyDrive/NTU_DLCV/Hw2/p2_img/0{i}.png', normalize=True)

100%|██████████| 50/50 [00:53<00:00,  1.08s/it]


#### Generate face images of noise 00.pt ~ 03.pt with different eta in one grid.

In [None]:
for i in [0., 0.25, 0.5, 0.75, 1.]:
    ddim = DDIM(eta=i)
    ddim.eval()
    ddim.to(device)

    x = ddim.sample('/content/hw2_data/face/noise', net, device, n=10)
    for j in range(10):
        save_image(x[j], f'/content/drive/MyDrive/NTU_DLCV/Hw2/p2_img/eta{i}_0{j}.png', normalize=True)

#### Check MSE value

In [None]:
import cv2
import numpy as np
from sklearn.metrics import mean_squared_error

mse = 0

for i in range(10):
    gt = cv2.imread(f'/content/hw2_data/face/GT/0{i}.png').flatten()
    pred = cv2.imread(f'/content/drive/MyDrive/NTU_DLCV/Hw2/p2_img/0{i}.png').flatten()
    mse += mean_squared_error(gt, pred)
print(f"Mean Squared Error: {mse:.2f}")

Mean Squared Error: 0.62
