In [1]:
%config IPCompleter.greedy = True

In [2]:
import math
import torch
from torch import nn

In [3]:
class PatchEmbeddings(nn.Module):
    #Convert the image into patches and then projection 

    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]

        self.num_patches = (self.image_size // self.patch_size) ** 2

        self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x

In [4]:
class Embeddings(nn.Module):
    #Combine the patch embeddings with the class token and position embeddings.

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)

        #making a token that can be added to input sequence and used to classify
        self.classify_t = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))

        #create position embeddings for the token and patch embeddings and adding 1 to sequence length for the token
        self.position_embeddings = \
            nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["hidden_size"]))
        self.Dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        batch_size, _, _ = x.size()

        classify_ts = self.classify_token.expand(batch_size, -1, -1)

        x = torch.cat((classify_t, x), dim=1)
        x = x + self.position_embeddings
        x = self.dropout(x)
        return x

In [5]:
class AttentionHead(nn.Module):
    #single attention head
    #multiple of these are used in multihead attention

    def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size

        #creating query, key and value projection layers
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        #projecting the input in query, key and value
        #then using the same to generate the query, value, and key
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        #attention scores
        #softmax(Q*K.T/sqrt(head_size))*V
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_score = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.Dropout(attention_probs)

        #calculate the attention output
        attention_output = torch.matmul(attention_probs, value)
        return (attenion_output, attention_probs)

In [None]:
class MultiHeadAttention(nn.Module):
    #multi head attention
    #this module is used in Transformer encode module

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_attention_heads = config["num_attention_heads"]

        #calculation attention head size
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        #to use bias or not in projections
        self.qkv_bias = config["bias"]

        #creating query, key and value projection layers
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)
        
        self.dropout = nn.Dropout(dropout)    