<a href="https://colab.research.google.com/github/LTopaloglou/transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [54]:
import torch
import torch.nn as nn
from torch.nn import functional as F

#Setup hyperparameters.


#Define a single Scaled Dot-Product Attention head
class Attention(nn.Module):

  def __init__(self, d_k, mask):
    super().__init__()
    self.mask = mask
    self.d_k = d_k
    self.register_buffer('lowerTriangle', torch.tril(torch.ones(d_k, d_k)))

  def forward(self, Q, K, V):
    #Takes in vectors of Queries, Keys & Values. Queries, Keys & Values have dimension d_k x 1 (one dimensional vectors)
    if (Q.shape[0] != self.d_k or K.shape[0] != self.d_k or V.shape[0] != self.d_k):
      raise Exception('Invalid Query, Key or Value Dimensions')
    #First take the dot product of Queries & Keys
    weight = Q @ K.transpose(0, 1)
    #Now scale by 1/sqrt(d_k)
    weight = weight  * (self.d_k**-0.5)
    #Mask everything not in the lower triangular of weight
    if (self.mask):
      weight = weight.masked_fill(self.lowerTriangle== 0, float('-inf'))
    #Now apply softmax in the dimension of the rows
    weight = weight.softmax(1)
    #Finally, multiply values by weights to get the outputs
    out = weight @ V # [d_k, d_k] x [d_k, 1] = [d_k, 1]
    return out


In [57]:
#For testing & validation
att = Attention(5, True)
keys = torch.ones(5, 1)
queries = torch.ones(5, 1)
values = torch.ones(5, 1)
print(att.forward(queries, keys, values))

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])
