<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p1_ModelA_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from torch import nn
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as tr
from torchvision.utils import save_image, make_grid

#### Download Mnist data

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install gsutil
!gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip

In [None]:
!unzip /content/hw2_data.zip

In [5]:
train_csv = pd.read_csv('/content/hw2_data/digits/mnistm/train.csv').values.tolist()
val_csv = pd.read_csv('/content/hw2_data/digits/mnistm/val.csv').values.tolist()

In [6]:
len(train_csv) + len(val_csv)

56000

#### Get cuda from GPU

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


#### Construct a Dataset

In [8]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, train_csv: list, test_csv: list, join_path: str, transform) -> None:
        self.transform = transform
        self.img_paths = []
        self.img_labels = []

        for data_csv in [train_csv, test_csv]:
            for row in data_csv:
                self.img_paths.append(os.path.join(join_path, row[0]))
                self.img_labels.append(row[1])
        assert len(self.img_paths) == len(self.img_labels)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> (torch.Tensor, int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)

        label = self.img_labels[idx]
        return img, label

In [10]:
BATCH_SIZE = 256

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

dataset = MnistDataset(train_csv,
                  val_csv,
                  join_path = '/content/hw2_data/digits/mnistm/data',
                  transform = tr.Compose([
                      tr.ToTensor(),
                      tr.Normalize(mean=mean, std=std),
                  ]),
                  )
dataset_loader = torch.utils.data.DataLoader(dataset, BATCH_SIZE, shuffle=True, num_workers=2)

#### Build UNet Model

In [11]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, residual: bool=False) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.residual = residual

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, out_channels),
            nn.GELU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(1, out_channels),
            nn.GELU()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        if self.residual:
            if self.in_channels == self.out_channels:
                out = x + x2
            else:
                out = x1 + x2
            return out / 1.414
        else:
            return x2


class Encoder_set(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.encoder_set = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            nn.MaxPool2d(2),
        )

    def forward(self, x):
        return self.encoder_set(x)


class Decoder_set(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.decoder_set = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ConvBlock(out_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x, skip_connect):
        x = torch.cat((x, skip_connect), 1)
        x = self.decoder_set(x)

        return x


class EmbeddingFC(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.embeddingfc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.Linear(output_dim, output_dim),
        )

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = self.embeddingfc(x)

        return x

class UNet(nn.Module):
    def __init__(self, in_channels, n_features, n_classes) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.n_features = n_features
        self.n_classes = n_classes

        self.init_conv = ConvBlock(in_channels, n_features, residual=True)

        self.encoder1 = Encoder_set(n_features, 1 * n_features)
        self.encoder2 = Encoder_set(n_features, 2 * n_features)

        self.to_vec = nn.Sequential(
            nn.AvgPool2d(7),
            nn.GELU(),
        )

        self.timeembedding1 = EmbeddingFC(1, 2 * n_features)
        self.timeembedding2 = EmbeddingFC(1, 1 * n_features)

        self.contextembedding1 = EmbeddingFC(n_classes, 2 * n_features)
        self.contextembedding2 = EmbeddingFC(n_classes, 1 * n_features)

        self.decoder0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_features, 2 * n_features, 7, 7),
            nn.GroupNorm(8, 2 * n_features),
            nn.ReLU(),
        )
        self.decoder1 = Decoder_set(4 * n_features, n_features)
        self.decoder2 = Decoder_set(2 * n_features, n_features)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_features, n_features, 3, 1, 1),
            nn.GroupNorm(8, n_features),
            nn.ReLU(),
            nn.Conv2d(n_features, self.in_channels, 3, 1, 1)
        )

    def forward(self, x, c, t, context_mask):
        # x is (noisy) image, c is context label, t is timestep,
        # context_mask says which samples to block the context on

        x = self.init_conv(x)
        encode1 = self.encoder1(x)
        encode2 = self.encoder2(encode1)
        hiddenvec = self.to_vec(encode2)

        # convert to 1-hot embedding
        c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)

        # mask out context if context_mask == 1
        context_mask = context_mask[:, None]
        context_mask = context_mask.repeat(1,self.n_classes)
        context_mask = (-1*(1-context_mask)) # flip 0 <-> 1
        c *= context_mask

        # embed context with time step
        c_embed1 = self.contextembedding1(c).view(-1, self.n_features * 2, 1, 1)
        t_embed1 = self.timeembedding1(t).view(-1, self.n_features * 2, 1, 1)
        c_embed2 = self.contextembedding2(c).view(-1, self.n_features * 1, 1, 1)
        t_embed2 = self.timeembedding2(t).view(-1, self.n_features * 1, 1, 1)

        decode1 = self.decoder0(hiddenvec)
        decode2 = self.decoder1(c_embed1 * decode1 + t_embed1, encode2)
        decode3 = self.decoder2(c_embed2 * decode2 + t_embed2, encode1)
        out = self.out(torch.cat((decode3, x), 1))

        return out


