<a href="https://colab.research.google.com/github/LTopaloglou/transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
from torch.torch_version import Version
import torch
import torch.nn as nn
from torch.nn import functional as F

#I am using variable names defined in "Attention Is All You Need" for simplicity.

#Define a single Scaled Dot-Product Attention head. I am simplifying it by enforcing d_k = d_v
class Attention(nn.Module):

  def __init__(self, d_k, mask):
    super().__init__()
    self.mask = mask
    self.d_k = d_k
    self.register_buffer('lowerTriangle', torch.tril(torch.ones(d_k, d_k)))

  def forward(self, Q, K, V):
    #Takes in vectors of Queries, Keys & Values. Queries, Keys & Values have dimension seq x d_k
    if (Q.shape[1] != self.d_k or K.shape[1] != self.d_k or V.shape[1] != self.d_k):
      raise Exception('Invalid Query, Key or Value Dimensions')
    #First take the dot product of Queries & Keys. Weight has dimensions seq x seq
    weight = Q @ K.transpose(0, 1)
    #Now scale by 1/sqrt(d_k)
    weight = weight  * (self.d_k**-0.5)
    #Mask everything not in the lower triangular of weight
    if (self.mask):
      weight = weight.masked_fill(self.lowerTriangle== 0, float('-inf'))
    #Now apply softmax in the dimension of the rows
    weight = weight.softmax(1)
    #Finally, multiply values by weights to get the outputs
    out = weight @ V # [seq, seq] x [seq, d_k] = [seq, d_k]
    return out

#Define a Multi-Head Attention layer. As in the paper, I am setting d_k = d_model/h
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model, h, mask):
    super().__init__()
    self.mask = mask
    self.d_model = d_model
    self.h = h
    self.d_k = int(d_model/h)
    if (self.d_k != d_model/h):
      raise Exception('Invalid Dimensions Provided') #Ensure valid dimensions
    self.W_O = nn.Linear(h*self.d_k, d_model, bias=False) #Output linear layer
    self.W_Q = nn.ModuleList() #The h different Query linear layers
    self.W_K = nn.ModuleList() #The h different Key linear layers
    self.W_V = nn.ModuleList() #The h different Value linear layers
    self.Att = Attention(self.d_k, mask) #I think (?) only 1 attention layer is needed since no backprop happens TODO: VERIFY
    for i in range(h):
      #Initialize all h linear layers for Q, K, V
      self.W_Q.append(nn.Linear(d_model, self.d_k, bias=False))
      self.W_K.append(nn.Linear(d_model, self.d_k, bias=False))
      self.W_V.append(nn.Linear(d_model, self.d_k, bias=False)) #TODO: Instead of Linear layers, just create Matrices

  def forward(self, Q, K, V):
    #The inputs are Queries, Keys and Vectors which each have size seq x d_model
    if (Q.shape[1] != self.d_model or K.shape[1] != self.d_model or V.shape[1] != self.d_model):
      raise Exception('Invalid Query, Key or Value Dimensions')
    heads = []
    for i in range(self.h):
      queries = self.W_Q[i](Q)
      keys = self.W_K[i](K)
      values = self.W_V[i](V)
      #At this point, queries keys & values have dimensions seq x d_k
      heads.append(self.Att.forward(queries, keys, values))
    out = torch.cat(heads, 1) #The output has the same dimension as all the inputs: seq x d_model
    return out

In [57]:
#For testing & validation of Attention
att = Attention(5, True)
keys = torch.ones(5, 1)
queries = torch.ones(5, 1)
values = torch.ones(5, 1)
print(att.forward(queries, keys, values))

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])


In [20]:
#For testing & validation of MultiHeadAttention
#Input is sequence size by d_model
#Q, K, V are all that size - in self attention Q=K=V=Input
#Hyperparams
seq = 10 #Sequence size of 4
d_model = 16 #Embedding vectors are of size 16
heads = 8 #Number of heads in our multihead attention
mult = MultiHeadAttention(d_model, heads, False)
Input = torch.ones(seq, d_model)
mult.forward(Input, Input, Input)

torch.Size([10, 16])
tensor([[ 0.4809,  0.6958,  0.0646,  0.2008, -0.7846, -0.6914, -0.3471,  0.0936,
          0.4514, -0.0206, -0.0639,  0.2451,  0.0504, -0.9306, -0.2231,  0.5706],
        [ 0.4809,  0.6958,  0.0646,  0.2008, -0.7846, -0.6914, -0.3471,  0.0936,
          0.4514, -0.0206, -0.0639,  0.2451,  0.0504, -0.9306, -0.2231,  0.5706],
        [ 0.4809,  0.6958,  0.0646,  0.2008, -0.7846, -0.6914, -0.3471,  0.0936,
          0.4514, -0.0206, -0.0639,  0.2451,  0.0504, -0.9306, -0.2231,  0.5706],
        [ 0.4809,  0.6958,  0.0646,  0.2008, -0.7846, -0.6914, -0.3471,  0.0936,
          0.4514, -0.0206, -0.0639,  0.2451,  0.0504, -0.9306, -0.2231,  0.5706],
        [ 0.4809,  0.6958,  0.0646,  0.2008, -0.7846, -0.6914, -0.3471,  0.0936,
          0.4514, -0.0206, -0.0639,  0.2451,  0.0504, -0.9306, -0.2231,  0.5706],
        [ 0.4809,  0.6958,  0.0646,  0.2008, -0.7846, -0.6914, -0.3471,  0.0936,
          0.4514, -0.0206, -0.0639,  0.2451,  0.0504, -0.9306, -0.2231,  0.5706],
 