<a href="https://colab.research.google.com/github/LTopaloglou/transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torch.torch_version import Version
import torch
import torch.nn as nn
from torch.nn import functional as F

#I am using variable names defined in "Attention Is All You Need" for simplicity.

#Define a single Scaled Dot-Product Attention head. I am simplifying it by enforcing d_k = d_v
class Attention(nn.Module):

  def __init__(self, d_k, mask, context_len):
    super().__init__()
    self.mask = mask
    self.d_k = d_k
    self.register_buffer('lowerTriangle', torch.tril(torch.ones(context_len, context_len)))

  def forward(self, Q, K, V):
    #Takes in vectors of Queries, Keys & Values. Queries, Keys & Values have dimension seq x d_k
    if (Q.shape[1] != self.d_k or K.shape[1] != self.d_k or V.shape[1] != self.d_k):
      raise Exception('Invalid Query, Key or Value Dimensions')
    seq = Q.shape[0]
    #First take the dot product of Queries & Keys. Weight has dimensions seq x seq
    weight = Q @ K.transpose(0, 1)
    #Now scale by 1/sqrt(d_k)
    weight = weight  * (self.d_k**-0.5)
    #Mask everything not in the lower triangular of weight
    if (self.mask):
      weight = weight.masked_fill(self.lowerTriangle[:seq, :seq]== 0, float('-inf'))
    #Now apply softmax in the dimension of the rows
    weight = weight.softmax(1)
    #Finally, multiply values by weights to get the outputs
    out = weight @ V # [seq, seq] x [seq, d_k] = [seq, d_k]
    return out

#Define a Multi-Head Attention layer. As in the paper, I am setting d_k = d_model/h
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model, h, mask, context_len):
    super().__init__()
    self.mask = mask
    self.d_model = d_model
    self.h = h
    self.d_k = int(d_model/h)
    if (self.d_k != d_model/h):
      raise Exception('Invalid Dimensions Provided') #Ensure valid dimensions
    self.W_O = nn.Linear(h*self.d_k, d_model, bias=False) #Output linear layer
    self.W_Q = nn.ModuleList() #The h different Query linear layers
    self.W_K = nn.ModuleList() #The h different Key linear layers
    self.W_V = nn.ModuleList() #The h different Value linear layers
    self.Att = Attention(self.d_k, mask, context_len)
    for i in range(h):
      #Initialize all h linear layers for Q, K, V
      self.W_Q.append(nn.Linear(d_model, self.d_k, bias=False))
      self.W_K.append(nn.Linear(d_model, self.d_k, bias=False))
      self.W_V.append(nn.Linear(d_model, self.d_k, bias=False)) #TODO: Instead of Linear layers, just create Matrices

  def forward(self, Q, K, V):
    #The inputs are Queries, Keys and Vectors which each have size seq x d_model
    if (Q.shape[1] != self.d_model or K.shape[1] != self.d_model or V.shape[1] != self.d_model):
      raise Exception('Invalid Query, Key or Value Dimensions')
    heads = []
    for i in range(self.h):
      queries = self.W_Q[i](Q)
      keys = self.W_K[i](K)
      values = self.W_V[i](V)
      #At this point, queries keys & values have dimensions seq x d_k
      heads.append(self.Att.forward(queries, keys, values))
    out = torch.cat(heads, 1) #The output has the same dimension as all the inputs: seq x d_model
    out = self.W_O(out)
    return out

#Define a Feed Forward Network
class FeedForward(nn.Module):

  def __init__(self, d_model, d_hidden):
    super().__init__()
    self.network = nn.Sequential(
        nn.Linear(d_model, d_hidden, bias=True),
        nn.ReLU(),
        nn.Linear(d_hidden, d_model, bias=True)
    )

  def forward(self, x):
    x = self.network(x)
    return x

