<a href="https://colab.research.google.com/github/R12942159/NTU_DLCV/blob/Hw2/p1_ModelA_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import os
import torch
import numpy as np
import pandas as pd
from torch import nn
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.transforms as tr
from torchvision.utils import make_grid

#### Download Mnist data

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !pip install gsutil
# !gsutil cp /content/drive/MyDrive/NTU_DLCV/Hw2/hw2_data.zip /content/hw2_data.zip
!unzip /content/hw2_data.zip

In [6]:
train_csv = pd.read_csv('/content/hw2_data/digits/mnistm/train.csv').values.tolist()
val_csv = pd.read_csv('/content/hw2_data/digits/mnistm/val.csv').values.tolist()

#### Get cuda from GPU

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using: {device}")

Using: cuda


#### Construct a Dataset

In [16]:
class MnistDataset(torch.utils.data.Dataset):
    def __init__(self, train_csv: list, test_csv: list, join_path: str, transform) -> None:
        self.transform = transform
        self.img_paths = []
        self.img_labels = []

        for data_csv in [train_csv, test_csv]:
            for row in data_csv:
                self.img_paths.append(os.path.join(join_path, row[0]))
                self.img_labels.append(row[1])
        assert len(self.img_paths) == len(self.img_labels)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx) -> (torch.Tensor, int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)

        label = self.img_labels[idx]
        return img, label

In [17]:
BATCH_SIZE = 256

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

dataset = MnistDataset(train_csv,
                  val_csv,
                  join_path = '/content/hw2_data/digits/mnistm/data',
                  transform = tr.Compose([
                      tr.ToTensor(),
                      tr.Normalize(mean=mean, std=std),
                  ]),
                  )
dataset_loader = torch.utils.data.DataLoader(dataset, BATCH_SIZE, shuffle=True, num_workers=4)

#### Build UNet Model

In [10]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, residual: bool=False) -> None:
        super().__init__()
        self.in_cahnnels = in_channels
        self_out_channels = out_channels
        self.residual = residual

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, out_channels),
            nn.GELU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(1, out_channels),
            nn.GELU()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        if self.residual:
            if self.in_channels == self.out_channels:
                out = x + x2
            else:
                out = x1 + x2
            return out / 1.414
        else:
            return x2


class Encoder_set(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.encoder_set = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            nn.MaxPool2d(2),
        )

    def forward(self, x):
        return self.encoder_set(x)


class Decoder_set(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.decoder_set = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ConvBlock(out_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x, skip_connect):
        x = torch.cat((x, skip_connect), 1)
        x = self.decoder_set(x)

        return x


class EmbeddingFC(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.embeddingfc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.GELU(),
            nn.Linear(output_dim, output_dim),
        )

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = self.embeddingfc(x)

        return x

class UNet(nn.Module):
    def __init__(self, in_channels, n_features, n_classes) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.n_features = n_features
        self.n_classes = n_classes

        self.init_conv = ConvBlock(in_channels, n_features, residual=True)

        self.encoder1 = Encoder_set(n_features, 1 * n_features)
        self.encoder2 = Encoder_set(n_features, 2 * n_features)

        self.to_vec = nn.Sequential(
            nn.AvgPool2d(7),
            nn.GELU(),
        )

        self.timeembedding1 = EmbeddingFC(1, 2 * n_features)
        self.timeembedding2 = EmbeddingFC(1, 1 * n_features)

        self.contextembedding1 = EmbeddingFC(n_classes, 2 * n_features)
        self.contextembedding2 = EmbeddingFC(n_classes, 1 * n_features)

        self.decoder0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_features, 2 * n_features, 7, 7),
            nn.GroupNorm(8, 2 * n_features),
            nn.ReLU(),
        )
        self.decoder1 = Decoder_set(4 * n_features, n_features)
        self.decoder2 = Decoder_set(2 * n_features, n_features)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_features, n_features, 3, 1, 1),
            nn.GroupNorm(8, n_features),
            nn.ReLU(),
            nn.Conv2d(n_features, self.in_channels, 3, 1, 1)
        )

    def forward(self, x, c, t, context_mask):
        # x is (noisy) image, c is context label, t is timestep,
        # context_mask says which samples to block the context on

        x = self.init_conv(x)
        encode1 = self.encoder1(x)
        encode2 = self.encoder2(encode1)
        hiddenvec = self.to_vec(encode2)

        # convert to 1-hot embedding
        c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)

        # mask out context if context_mask == 1
        context_mask = context_mask[:, None]
        context_mask = context_mask.repeat(1,self.n_classes)
        context_mask = (-1*(1-context_mask)) # flip 0 <-> 1
        c *= context_mask

        # embed context with time step
        c_embed1 = self.contextembedding1(c).view(-1, self.n_features * 2, 1, 1)
        t_embed1 = self.timeembedding1(t).view(-1, self.n_features * 2, 1, 1)
        c_embed2 = self.contextembedding2(c).view(-1, self.n_features * 1, 1, 1)
        t_embed2 = self.timeembedding2(t).view(-1, self.n_features * 1, 1, 1)

        decode1 = self.decoder0(hiddenvec)
        decode2 = self.decoder1(c_embed1 * decode1 + t_embed1, encode2)
        decode3 = self.decoder2(c_emb2 * dec2 + t_emb2, encode1)
        out = self.out(torch.cat((decode3, x), 1))

        return out


