In [2]:
import torch
import torch.nn as nn

In [3]:
IMG_SIZE = 144
PATCH_SIZE = 8
EMB_DIM = 32

In [4]:
from torchvision.transforms import v2
from torchvision import transforms
import torchvision
from torchvision.transforms.functional import to_pil_image

transform = transforms.Compose([
    transforms.Resize((144, 144)),
    transforms.ToTensor()
])

import matplotlib.pyplot as plt

def show_images(images, num_samples=20, cols=4):
    """Plots some samples from the dataset"""
    plt.figure(figsize=(15, 15))
    idx = int(len(dataset) / num_samples)
    print(images)

    for i, img in enumerate(images):
        if i % idx == 0:
            plt.subplot(int(num_samples / cols) + 1, cols, int(i / idx) + 1)
            plt.imshow(to_pil_image(img[0]))


dataset = torchvision.datasets.OxfordIIITPet('./ds', download=True, transform=transform)
# show_images(dataset)

100%|██████████| 792M/792M [00:36<00:00, 21.8MB/s] 
100%|██████████| 19.2M/19.2M [00:02<00:00, 9.35MB/s]


In [5]:
len(dataset.classes)

37

In [6]:
class PatchEmbedding(nn.Module):

    def __init__(self, in_chans=3, patch_size=8, emb_dim=32):
        super().__init__()
        self.patch_size = patch_size
        self.unfold = nn.Unfold(patch_size, stride=patch_size)
        self.proj = nn.Linear(3 * patch_size * patch_size, emb_dim)

    def forward(self, x):

        x = self.unfold(x) # (B, C*P*P, N)
        x = x.transpose(-1, -2) # (B, N, CPP)
        x = self.proj(x)
        return x
        
inp = torch.randn(1, 3, 144, 144)
pe = PatchEmbedding(in_chans=3, patch_size=8, emb_dim=128)
inp.shape, pe(inp).shape

(torch.Size([1, 3, 144, 144]), torch.Size([1, 324, 128]))

In [7]:
class Attention(nn.Module):

    def __init__(self, emb_dim=32, n_heads=2, dropout=0.):
        super().__init__()

        self.attention = nn.MultiheadAttention(emb_dim, n_heads, dropout, batch_first=True)

        self.q = nn.Linear(emb_dim, emb_dim)
        self.k = nn.Linear(emb_dim, emb_dim)
        self.v = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        attn_output, attn_output_weights = self.attention(q, k, v)
        return attn_output

In [8]:
attn = Attention()(torch.zeros(1, 5, 32))
attn[0].shape

torch.Size([5, 32])

In [9]:
class Prenorm(nn.Module):

    def __init__(self, fn, emb_dim=32):
        super().__init__()

        self.layer_norm = nn.LayerNorm(emb_dim)
        self.fn = fn

    def forward(self, x):
        out = self.fn(self.layer_norm(x)) # Attn(LN(x))
        return out
    

In [10]:
att = torch.tensor([
    [1, 1, 1, 1],
    [1, 2, 3, 10]
], device='cpu', dtype=torch.float32)

last_dim = att.shape[-1]
att = torch.unsqueeze(att, 0)
layer_norm = nn.LayerNorm(last_dim, device='cpu')

print(att.shape)
print(att)
layer_norm(att)

torch.Size([1, 2, 4])
tensor([[[ 1.,  1.,  1.,  1.],
         [ 1.,  2.,  3., 10.]]])


tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],
         [-0.8485, -0.5657, -0.2828,  1.6971]]],
       grad_fn=<NativeLayerNormBackward0>)

In [11]:
Prenorm(Attention(32, 1, 0.0))(torch.zeros(1, 5, 32))

