In [1]:
import logging
import os
import math
import pathlib
from datetime import datetime

# 3rd Party
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

In [2]:
# Japanese compatible font
plt.rcParams['font.sans-serif'] = "Microsoft YaHei" 

# Computation device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f"Using device: {torch.cuda.get_device_name(0)}")
else:
    print(f"Using device: CPU")

Using device: NVIDIA GeForce GTX 1060 6GB


# Utils

In [3]:
def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)

## Dataset

In [4]:
DATASET_DIR = 'Kuzushiji-49'
DATASET_TRAIN_X_FILE = 'k49-train-imgs.npz'
DATASET_TRAIN_Y_FILE = 'k49-train-labels.npz'
DATASET_TEST_X_FILE = 'k49-test-imgs.npz'
DATASET_TEST_Y_FILE = 'k49-test-labels.npz'
DATASET_CLASSMAP = 'k49_classmap.csv'

In [5]:
# Load dataset
train_images = torch.tensor(np.load(os.path.join(os.getcwd(), DATASET_DIR, DATASET_TRAIN_X_FILE))['arr_0'], dtype=torch.float32).to(device)
train_labels = torch.tensor(np.load(os.path.join(os.getcwd(), DATASET_DIR, DATASET_TRAIN_Y_FILE))['arr_0'], dtype=torch.int64).to(device)
test_images = torch.tensor(np.load(os.path.join(os.getcwd(), DATASET_DIR, DATASET_TEST_X_FILE))['arr_0'], dtype=torch.float32).to(device)
test_labels = torch.tensor(np.load(os.path.join(os.getcwd(), DATASET_DIR, DATASET_TEST_Y_FILE))['arr_0'], dtype=torch.int64).to(device)
print(f"Training Set: Input shape: {train_images.shape}. Output shape: {train_labels.shape}")
print(f"Testing Set: Input shape: {test_images.shape}. Output shape: {test_labels.shape}")

class_map = pd.read_csv(os.path.join(os.getcwd(), DATASET_DIR, DATASET_CLASSMAP))

Training Set: Input shape: torch.Size([232365, 28, 28]). Output shape: torch.Size([232365])
Testing Set: Input shape: torch.Size([38547, 28, 28]). Output shape: torch.Size([38547])


In [6]:
# Convert labels into one-hot encoder
train_labels = torch.nn.functional.one_hot(train_labels, num_classes=49).to(torch.float32)
test_labels = torch.nn.functional.one_hot(test_labels, num_classes=49).to(torch.float32)

In [7]:
# Add one dimension for channel
train_images = train_images.unsqueeze(1)
test_images = test_images.unsqueeze(1)

In [8]:
# Create validation set
validation_size = int(train_images.shape[0] * 0.1)
validate_images = train_images[-validation_size:]
train_images = train_images[:-validation_size]
validate_labels = train_labels[-validation_size:]
train_labels = train_labels[:-validation_size]

In [9]:
# transform input images into 32x32
pad2 = transforms.Pad(2)
train_images = pad2(train_images)
validate_images = pad2(validate_images)
test_images = pad2(test_images)

In [10]:
# shift pixel values into [-1,1]
def shiftInputImage(images):
    return (images-128)/128
train_images = shiftInputImage(train_images)
validate_images = shiftInputImage(validate_images)
test_images = shiftInputImage(test_images)

In [17]:
# Discard test images: there are for classification purposes
test_images = None
test_labels = None
# validate_images = None
# validate_labels = None
torch.cuda.empty_cache() 

# Model

In [12]:
class Diffusion:
    def __init__(self, noise_step=1000, beta_start=1e-4, beta_end=0.02, img_size=64):
        self.noise_step = noise_step
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_step)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t]).reshape(t.shape[0], 1, 1, 1).repeat(1, 1, self.img_size, self.img_size)
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t]).reshape(t.shape[0], 1, 1, 1).repeat(1, 1, self.img_size, self.img_size)
        Ɛ = torch.randn_like(x)
        return (
            (torch.mul(sqrt_alpha_hat, x)+torch.mul(sqrt_one_minus_alpha_hat,Ɛ)),
            Ɛ
        )

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_step, size=(n,))

    # denoise process
    def sample(self, model, n):
        logging.info(f"Sampling {n} new images ...")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 1, self.img_size, self.img_size)).to(device)
            # for i in tqdm( reversed(range(1, self.noise_step)), position=0):
            for i in reversed(range(1, self.noise_step)):
                t_tensor = (torch.ones(n) * i).long().to(device)
                t = i
                predicted_noise = model(x, t_tensor)
                alpha = self.alpha[t]
                alpha_hat = self.alpha_hat[t]
                beta = self.beta[t]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    # in the final iteration, we don't want to add noise back to X0
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

In [13]:
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb
    