#Define Layer Normalization:
class LayerNorm(nn.Module):

  def __init__(self, d_model):
    super().__init__()
    self.epsilon = 1e-5
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))

  def __call__(self, x):
    mean = x.mean(1, keepdim=True) #Mean across the layer i.e. the column
    variance = x.var(1, keepdim=True) #Mean across the layer i.e. column
    norm = (x - mean) / torch.sqrt(variance + self.epsilon) #Normalize
    out = self.gamma * norm + self.beta #Scale by gamma, add beta to achieve var= gamma, mean = beta
    return out

#Define batch normalization
class BatchNorm(nn.Module):

  def __init__(self, d_model):
    super().__init__()
    self.epsilon = 1e-5
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))

  def __call__(self, x):
    mean = x.mean(0, keepdim=True) #Mean across the batch i.e. row
    variance = x.var(0, keepdim=True) #Mean across the batch i.e. row
    norm = (x - mean) / torch.sqrt(variance + self.epsilon) #Normalize
    out = self.gamma * norm + self.beta #Scale by gamma, add beta to achieve var= gamma, mean = beta
    return out

#Define an Block of Multi Head Self-Attention with a Residual Connection & Layer Normalization
class NormalizedSelfAttention(nn.Module):

  def __init__(self, d_model, h, mask, context_len):
    super().__init__()
    self.MHA = MultiHeadAttention(d_model, h, mask, context_len)
    self.norm = LayerNorm(d_model)

  def forward(self, x):
    x = x + self.MHA(x, x, x)
    x = self.norm(x)
    return x

#Define a block of a Feed Forward Network with a Residual Connection & Layer Normalization
class NormalizedFeedForward(nn.Module):

  def __init__(self, d_model, d_hidden):
    super().__init__()
    self.FF = FeedForward(d_model, d_hidden)
    self.norm = LayerNorm(d_model)

  def forward(self, x):
    x = x + self.FF(x)
    x = self.norm(x)
    return x

#Define a stand-alone Encoder block with N-self attention and feed forward blocks (without embeddings or softmax)
#This is identical to StandAloneDecoder except that masking is set to false, so attention is bi-directional
class StandAloneEncoder(nn.Module):

  def __init__(self, N, d_model, h, d_hidden, context_len):
    super().__init__()
    self.network = nn.Sequential()
    for i in range(N):
      self.network.append(NormalizedSelfAttention(d_model, h, False, context_len))
      self.network.append(NormalizedFeedForward(d_model, d_hidden))

  def forward(self, x):
    out = self.network(x)
    return out

#Define a stand-alone Decoder block with N self-attention and feed forward blocks (without embeddings or softmax)
#Note that the self attention is Masked because this is a Decoder
class StandAloneDecoder(nn.Module):

  def __init__(self, N, d_model, h, d_hidden, context_len):
    super().__init__()
    self.network = nn.Sequential()
    for i in range(N):
      self.network.append(NormalizedSelfAttention(d_model, h, True, context_len))
      self.network.append(NormalizedFeedForward(d_model, d_hidden))

  def forward(self, x):
    out = self.network(x)
    return out

#Define an entire Transformer with E encoder blocks and D decoder blocks (without embeddings or softmax)
#The last encoder provides input to all D decoders
class Transformer(nn.Module):

  def __init__(self, E, D, d_model, h, d_hidden, context_len):
    super().__init__()
    self.E = E
    self.D = D
    self.encoders = nn.Sequential()
    for e in range(E):
      self.encoders.append(NormalizedSelfAttention(d_model, h, False, context_len))
      self.encoders.append(NormalizedFeedForward(d_model, d_hidden))
    self.decoders = nn.ModuleList()
    for d in range(D):
      self.decoders.append(NormalizedSelfAttention(d_model, h, True, context_len))
      self.decoders.append(MultiHeadAttention(d_model, h, False, context_len))
      self.decoders.append(LayerNorm(d_model))
      self.decoders.append(NormalizedFeedForward(d_model, d_hidden))

  def __forward__(self, inputs, outputs):
    inputs = self.encoders(inputs)
    for e in range(self.E):
      outputs = self.decoders[3*e](outputs) #self attention
      outputs = self.decoders[3*e+1](outputs, inputs, inputs) #cross attention
      outputs = outputs + self.decoders[3*e+2](outputs) #Layer norm & residual connection
      outputs = self.decoders[3*e+3](outputs) #Feed forward
    return outputs