tensor([[[-0.0205,  0.0404,  0.0117, -0.0022,  0.0241, -0.0385, -0.0441,
           0.0777,  0.0837, -0.0639, -0.0931, -0.0232, -0.0108, -0.0322,
           0.0176,  0.0527,  0.0941, -0.0132,  0.0397, -0.0193, -0.0001,
           0.0213, -0.0204,  0.0206, -0.0676, -0.0797, -0.0829, -0.0929,
           0.0083,  0.0649,  0.0093, -0.0020],
         [-0.0205,  0.0404,  0.0117, -0.0022,  0.0241, -0.0385, -0.0441,
           0.0777,  0.0837, -0.0639, -0.0931, -0.0232, -0.0108, -0.0322,
           0.0176,  0.0527,  0.0941, -0.0132,  0.0397, -0.0193, -0.0001,
           0.0213, -0.0204,  0.0206, -0.0676, -0.0797, -0.0829, -0.0929,
           0.0083,  0.0649,  0.0093, -0.0020],
         [-0.0205,  0.0404,  0.0117, -0.0022,  0.0241, -0.0385, -0.0441,
           0.0777,  0.0837, -0.0639, -0.0931, -0.0232, -0.0108, -0.0322,
           0.0176,  0.0527,  0.0941, -0.0132,  0.0397, -0.0193, -0.0001,
           0.0213, -0.0204,  0.0206, -0.0676, -0.0797, -0.0829, -0.0929,
           0.0083,  0.0649,  0

In [12]:
class FeedForward(nn.Sequential):
    def __init__(self, emb_dim=32, hidden_dim=32*3, dropout=0.):
        super().__init__(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, emb_dim),
            nn.Dropout(dropout)
        )

ff = FeedForward()
ff(torch.ones(1, 5, 32)).shape

torch.Size([1, 5, 32])

In [13]:
class TranResBlock(nn.Module):

    def __init__(self, fn, emb_dim=32, dropout=0.):
        super().__init__()

        self.pre_norm = Prenorm(fn, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inp):
        res = inp
        inp = self.pre_norm(inp)
        out = self.dropout(inp) + res
        return out

In [14]:
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

posemb_sincos_2d(144 // 8, 144 // 8, 128).shape

torch.Size([324, 128])

In [None]:
class ViT(nn.Module):
    def __init__(self, 
                in_chans=3, 
                img_size=IMG_SIZE, 
                patch_size=PATCH_SIZE, 
                emb_dim=EMB_DIM, 
                n_heads=2, 
                n_layers=4, 
                dim_out=37,
                dropout=0.1):
        
        super().__init__()

        self.in_chans = in_chans
        self.img_size = img_size
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dim_out = dim_out
        self.dropout = dropout
        self.n_patches = (img_size ** 2) // (patch_size ** 2) 

        self.patch_emb = PatchEmbedding(in_chans, patch_size, emb_dim)
        self.cls_token = nn.Parameter(torch.randn(1, emb_dim))

        self.register_buffer('pos_enc',
            posemb_sincos_2d(img_size // patch_size + 1,
                            img_size // patch_size + 1,
                            emb_dim)
            .unsqueeze(0)[:, :self.n_patches+1, :]
        )

        # self.pos_enc = posemb_sincos_2d(img_size // patch_size + 1, img_size // patch_size + 1, emb_dim) \
        #                                     .unsqueeze(0)[:, :self.n_patches+1, :]

        self.layers = nn.ModuleList([])

        for i in range(n_layers):
            transformer_block = nn.Sequential(
                TranResBlock(
                    fn=Attention(self.emb_dim, self.n_heads, self.dropout),
                    emb_dim=emb_dim,
                    dropout=dropout
                ),
                TranResBlock(
                    fn=FeedForward(emb_dim=emb_dim, hidden_dim=emb_dim*3, dropout=dropout),
                    emb_dim=emb_dim,
                    dropout=dropout
                )
            )
            self.layers.append(transformer_block)

        self.fc_cls = nn.Linear(emb_dim, dim_out)


    def forward(self, x):

        B = x.shape[0]

        patches = self.patch_emb(x)
        cls_token = self.cls_token.unsqueeze(0).expand(B, -1, -1)

        inp = torch.cat([cls_token, patches], dim=1)

        inp = inp + self.pos_enc[:, :inp.shape[1], :]

        for block in self.layers:
            inp = block(inp)

        out_logits = self.fc_cls(inp[:, 0, :])
        return out_logits

In [None]:
import torch.optim as optim

device = 'cuda'
print("training on ", device)

trainloader = torch.utils.data.DataLoader(dataset, 16, True)

vit = ViT(in_chans=3, img_size=144, patch_size=8, emb_dim=32, n_heads=2, n_layers=2, dim_out=37).to(device)

optimizer = optim.Adam(vit.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

epochs = 20

for i in range(1, epochs + 1):

    vit.train()
    running_loss = 0.0

    for imgs, labels in trainloader:

        imgs = imgs.to(device)
        labels = labels.to(device)

        preds = vit(imgs)
        loss = loss_fn(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
    epoch_loss = running_loss / len(trainloader)
    print("Epoch {}, Loss {:10.6f}".format(i, epoch_loss))

torch.save(vit.state_dict(), './models/vit_v1.pt')

training on  cuda
Epoch 1, Loss   3.660806
Epoch 2, Loss   3.559315
Epoch 3, Loss   3.452833
Epoch 4, Loss   3.393206
Epoch 5, Loss   3.332496
Epoch 6, Loss   3.308988
Epoch 7, Loss   3.263788
Epoch 8, Loss   3.238080
Epoch 9, Loss   3.202475
Epoch 10, Loss   3.165187
Epoch 11, Loss   3.137716
Epoch 12, Loss   3.117096
Epoch 13, Loss   3.071120
Epoch 14, Loss   3.052170
Epoch 15, Loss   3.006948
Epoch 16, Loss   2.973209
Epoch 17, Loss   2.953056
Epoch 18, Loss   2.945295
Epoch 19, Loss   2.900073
Epoch 20, Loss   2.887796


RuntimeError: Parent directory ./models does not exist.

In [24]:
correct = 0
total = 0

vit.eval()

with torch.no_grad():
    for imgs, labels in trainloader:

        imgs = imgs.to(device)
        labels = labels.to(device)

        preds = vit(imgs)
        predicted_classes = preds.argmax(dim=1)

        correct += (predicted_classes == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"Training Accuracy: {accuracy:.2f}%")


Training Accuracy: 22.64%


In [28]:
correct = 0
total = 0

vit_rand = ViT(in_chans=3, img_size=144, patch_size=8, emb_dim=32, n_heads=2, n_layers=2, dim_out=37).to(device)

vit_rand.eval()

with torch.no_grad():
    for imgs, labels in trainloader:

        imgs = imgs.to(device)
        labels = labels.to(device)

        preds = vit_rand(imgs)
        predicted_classes = preds.argmax(dim=1)

        correct += (predicted_classes == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"Training Accuracy: {accuracy:.2f}%")


Training Accuracy: 2.61%


In [30]:
# guessing randomly
(1 / len(dataset.classes)) * 100

2.7027027027027026