In [1]:
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=96, patch_size=16, num_hiddens=512):
        super().__init__()
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size,
                                  stride=patch_size)

    def forward(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        return self.conv(X).flatten(2).transpose(1, 2)

In [3]:
class ViTMLP(nn.Module):
    def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
        super().__init__()
        self.dense1 = nn.LazyLinear(mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.dense2 = nn.LazyLinear(mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(
            self.dense1(x)))))

In [4]:
class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
                                                dropout, use_bias)
        self.ln2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, X, valid_lens=None):
        X = self.ln1(X)
        return X + self.mlp(self.ln2(
            X + self.attention(X, X, X, valid_lens)))

In [5]:
class ViT(d2l.Classifier):
    """Vision transformer."""
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens,
                 num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
                 use_bias=False, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.patch_embedding = PatchEmbedding(
            img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", ViTBlock(
                num_hiddens, num_hiddens, mlp_num_hiddens,
                num_heads, blk_dropout, use_bias))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens),
                                  nn.Linear(num_hiddens, num_classes))

    def forward(self, X):
        X = self.patch_embedding(X)
        X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])

In [6]:
from tree_dataset import TreeDataset
import model as m
from torch.utils.data import DataLoader
import os
import torchvision.transforms as transforms

In [7]:
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])
train_set = TreeDataset(os.path.join('..', 'data', 'trainset4'), preprocess) 
val_set = TreeDataset(os.path.join('..', 'data', 'trainset1'), preprocess)
print(f'Train size: {len(train_set)} Val size: {len(val_set)}')
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
device = m.get_device()
config = {'labels_key': 'digit_labels'}

Train size: 2000 Val size: 1000
Identified CUDA device: NVIDIA GeForce RTX 3060


In [8]:
img_size, patch_size = 224, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr).to(device)



In [10]:
model(torch.unsqueeze(train_set[0]['image'], 0).to(device))

tensor([[-0.2003,  0.8277, -0.0792,  0.1954,  0.0097,  0.3583, -0.0675, -0.6295,
         -0.1205,  0.0465]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [11]:
model

ViT(
  (patch_embedding): PatchEmbedding(
    (conv): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (blks): Sequential(
    (0): ViTBlock(
      (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attention): MultiHeadAttention(
        (attention): DotProductAttention(
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (W_q): Linear(in_features=512, out_features=512, bias=False)
        (W_k): Linear(in_features=512, out_features=512, bias=False)
        (W_v): Linear(in_features=512, out_features=512, bias=False)
        (W_o): Linear(in_features=512, out_features=512, bias=False)
      )
      (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): ViTMLP(
        (dense1): Linear(in_features=512, out_features=2048, bias=True)
        (gelu): GELU(approximate=none)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dense2): Linear(in_features=2048, out_features=512, b

In [12]:
m.train(model, 0.0001, 0, 100, train_loader, val_loader, device, os.path.join('..', 'models', 'd2lvit'), None)

Epoch 10 done, train loss: 0.0092 val acc: 1.0000
Epoch 20 done, train loss: 0.0024 val acc: 1.0000
Epoch 30 done, train loss: 0.0011 val acc: 1.0000
Epoch 40 done, train loss: 0.0006 val acc: 1.0000
Epoch 50 done, train loss: 0.0004 val acc: 1.0000
Epoch 60 done, train loss: 0.0004 val acc: 1.0000
Epoch 70 done, train loss: 0.0003 val acc: 1.0000
Epoch 80 done, train loss: 0.0002 val acc: 1.0000
Epoch 90 done, train loss: 0.0001 val acc: 1.0000
Epoch 100 done, train loss: 0.0001 val acc: 1.0000


In [14]:
train_acc = m.predict(model, train_loader, device, config, None)
print(train_acc)

1.0


In [15]:
val_acc = m.predict(model, val_loader, device, config, None)
print(val_acc)

1.0