#Test what happens when a cross attention head gets inputs of:
#Queries: a x d_model
#Keys & Values: b x d_model
#Where a != b

In [None]:
#Global hyperparam
max_context = 8

In [None]:
#For testing NormalizedFeedForward
d_model = 32
d_hidd = 64
seq = 3
normFeed = NormalizedFeedForward(d_model, d_hidd)
input = torch.ones(seq, d_model)
output = normFeed(input)
print(output)
print(output[0,:].std())

tensor([[ 2.9178e-01,  4.3697e-04, -8.7427e-01,  4.5068e-02,  3.5812e-01,
          1.1397e+00,  1.0114e+00,  1.0266e+00,  2.5649e-01, -6.8877e-01,
          1.0952e+00,  1.0583e+00,  1.9244e-01, -8.6496e-01,  5.6941e-01,
          4.1225e-01, -2.7264e+00, -1.0703e+00,  7.7482e-01,  1.7110e+00,
          9.4203e-01,  5.2723e-01, -1.1681e+00, -3.4162e-01, -1.4807e+00,
         -8.5708e-01,  1.2498e+00, -1.3937e-02, -1.3199e-01, -1.5662e+00,
          8.3566e-03, -8.8612e-01],
        [ 2.9178e-01,  4.3697e-04, -8.7427e-01,  4.5068e-02,  3.5812e-01,
          1.1397e+00,  1.0114e+00,  1.0266e+00,  2.5649e-01, -6.8877e-01,
          1.0952e+00,  1.0583e+00,  1.9244e-01, -8.6496e-01,  5.6941e-01,
          4.1225e-01, -2.7264e+00, -1.0703e+00,  7.7482e-01,  1.7110e+00,
          9.4203e-01,  5.2723e-01, -1.1681e+00, -3.4162e-01, -1.4807e+00,
         -8.5708e-01,  1.2498e+00, -1.3937e-02, -1.3199e-01, -1.5662e+00,
          8.3566e-03, -8.8612e-01],
        [ 2.9178e-01,  4.3697e-04, -8.74

In [None]:
#For testing NormalizedSelfAttention
d_model = 32
heads = 8
seq = 3
selfAtt = NormalizedSelfAttention(d_model, heads, True, max_context)
input = torch.ones(seq, d_model)
output = selfAtt(input)
print(output)
print(output[0,:].std())

tensor([[ 0.5378,  0.5970,  0.9667, -0.6975,  0.2002, -1.5745, -0.6936,  0.8485,
         -0.3678,  1.9080, -0.2445, -1.1252, -1.3877,  0.0942, -0.6123, -0.1480,
          0.5845, -2.1231,  0.1803,  1.1711, -0.3985, -0.6275, -0.5578,  0.6858,
         -0.3138,  1.7304, -0.5328, -0.1623, -0.8022,  0.2883,  2.4591,  0.1171],
        [ 0.5378,  0.5970,  0.9667, -0.6975,  0.2002, -1.5745, -0.6936,  0.8485,
         -0.3678,  1.9080, -0.2445, -1.1252, -1.3877,  0.0942, -0.6123, -0.1480,
          0.5845, -2.1231,  0.1803,  1.1711, -0.3985, -0.6275, -0.5578,  0.6858,
         -0.3138,  1.7304, -0.5328, -0.1623, -0.8022,  0.2883,  2.4591,  0.1171],
        [ 0.5378,  0.5970,  0.9667, -0.6975,  0.2002, -1.5745, -0.6936,  0.8485,
         -0.3678,  1.9080, -0.2445, -1.1252, -1.3877,  0.0942, -0.6123, -0.1480,
          0.5845, -2.1231,  0.1803,  1.1711, -0.3985, -0.6275, -0.5578,  0.6858,
         -0.3138,  1.7304, -0.5328, -0.1623, -0.8022,  0.2883,  2.4591,  0.1171]],
       grad_fn=<AddBackw

In [None]:
#For testing layer norm
randoms = torch.randn(2, 5)
layernorm = LayerNorm(5)
randoms = layernorm(randoms)

print("Column mean and std dev")
print(randoms[:,0].mean())
print(randoms[:,0].std())

print("Row mean and std dev")
print(randoms[0,:].mean())
print(randoms[0,:].std())

print("We want the rows normalized")
print(randoms)

Column mean and std dev
tensor(0.5284, grad_fn=<MeanBackward0>)
tensor(1.3514, grad_fn=<StdBackward0>)
Row mean and std dev
tensor(-5.9605e-08, grad_fn=<MeanBackward0>)
tensor(1.0000, grad_fn=<StdBackward0>)
We want the rows normalized
tensor([[ 1.4840,  0.5001, -0.7490, -0.9520, -0.2830],
        [-0.4272,  0.1853, -0.5511, -0.8615,  1.6545]], grad_fn=<AddBackward0>)


In [None]:
#For testing & validation of Feed Forward
d_model = 64
d_hidden = 256
seq = 5
ff = FeedForward(d_model, d_hidden)
Input = torch.ones(seq, d_model)
print(ff.forward(Input).shape)

torch.Size([5, 64])


In [None]:
#For testing & validation of MultiHeadAttention
#Input is sequence size by d_model
#Q, K, V are all that size - in self attention Q=K=V=Input
#Hyperparams:
seq = 8 #Sequence size (i.e. number of characters). Must be less than blocksize.
d_model = 16 #Size of the embedding vectors for each character
heads = 8 #Number of heads in our multihead attention
mult = MultiHeadAttention(d_model, heads, True, max_context)
Input = torch.ones(seq, d_model)
output = mult.forward(Input, Input, Input)
print(output.shape)
print(output)

torch.Size([8, 16])
tensor([[ 0.2838, -0.4051,  0.2764,  0.5660,  0.2345, -0.0719,  0.1130, -0.2280,
         -0.0888,  0.1520, -0.2332, -0.1001,  0.5237, -0.2018, -0.0421, -0.2075],
        [ 0.2838, -0.4051,  0.2764,  0.5660,  0.2345, -0.0719,  0.1130, -0.2280,
         -0.0888,  0.1520, -0.2332, -0.1001,  0.5237, -0.2018, -0.0421, -0.2075],
        [ 0.2838, -0.4051,  0.2764,  0.5660,  0.2345, -0.0719,  0.1130, -0.2280,
         -0.0888,  0.1520, -0.2332, -0.1001,  0.5237, -0.2018, -0.0421, -0.2075],
        [ 0.2838, -0.4051,  0.2764,  0.5660,  0.2345, -0.0719,  0.1130, -0.2280,
         -0.0888,  0.1520, -0.2332, -0.1001,  0.5237, -0.2018, -0.0421, -0.2075],
        [ 0.2838, -0.4051,  0.2764,  0.5660,  0.2345, -0.0719,  0.1130, -0.2280,
         -0.0888,  0.1520, -0.2332, -0.1001,  0.5237, -0.2018, -0.0421, -0.2075],
        [ 0.2838, -0.4051,  0.2764,  0.5660,  0.2345, -0.0719,  0.1130, -0.2280,
         -0.0888,  0.1520, -0.2332, -0.1001,  0.5237, -0.2018, -0.0421, -0.2075],
  

In [None]:
#For testing & validation of Attention
#Q, K, V all have size seq x d_k
#Hyperparams:
seq = 4 #Sequence length (i.e. number of characters)
d_k = 5 #dimension of embedding vectors in this head
att = Attention(d_k, True, max_context)
Input = torch.ones(seq, d_k)
print(att.forward(Input, Input, Input))

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