#### Denoising Duffusion Probabilistic Models

$
\beta_t = \beta_1 + (\beta_2 - \beta_1) * \frac{range(0, T + 1)}{T} \\
\alpha_t  = 1 - \beta_t \\
\hat{\alpha_t} = \prod_{s=1}^{t}(1 - \beta_s) = \prod_{s=1}^{t}\alpha_t
$

In [12]:
def ddpm_schedules(beta1, beta2, T):
    assert beta1 < beta2 < 1.0

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    alphabar_t = torch.cumprod(alpha_t, dim=0)

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,
        "oneover_sqrta": oneover_sqrta,
        "sqrt_beta_t": sqrt_beta_t,
        "alphabar_t": alphabar_t,
        "sqrtab": sqrtab,
        "sqrtmab": sqrtmab,
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,
    }


"alpha_t": $\alpha_t$
"oneover_sqrta": $\frac{1}{\sqrt{\alpha_t}}$
"sqrt_beta_t": $\sqrt{\beta_t}$
"alphabar_t": $\bar{\alpha_t}$
"sqrtab": $\sqrt{\bar{\alpha_t}}$
"sqrtmab": $\sqrt{1-\bar{\alpha_t}}$
"mab_over_sqrtmab": $\frac{(1-\alpha_t)}{\sqrt{1-\bar{\alpha_t}}}$


In [13]:
class DDPM(nn.Module):
    def __init__(self, model, beta_start: int=1e-4, beta_end: int=0.02, noise_step: int=1000, device: str='cuda', drop_prob: int=0.1) -> None:
        super().__init__()
        self.model = model.to(device)
        self.noise_step = noise_step
        self.device = device
        self.drop_prob = drop_prob
        self.mse_loss = nn.MSELoss()

        for k, v in ddpm_schedules(beta_start, beta_end, noise_step).items():
            self.register_buffer(k, v)

    def forward(self, x, t):
        #  x.shape[0] is batch_size; x = (batch_size, 3, 28, 28)
        _ts = torch.randint(1, self.noise_step + 1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, noise_step)
        noise = torch.randn_like(x)  # eps ~ N(0, I)

        # We should predict the "error term" from this x_t. Loss is what we return.
        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(t)+self.drop_prob).to(self.device)

        # return MSE between added noise, and our predicted noise
        return self.mse_loss(noise, self.model(x_t, t, _ts / self.noise_step, context_mask))

    def sample(self, n_sample, size, device, guide_w = 0.0):
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        c_i = c_i.repeat(n_sample // c_i.shape[0])

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1.

        # keep track of generated steps in case want to plot something
        x_i_store = []

        for i in range(self.noise_step, 0, -1):
            t_is = torch.tensor([i / self.noise_step]).to(device)
            t_is = t_is.repeat(n_sample, 1, 1, 1)

            # double batch
            x_i = x_i.repeat(2, 1, 1, 1)
            t_is = t_is.repeat(2, 1, 1, 1)

            noise = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            predicted_noise = self.model(x_i, c_i, t_is, context_mask)
            epsilon1 = predicted_noise[:n_sample]
            epsilon2 = predicted_noise[n_sample:]
            predicted_noise = (1 + guide_w) * epsilon1 - guide_w * epsilon2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - predicted_noise * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * noise
            )
            if i % 20 == 0 or i == self.noise_step or i < 8:
                x_i_store.append(x_i.detach().cpu().numpy())

        x_i_store = np.array(x_i_store)
        return x_i, x_i_store

