In [None]:
import torch
from torchvision.datasets import OxfordIIITPet
import matplotlib.pyplot as plt
from random import random
from torchvision.transforms import Resize, ToTensor
from torchvision.transforms.functional import to_pil_image

to_tensor = [Resize((144,144)), ToTensor()]

class Compose(object):
  def __init__(self, transforms):
    self.transforms = transforms
    
  def __call__(self, image):
    for t in self.transforms:
      image = t(image)
      
    return image
  
def show_img(images, num_samples=20, columns=4):
  plt.figure(figsize=(15,15))
  idx = int(len(dataset)/num_samples)
  for i, img in enumerate(images):
    if i% idx == 0:
      plt.subplot(int(num_samples/columns)+1, columns, int(i/idx)+1)
      plt.imshow(to_pil_image(img[0]))
  
  
dataset = OxfordIIITPet(root=".", download=True, transform=Compose(to_tensor))
show_img(dataset)

In [1]:
from torch import nn
from einops.layers.torch import Rearrange
from torch import Tensor

class PatchEmbedding(nn.Module):
  def __init__(self, in_channels=3, patch_size = 8, emb_size = 128):
    super().__init__()
    self.patch_size = patch_size
    self.projection = nn.Sequential(
      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
      nn.Linear(patch_size*patch_size*in_channels, emb_size)
    )
    
  def forward(self, x: Tensor) -> Tensor:
    x = self.projection(x)
    return x 

In [None]:
class Attention(nn.Module):
  def __init__(self, dim, n_heads, dropout) -> None:
    super().__init__()
    self.n_heads = n_heads
    self.att = torch.nn.MultiheadAttention(
      embed_dim=dim,
      num_heads=n_heads,
      dropout=dropout
    )
    self.q = nn.Linear(dim, dim)
    self.v = nn.Linear(dim, dim)
    self.k = nn.Linear(dim, dim)
    
  def forward(self, x):
    q = self.q(x)
    v = self.v(x)
    k = self.k(x)
    attn_output, _ = self.att(q, k, v)
    return attn_output
    

In [3]:
class FeedForward(nn.Sequential):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

In [None]:
class Block(nn.Module):
    
    def __init__(self, emb_dim=32) -> None:
        super().__init__()
        self.sa = Attention()
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)


    def forward(self, ix):
        ix = ix + self.ffwd(self.ln1(ix))
        ix = ix + self.sa(self.ln2(ix))
        return ix
        

In [None]:
class FeedForward(nn.Module):
    
    def __init__(self, emb_dim) -> None:
        super().__init__()
        self.lin1 = nn.Linear(emb_dim, 4*emb_dim)
        self.tanh = nn.Tanh()
        self.lin2 = nn.Linear(4*emb_dim, emb_dim)
    
    
    def forward(self, ix):
        out = self.lin1(ix)
        out = self.tanh(out)
        out = ix + self.lin2(out)
        return out
    

In [None]:
from einops import repeat

class ViT(nn.Module):
  def __init__(self, ch=3, img_size=144, patch_size=4, emb_dim=32,
                n_layers=6, out_dim=37, dropout=0.1, heads=2):
    super(ViT, self).__init__()

    self.channels = ch
    self.height = img_size
    self.width = img_size
    self.patch_size = patch_size
    self.n_layers = n_layers
    self.blocks = nn.Sequential(*[Block() for _ in range(n_layers)])
    self.l_head = nn.Linear(emb_dim, out_dim)

    self.patch_embedding = PatchEmbedding(in_channels=ch, patch_size=patch_size, emb_size=emb_dim)
    num_patches = (img_size // patch_size) ** 2
    self.pos_embedding = nn.Parameter(
        torch.randn(1, num_patches + 1, emb_dim))
    self.cls_token = nn.Parameter(torch.rand(1, 1, emb_dim))
    
    
  def forward(self, img):
    x = self.patch_embedding(img)
    b, n, _ = x.shape
    cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
    x = torch.cat([cls_tokens, x], dim=1)
    x += self.pos_embedding[:, :(n + 1)]

    for i in range(self.n_layers):
        x = self.layers[i](x)

    return self.head(x[:, 0, :])
    
