In [63]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [64]:
class MyMSA(nn.Module):
  def __init__(self, d, n_heads=2):
    super(MyMSA, self).__init__()
    self.d = d
    self.n_heads = n_heads
    
    assert d % n_heads == 0, "The dimension d of the token must be divisible by number of heads n_heads"
    
    self.d_head = int(d / n_heads)
    
    self.q_mappins = [nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)]
    self.k_mappins = [nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)]
    self.v_mappins = [nn.Linear(self.d_head, self.d_head) for _ in range(n_heads)]
    
    self.softmax = nn.Softmax(dim=-1)
  
  def forward(self, sequences):
    result = []
    for sequence in sequences:
      seq_result = []
      for head in range(self.n_heads):
        q_mappin = self.q_mappins[head]
        k_mappin = self.k_mappins[head]
        v_mappin = self.v_mappins[head]
        
        seq = sequence[:, head * self.d_head:(head + 1) * self.d_head]
        
        q, k, v = q_mappin(seq), k_mappin(seq), v_mappin(seq)
        
        attention = self.softmax((q @ k.T) / np.sqrt(self.d_head))
        
        seq_result.append(attention @ v)
      result.append(torch.hstack(seq_result))
    return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [65]:
class MyVIT(nn.Module):
  def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2):
    super(MyVIT, self).__init__()
    self.input_shape = input_shape # (C, H, W)
    self.n_patches = n_patches
    self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
    self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
    
    self.hidden_d = hidden_d # even if C > 1 we will keep this dimension and we pass each row vector of 16 to the linear mapper each time
    
    assert input_shape[1] % n_patches == 0, "Height mst be divisible by number of patches"
    assert input_shape[2] % n_patches == 0, "Width mst be divisible by number of patches"
    
    # 1) linear mapper
    self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
    
    # 2) classification token
    self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
    
    # 3) positional embeddings
    # in forward pass
  
    # 4)a) layer normalization 1
    self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d)) 
    
    # 4)b) multi-head self attention (msa) and classification token
    self.msa = MyMSA(self.hidden_d, n_heads)
    
  def forward(self, images):
    n, c, h, w = images.shape
    # reshapes the images into patches
    patches = images.reshape(n, self.n_patches ** 2, self.input_d)
    
    # running the patches into the linear mapper for tokenization
    tokens = self.linear_mapper(patches) # or embeddings
    
    # adding the classification token
    patches = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    
    # adding the positional embeddings to the tokens
    tokens += self.get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d).repeat(n, 1, 1)
    
    # TRANSFORMER ENCODER BEGINS #############################
    # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER #####
    
    # running Layer Normalization, MSA and residual connection
    out = tokens + self.msa(self.ln1(tokens))
    
    return out
  
  def get_positional_embeddings(self, sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
      for j in range(d):
        result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result
  
  def __call__(self, images):
    return self.forward(images)

    

In [68]:
# main program for testing purposes
if __name__ == '__main__':
  model = MyVIT(input_shape=(1, 28, 28))
  images = torch.rand(3, 1, 28, 28)
  print(model(images).shape)
  plt.imshow(model.get_positional_embeddings(100, 300), cmap='hot', interpolation='nearest')
  plt.show()

RuntimeError: The size of tensor a (49) must match the size of tensor b (50) at non-singleton dimension 1