#### Denoising Duffusion Probabilistic Models

In [11]:
class DDPM(nn.Module):
    def __init__(self, model, beta_start: int=1e-4, beta_end: int=0.02, noise_step: int=1000, device: str='cuda', drop_prob: int=0.1) -> None:
        super().__init__()
        self.model = model.to(device)
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.noise_step = noise_step
        self.device = device
        self.drop_prob = drop_prob
        self.mse_loss = nn.MSELoss()

        self.beta = torch.linspace(beta_start, beta_end, noise_step).to(device)
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def forward(self, x, t):
        """
        this method is used in training, so samples t and noise randomly
        """

        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x) # eps ~ N(0, 1)

        _ts = torch.randint(1, self.noise_step+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, noise_step)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        # We should predict the "error term" from this x_t. Loss is what we return.
        x_t = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(t)+self.drop_prob).to(self.device)

        # return MSE between added noise, and our predicted noise
        return self.mse_loss(noise, self.model(x_t, t, _ts / self.noise_step, context_mask))

    def sample(self, n_sample, size, device, guide_w = 0.0):
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        c_i = c_i.repeat(n_sample // c_i.shape[0])

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1.

        # keep track of generated steps in case want to plot something
        x_i_store = []

        for i in range(self.noise_step, 0, -1):
            alpha = self.alpha[i]
            alpha_hat = self.alpha_hat[i]
            beta = self.beta[i]

            t_is = torch.tensor([i / self.noise_step]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            if i > 1:
                 noise = torch.randn_like(n_sample, *size)
            else:
                 noise = torch.zeros_like(n_sample, *size)

            # split predictions and compute weighting
            predicted_noise = self.model(x_i, c_i, t_is, context_mask)
            epsilon1 = predicted_noise[:n_sample]
            epsilon2 = predicted_noise[n_sample:]
            epsilon = (1 + guide_w) * epsilon1 - guide_w * epsilon2
            x_i = x_i[:n_sample]
            x_i = (
                (1 / torch.sqrt(alpha)) * (x_i - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            )
            if i%20==0 or i==self.noise_step or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())

        x_i_store = np.array(x_i_store)
        return x_i, x_i_store

#### Training process

In [12]:
def modling(dataloader, ddpm, optimizer):
    ddpm.train() # to training mode.
    optimizer.param_groups[0]['lr'] = lr * (1 - epoch / EPOCHS)
    loss_ema = None
    for batch_i, (x, t) in enumerate(tqdm(dataloader, leave=False)):
        x, t = x.to(device, non_blocking=True), t.to(device, non_blocking=True) # move data to GPU

        optimizer.zero_grad()
        loss = ddpm(x, t)
        loss.backward()
        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
        optimizer.step() # update model params

    ddpm.eval()
    with torch.no_grad():
        n_sample = 40
        for w in [0.0, 0.5, 2.0]:
            x_gen, x_gen_store = ddpm.sample(n_sample, (3, 28, 28), device, guide_w=w)
            grid = make_grid(x_gen*(-1) + 1, nrow=3)
            save_image(grid, f'/content/img/epoch{epoch+1}_w{w:.1f}.png')

    torch.save(ddpm.state_dict(), f'/content/ckpt/epoch{epoch+1}.pth')

In [19]:
EPOCHS = 100

n_feature = 128

Unet = UNet(in_channels=3, n_features=n_feature, n_classes=10)
ddpm = DDPM(model=Unet)
ddpm.to(device)

lr = 1e-4
optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)

for epoch in tqdm(range(EPOCHS)):
    print(f'epoch{epoch+1}')
    modling(dataset_loader, ddpm, optimizer)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x79e8ed489ea0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 
  0%|          | 0/100 [00:00<?, ?it/s]

epoch1



  0%|          | 0/219 [00:00<?, ?it/s][A
  0%|          | 0/100 [00:01<?, ?it/s]


AttributeError: ignored