In [None]:
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#### Hyperparameter
learning_rate = 0.001
num_epochs = 30
batch_size = 16
image_size = 72
patch_size = 6
num_patches = (image_size // patch_size) ** 2
batch_size = 32
hidden_size = 64

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])


train_ds = datasets.CIFAR100(root="./data", download=True, train=True, transform=transform)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

In [None]:
print("data batch size: {}".format(next(iter(train_dl))[0].size()))

data batch size: torch.Size([32, 3, 72, 72])


In [None]:
class PatchLayerWithoutChannel(nn.Module):
    def __init__(self, patch_size: int):
        super().__init__()
        self.patch_size = patch_size

    # input = (N, C, H, W) ===> return = (N, CHW/ P^2, P^2)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.size(2) == x.size(3), "image size should be square"
        HW = x.size(2) ** 2
        channel_in = x.size(1)
        assert HW % (self.patch_size ** 2) == 0, "image size should be devisible by patch!"
        patch_num = HW // (self.patch_size ** 2)
        patch = torch.zeros(x.size(0), patch_num, channel_in, self.patch_size , self.patch_size)
        for i in range(x.size(2) // self.patch_size):
            for j in  range(x.size(3) // self.patch_size):
                    patch[:, i * (x.size(2) // self.patch_size) + j, :, :] = x[..., i * self.patch_size:(i+1) * self.patch_size, j * self.patch_size:(j+1) * self.patch_size]
        return patch
img = train_ds[1][0]
plt.imshow(img.detach().cpu().numpy().transpose(1, 2, 0))
plt.axis("off")
patcher = PatchLayerWithoutChannel(6)
patched_img = patcher(img.unsqueeze(0))
patched_img[0, 0].size()
train_ds[5][0].detach().cpu().numpy()
fig, axes = plt.subplots(12, 12, figsize=(6, 6))
for i in range(12):
    for j in range(12):
        axes[i, j].imshow(patched_img[0, i * 12 + j].detach().cpu().numpy().transpose(1, 2, 0))
        axes[i, j].axis("off")

In [None]:
## step 1: patchify Layer
class PatchLayer(nn.Module):
    def __init__(self, patch_size: int, device="cuda"):
        super().__init__()
        self.patch_size = patch_size
        self.device = device
    # input = (N, C, H, W) ===> return = (N, CHW/ P^2, P^2)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.size(2) == x.size(3), "image size should be square"
        HW = x.size(2) ** 2
        channel_in = x.size(1)
        assert HW % (self.patch_size ** 2) == 0, "image size should be devisible by patch!"
        patch_num = (channel_in * HW) // (self.patch_size ** 2)
        patch = torch.zeros(x.size(0), patch_num, self.patch_size ** 2, device=device)
        for i in range(x.size(2) // self.patch_size):
            for j in  range(x.size(3) // self.patch_size):
                for c in range(channel_in):
                    temp = x[:, c, i * self.patch_size:(i+1) * self.patch_size, j * self.patch_size:(j+1) * self.patch_size]
                    patch[:, c + (2 * i) + 1, :] = temp.reshape(x.size(0), -1)
        return patch

patcher = PatchLayer(6)
images, _ = next(iter(train_dl))
patcher(images).size()

torch.Size([32, 432, 36])

In [None]:
## step 2: Map Layer
class LinearMap(nn.Module):
    def __init__(self, patch_size, hidden_size):
        super().__init__()
        self.fc = nn.Linear(patch_size ** 2, hidden_size)

    # input = (N, S, d) ===> return = (N, S, d')
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

patcher = PatchLayer(6)
mapper = LinearMap(6, 8)
images, _ = next(iter(train_dl))
mapper(patcher(images)).size()

torch.Size([32, 432, 8])

In [None]:
## step 3: Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, hidden_size) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        pe = torch.zeros(1, num_patches + 1, hidden_size)
        position = torch.arange(0, num_patches + 1).float().unsqueeze(1)
        step_compnent = torch.arange(0, hidden_size, 2).float()
        w = torch.exp(step_compnent * (-math.log(10_000) / hidden_size))

        pe[0, :, 0::2] = torch.sin(w * position)
        pe[0, :, 1::2] = torch.cos(w * position)

        self.v_class = nn.Parameter(torch.randn(1, hidden_size))
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = torch.cat([self.v_class.expand(x.size(0), 1, -1), x], dim=1)
        return x + self.pe


patcher = PatchLayer(patch_size)
mapper = LinearMap(patch_size, hidden_size)
po_en = PositionalEncoding(num_patches * 3, hidden_size)
images, _ = next(iter(train_dl))
po_en(mapper(patcher(images))).size()

torch.Size([32, 433, 64])

In [None]:
## step 4: Transformer Block
class MultiHead(nn.Module):
    def __init__(self, d_model=512, num_head=8, mask=None,  *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        assert d_model % num_head == 0, "d_model divisible by num_head"
        self.d_model = d_model
        self.num_head = num_head
        self.mask = mask
        self.dim_h = d_model // num_head

        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

    def split_(self, x):
        return x.view(x.size(0), x.size(1) * self.num_head, self.dim_h)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        btz, seq_len, _ = query.size()
        # Linear + split
        q, k, v = self.split_(self.linear_q(query)), self.split_(self.linear_k(
            key)), self.split_(self.linear_v(value))

        _, _, d_key = k.size()
        query_scaled = q / math.sqrt(d_key)

        attn_output_weights = torch.bmm(query_scaled, k.transpose(-2, -1))

        if self.mask:
            attn_output_weights.masked_fill_(self.mask == 0, -1e12)

        attn_output_weights = F.softmax(attn_output_weights, dim=-1)

        atten_output = torch.bmm(attn_output_weights, v)

        # concat
        atten_output = atten_output.transpose(
            0, 1).contiguous().view(btz * seq_len, self.d_model)
        atten_output = self.linear_out(atten_output)
        return atten_output.view(btz, seq_len, self.d_model)


class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model=64, num_head=8, mask=None):
        super().__init__()
        self.nl1 = nn.LayerNorm(d_model)
        self.mha = MultiHead(d_model=d_model, num_head=num_head)
        self.nl2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

    def forward(self, x):
        x = x + self.mha(self.nl1(x), self.nl1(x), self.nl1(x))
        x = x + self.mlp(self.nl2(x))
        return x

class TransformerEncoder(nn.Module):
    def __init__(self,  d_model=64, num_head=8, mask=None, L: int=2):
        super().__init__()
        self.encoder = nn.ModuleList([])
        for i in range(L):
            self.encoder.append(TransformerEncoderBlock(d_model=d_model, num_head=num_head, mask=mask))

    def forward(self, x):
        for module in self.encoder:
            x = module(x)
        return x

x = torch.randn(2, 433, 64)
tb = TransformerEncoder(d_model=hidden_size, num_head=8)
tb(x).size()

torch.Size([2, 433, 64])

In [None]:
## step 5: classifier Block
class Classifier(nn.Module):
    def __init__(self, dropout=None, model_dim=64, target_size=100):
        super().__init__()
        self.dropout = dropout
        self.fc = nn.Linear(model_dim, target_size)

    def forward(self, x):
        if self.dropout:
            x = F.dropout(x, self.dropout)
        return self.fc(x)

In [None]:
class VIT(nn.Module):
    def __init__(self, device="cuda", patch_size: int = 6, hidden_size: int = 64, target_size: int=100, num_patches: int = 432, num_head=8, mask=None, L: int=2):
        super().__init__()
        self.patcher = PatchLayer(patch_size, device=device)
        self.mapper = LinearMap(patch_size, hidden_size)
        self.pe = PositionalEncoding(num_patches, hidden_size)
        self.transformer_encoder = TransformerEncoder(d_model=hidden_size, num_head=num_head, mask=mask, L=L)
        self.classifier = Classifier(dropout=0.5, model_dim=hidden_size, target_size=target_size)

    def forward(self, x):
        x = self.patcher(x)
        x = self.mapper(x)
        x = self.pe(x)
        x = self.transformer_encoder(x)
        x = self.classifier(x[:, 0, :])
        return x

x = torch.randn(2, 3, 72, 72).to(device)
vit = VIT().to(device)
vit(x).size()

torch.Size([2, 100])

In [None]:
model = VIT(device=device).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(num_epochs):
    train_loss = 0.0
    for images, labels in tqdm(train_dl):
        images = images.to(device)
        labels = labels.to(device)
        prediction = model(images)
        loss = criterion(prediction, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    print(f'Epoch {epoch+1}/{num_epochs} | train loss: {train_loss/len(train_dl)}')