class UNet(nn.Module):
    # sf: scaling factor of cnn kernel numbers
    def __init__(self, c_in=1, c_out=1, time_dim=256, sf=16):
        super().__init__()
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, sf*1)
        self.down1 = Down(sf*1, sf*2)
        self.sa1 = SelfAttention(sf*2, 16)
        self.down2 = Down(sf*2, sf*4)
        self.sa2 = SelfAttention(sf*4, 8)
        self.down3 = Down(sf*4, sf*4)
        self.sa3 = SelfAttention(sf*4, 4)

        self.bot1 = DoubleConv(sf*4, sf*8)
        self.bot2 = DoubleConv(sf*8, sf*8)
        self.bot3 = DoubleConv(sf*8, sf*4)

        self.up1 = Up(sf*4+sf*4, sf*2)
        self.sa4 = SelfAttention(sf*2, 8)
        self.up2 = Up(sf*2+sf*2, sf*1)
        self.sa5 = SelfAttention(sf*1, 16)
        self.up3 = Up(sf*1+sf*1, sf*1)
        self.sa6 = SelfAttention(sf*1, 32)
        self.outc = nn.Conv2d(sf*1, c_out, kernel_size=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

# Train

In [14]:
def train(epochs=500, batch_size=6, image_size=32, lr=3e-4, patience=5):
    model = UNet(c_in=1, c_out=1, sf=8).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=image_size, noise_step=200)
    training_timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    pathlib.Path(f'results/{training_timestamp}').mkdir(parents=True, exist_ok=True) 

    train_total_step = math.ceil(train_images.shape[0] / batch_size)
    val_total_step = math.ceil(validate_images.shape[0] / batch_size)

    prev_val_loss = math.inf
    patience_left = patience
    for epoch in range(epochs):
        print(f"Starting epoch {epoch}:")
        train_error = 0
        model.train()
        print(f"Train:")
        pbar = tqdm(range(train_total_step))
        for i in pbar:
            optimizer.zero_grad()
            if ((i+1)*batch_size > train_images.shape[0]):
                images = train_images[i*batch_size:]
            else:
                images = train_images[i*batch_size:(i+1)*batch_size]
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)
            train_error += loss.item()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=(train_error/(i+1)))

            if (epoch * train_total_step + i)%100 == 0:
                sampled_images = diffusion.sample(model, n=16)
                save_images(sampled_images, os.path.join("results", f"{training_timestamp}", f"{(epoch * train_total_step + i)}.jpg"))
        # torch.save(model.state_dict(), os.path.join("models", f"ckpt.pt"))
        avg_train_error = (train_error/train_total_step)
                
        # validate
        model.eval()
        print(f"Validate:")
        val_error = 0
        pbar = tqdm(range(val_total_step))
        for i in pbar:
            if ((i+1)*batch_size > validate_images.shape[0]):
                images = validate_images[i*batch_size:]
            else:
                images = validate_images[i*batch_size:(i+1)*batch_size]
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)
            val_error += loss.item()
            pbar.set_postfix(MSE=(val_error/(i+1)))
        avg_val_error = val_error/val_total_step

        # early-stop control
        if ((avg_val_error-prev_val_loss)>0.0001):
            patience_left -= 1
            print(f'Patience left: {patience_left}')
            if patience < 0:
                print(f'Early stopped')
                break
        else:
            patience_left += 0.5
            if patience_left > patience:
                patience_left = patience
        prev_val_loss = avg_val_error

In [15]:
train(epochs=20, batch_size=80, image_size=32, lr=3e-4, patience=5)

Starting epoch 0:
Train:


100%|███████████████████████████████████████████████████████████████████| 2615/2615 [21:18<00:00,  2.05it/s, MSE=0.126]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:31<00:00,  3.17it/s, MSE=0.0806]


Starting epoch 1:
Train:


100%|██████████████████████████████████████████████████████████████████| 2615/2615 [20:41<00:00,  2.11it/s, MSE=0.0749]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:32<00:00,  3.14it/s, MSE=0.0721]


Starting epoch 2:
Train:


100%|███████████████████████████████████████████████████████████████████| 2615/2615 [20:37<00:00,  2.11it/s, MSE=0.069]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:31<00:00,  3.17it/s, MSE=0.0666]


Starting epoch 3:
Train:


100%|███████████████████████████████████████████████████████████████████| 2615/2615 [20:48<00:00,  2.10it/s, MSE=0.066]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:32<00:00,  3.16it/s, MSE=0.0646]


Starting epoch 4:
Train:


100%|██████████████████████████████████████████████████████████████████| 2615/2615 [21:31<00:00,  2.03it/s, MSE=0.0641]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:41<00:00,  2.87it/s, MSE=0.0633]


Starting epoch 5:
Train:


100%|██████████████████████████████████████████████████████████████████| 2615/2615 [23:01<00:00,  1.89it/s, MSE=0.0626]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:43<00:00,  2.82it/s, MSE=0.0615]


Starting epoch 6:
Train:


100%|██████████████████████████████████████████████████████████████████| 2615/2615 [22:19<00:00,  1.95it/s, MSE=0.0618]


Validate:


100%|█████████████████████████████████████████████████████████████████████| 291/291 [01:36<00:00,  3.02it/s, MSE=0.061]


Starting epoch 7:
Train:


100%|██████████████████████████████████████████████████████████████████| 2615/2615 [21:05<00:00,  2.07it/s, MSE=0.0609]


Validate:


100%|█████████████████████████████████████████████████████████████████████| 291/291 [01:32<00:00,  3.14it/s, MSE=0.061]


Starting epoch 8:
Train:


100%|██████████████████████████████████████████████████████████████████| 2615/2615 [20:58<00:00,  2.08it/s, MSE=0.0604]


Validate:


100%|████████████████████████████████████████████████████████████████████| 291/291 [01:32<00:00,  3.15it/s, MSE=0.0609]


Starting epoch 9:
Train:


 82%|███████████████████████████████████████████████████████▊            | 2147/2615 [17:16<03:45,  2.07it/s, MSE=0.06]


KeyboardInterrupt: 