#### Training process

In [17]:
def modling(dataloader, ddpm, optimizer):
    ddpm.train() # to training mode.
    optimizer.param_groups[0]['lr'] = lr * (1 - epoch / EPOCHS)
    loss_ema = None

    for batch_i, (x, t) in enumerate(tqdm(dataloader, leave=False)):
        x, t = x.to(device, non_blocking=True), t.to(device, non_blocking=True) # move data to GPU

        optimizer.zero_grad()
        loss = ddpm(x, t)
        loss.backward()
        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
        optimizer.step() # update model params

    ddpm.eval()
    with torch.no_grad():
        n_samples = 30
        for w in [0, 0.5, 2]:
            x_gen, x_gen_store = ddpm.sample(n_samples, (3, 28, 28), device, guide_w=w)
            grid = make_grid(x_gen*(-1) + 1, nrow=3)
            if epoch % 10 == 0:
                save_image(grid, f'/content/drive/MyDrive/NTU_DLCV/Hw2/p1_img/epoch{epoch+1}_w{w:.1f}.png')
    if epoch % 10 == 0:
        torch.save(ddpm.state_dict(), f'/content/drive/MyDrive/NTU_DLCV/Hw2/p1_ckpt/epoch{epoch+1}.pth')

In [None]:
EPOCHS = 300

n_feature = 128 # try 256 maybe better

Unet = UNet(in_channels=3, n_features=n_feature, n_classes=10)
ddpm = DDPM(model=Unet)
ddpm.to(device)

lr = 1e-4
optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)

for epoch in tqdm(range(EPOCHS)):
    modling(dataset_loader, ddpm, optimizer)

[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
 76%|███████▋  | 167/219 [01:15<00:23,  2.23it/s][A
 77%|███████▋  | 168/219 [01:15<00:22,  2.24it/s][A
 77%|███████▋  | 169/219 [01:16<00:22,  2.25it/s][A
 78%|███████▊  | 170/219 [01:16<00:21,  2.24it/s][A
 78%|███████▊  | 171/219 [01:17<00:21,  2.24it/s][A
 79%|███████▊  | 172/219 [01:17<00:20,  2.25it/s][A
 79%|███████▉  | 173/219 [01:18<00:20,  2.25it/s][A
 79%|███████▉  | 174/219 [01:18<00:20,  2.25it/s][A
 80%|███████▉  | 175/219 [01:18<00:19,  2.25it/s][A
 80%|████████  | 176/219 [01:19<00:19,  2.23it/s][A
 81%|████████  | 177/219 [01:19<00:18,  2.22it/s][A
 81%|████████▏ | 178/219 [01:20<00:18,  2.23it/s][A
 82%|████████▏ | 179/219 [01:20<00:17,  2.23it/s][A
 82%|████████▏ | 180/219 [01:21<00:17,  2.24it/s][A
 83%|████████▎ | 181/219 [01:21<00:16,  2.25it/s][A
 83%|████████▎ | 182/219 [01:22<00:16,  2.25it/s][A
 84%|████████▎ | 183/219 [01:22<00:15,  2.26it/s][A
 84%|████████▍ | 184/219 [01:22<00:15,  2.25it/s][A
 84%|██████