<a href="https://colab.research.google.com/github/Kkun84/patch_transformer/blob/master/colab/patch_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [52]:
!pip install torchinfo



In [53]:
import torch
from torch import nn, Tensor
from torch import tensor
import torch.nn.functional as F


class TransformerModel(nn.Module):
    def __init__(
        self,
        *,
        image_size: int = 28,
        image_channels: int = 1,
        patch_size: int = 7,
        num_classes: int = 10,
        dim: int = 64,
        nhead: int = 1,
        dim_feedforward: int = 64,
        depth: int = 3,
        dropout: float = 0.5,
    ) -> None:
        super().__init__()

        assert (
            image_size % patch_size == 0
        ), f'{image_size}, {patch_size}, {image_size % patch_size}'

        self.input_shape = (image_channels, image_size, image_size)

        self.image_size = image_size
        self.image_channels = image_channels
        self.patch_size = patch_size
        self.dim = dim

        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = image_channels * patch_size ** 2

        # self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.patch_embedding = nn.Linear(self.patch_dim, dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, num_layers=depth
        )

        self.mlp_head = nn.Linear(dim, num_classes)

    def unfold_patch(self, image: Tensor) -> Tensor:
        batch_size = len(image)

        verify_shape = torch.Size(
            [batch_size, self.image_channels, self.image_size, self.image_size]
        )
        assert image.shape == verify_shape, f'{image.shape}, {verify_shape}'

        x = image.unfold(2, self.patch_size, self.patch_size).unfold(
            3, self.patch_size, self.patch_size
        )

        verify_shape = torch.Size(
            [
                batch_size,
                self.image_channels,
                self.image_size // self.patch_size,
                self.image_size // self.patch_size,
                self.patch_size,
                self.patch_size,
            ]
        )
        assert x.shape == verify_shape, f'{x.shape}, {verify_shape}'

        x = x.permute(0, 2, 3, 1, 4, 5).reshape(
            batch_size,
            self.num_patches,
            self.image_channels,
            self.patch_size,
            self.patch_size,
        )
        return x

    def forward(self, image: Tensor) -> Tensor:
        batch_size = len(image)

        patches = self.unfold_patch(image)
        x = patches.flatten(2)

        verify_shape = torch.Size([batch_size, self.num_patches, self.patch_dim])
        assert x.shape == verify_shape, f'{x.shape}, {verify_shape}'

        x = self.patch_embedding(x)
        x = self.transformer_encoder(x)
        x = x.mean(1)

        verify_shape = torch.Size([batch_size, self.dim])
        assert x.shape == verify_shape, f'{x.shape}, {verify_shape}'

        x = self.mlp_head(x)
        return x

In [54]:
import shutil
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm


batch_size = 64
lr = 0.001
max_epoch = 30


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
writer = SummaryWriter()
log_dir = Path(writer.get_logdir())

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)
test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
)

model = TransformerModel().to(device)
print(summary(model, (2, *model.input_shape)))

optimizer = optim.Adam(model.parameters(), lr=lr)

n_iter = 0

Layer (type:depth-idx)                        Output Shape              Param #
TransformerModel                              --                        --
├─Linear: 1-1                                 [2, 16, 64]               3,200
├─TransformerEncoder: 1-2                     [2, 16, 64]               --
│    └─ModuleList: 2                          --                        --
│    │    └─TransformerEncoderLayer: 3-1      [2, 16, 64]               25,216
│    │    └─TransformerEncoderLayer: 3-2      [2, 16, 64]               25,216
│    │    └─TransformerEncoderLayer: 3-3      [2, 16, 64]               25,216
├─Linear: 1-3                                 [2, 10]                   650
Total params: 79,498
Trainable params: 79,498
Non-trainable params: 0
Total mult-adds (M): 0.06
Input size (MB): 0.01
Forward/backward pass size (MB): 0.21
Params size (MB): 0.32
Estimated Total Size (MB): 0.54


In [None]:
for epoch in range(max_epoch):
    model.train()
    for i, (images, labels) in tqdm(
        enumerate(train_dataloader),
        desc=f'Train {epoch}/{max_epoch-1}',
        total=len(train_dataloader),
    ):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = F.cross_entropy(outputs, labels, reduction='mean')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        accuracy = (outputs.max(1)[1] == labels).float().mean().item()

        writer.add_scalar('metrics/train_loss', loss.item(), n_iter)
        writer.add_scalar('metrics/train_accuracy', accuracy, n_iter)

        n_iter += 1

    model.eval()

    loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in tqdm(
            test_dataloader,
            desc=f'Test {epoch}/{max_epoch-1}',
            total=len(test_dataloader),
        ):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            loss += F.cross_entropy(outputs, labels, reduction='sum').item()
            correct += (outputs.max(1)[1] == labels).sum().item()
            total += len(labels)

        loss /= total
        accuracy = correct / total

        writer.add_scalar('metrics/test_loss', loss, n_iter)
        writer.add_scalar('metrics/test_accuracy', accuracy, n_iter)

    torch.save(model.state_dict(), log_dir / f'epoch{epoch:05}.pt')

Train 0/29: 100%|██████████| 938/938 [00:13<00:00, 68.75it/s]
Test 0/29: 100%|██████████| 157/157 [00:01<00:00, 129.01it/s]
Train 1/29: 100%|██████████| 938/938 [00:13<00:00, 68.61it/s]
Test 1/29: 100%|██████████| 157/157 [00:01<00:00, 126.04it/s]
Train 2/29: 100%|██████████| 938/938 [00:13<00:00, 67.87it/s]
Test 2/29: 100%|██████████| 157/157 [00:01<00:00, 124.99it/s]
Train 3/29: 100%|██████████| 938/938 [00:13<00:00, 68.18it/s]
Test 3/29: 100%|██████████| 157/157 [00:01<00:00, 126.34it/s]
Train 4/29: 100%|██████████| 938/938 [00:13<00:00, 68.32it/s]
Test 4/29: 100%|██████████| 157/157 [00:01<00:00, 127.42it/s]
Train 5/29: 100%|██████████| 938/938 [00:13<00:00, 67.96it/s]
Test 5/29: 100%|██████████| 157/157 [00:01<00:00, 126.04it/s]
Train 6/29: 100%|██████████| 938/938 [00:13<00:00, 68.83it/s]
Test 6/29: 100%|██████████| 157/157 [00:01<00:00, 123.32it/s]
Train 7/29: 100%|██████████| 938/938 [00:13<00:00, 67.98it/s]
Test 7/29: 100%|██████████| 157/157 [00:01<00:00, 126.12it/s]
Train 8/

In [None]:
for index in range(20):
    print(index)
    image, label = test_dataset[index]

    display(to_pil_image(image))

    patches = model.unfold_patch(image[None])[0]
    display(to_pil_image(make_grid(patches, nrow=model.num_patches, pad_value=0.5)))

    output_prob = model(image[None])[0].softmax(0)

    print(output_prob.tolist())

    print(output_prob.argmax())

    print()