In [14]:
import numpy as np
import torch
import torch.nn as nn

In [15]:
class MyVIT(nn.Module):
  def __init__(self, input_shape, n_patches=7, hidden_d=8):
    super(self, MyVIT).__init__()
    self.input_shape = input_shape # (C, H, W)
    self.n_patches = n_patches
    self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
    self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
    
    self.hidden_d = hidden_d
    
    assert input_shape[1] % n_patches == 0, "Height mst be divisible by number of patches"
    assert input_shape[2] % n_patches == 0, "Width mst be divisible by number of patches"
    
    # 1) linear mapper
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
    
    # 2) classification token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
    
  def forward(self, images):
    n, c, h, w = images.shape
    # reshapes the images into patches
    patches = images.reshape(n, self.n_patches ** 2, self.input_d)
    
    # running the patches into the linear mapper 
    tokens = self.linear_mapper(patches)
    
    # adding the classification token to the patches
    patches = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    
    return tokens
    