<a href="https://colab.research.google.com/github/LTopaloglou/transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torch.torch_version import Version
import torch
import torch.nn as nn
from torch.nn import functional as F

#I am using variable names defined in "Attention Is All You Need" for simplicity.

#Global Hyperparams
block_size = 8 #The max number of characters of context used. seq <= block_size always

#Define a single Scaled Dot-Product Attention head. I am simplifying it by enforcing d_k = d_v
class Attention(nn.Module):

  def __init__(self, d_k, mask):
    super().__init__()
    self.mask = mask
    self.d_k = d_k
    self.register_buffer('lowerTriangle', torch.tril(torch.ones(block_size, block_size)))

  def forward(self, Q, K, V):
    #Takes in vectors of Queries, Keys & Values. Queries, Keys & Values have dimension seq x d_k
    if (Q.shape[1] != self.d_k or K.shape[1] != self.d_k or V.shape[1] != self.d_k):
      raise Exception('Invalid Query, Key or Value Dimensions')
    seq = Q.shape[0]
    #First take the dot product of Queries & Keys. Weight has dimensions seq x seq
    weight = Q @ K.transpose(0, 1)
    #Now scale by 1/sqrt(d_k)
    weight = weight  * (self.d_k**-0.5)
    #Mask everything not in the lower triangular of weight
    if (self.mask):
      weight = weight.masked_fill(self.lowerTriangle[:seq, :seq]== 0, float('-inf'))
    #Now apply softmax in the dimension of the rows
    weight = weight.softmax(1)
    #Finally, multiply values by weights to get the outputs
    out = weight @ V # [seq, seq] x [seq, d_k] = [seq, d_k]
    return out

#Define a Multi-Head Attention layer. As in the paper, I am setting d_k = d_model/h
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model, h, mask):
    super().__init__()
    self.mask = mask
    self.d_model = d_model
    self.h = h
    self.d_k = int(d_model/h)
    if (self.d_k != d_model/h):
      raise Exception('Invalid Dimensions Provided') #Ensure valid dimensions
    self.W_O = nn.Linear(h*self.d_k, d_model, bias=False) #Output linear layer
    self.W_Q = nn.ModuleList() #The h different Query linear layers
    self.W_K = nn.ModuleList() #The h different Key linear layers
    self.W_V = nn.ModuleList() #The h different Value linear layers
    self.Att = Attention(self.d_k, mask) #I think (?) only 1 attention layer is needed since no backprop happens TODO: VERIFY
    for i in range(h):
      #Initialize all h linear layers for Q, K, V
      self.W_Q.append(nn.Linear(d_model, self.d_k, bias=False))
      self.W_K.append(nn.Linear(d_model, self.d_k, bias=False))
      self.W_V.append(nn.Linear(d_model, self.d_k, bias=False)) #TODO: Instead of Linear layers, just create Matrices

  def forward(self, Q, K, V):
    #The inputs are Queries, Keys and Vectors which each have size seq x d_model
    if (Q.shape[1] != self.d_model or K.shape[1] != self.d_model or V.shape[1] != self.d_model):
      raise Exception('Invalid Query, Key or Value Dimensions')
    heads = []
    for i in range(self.h):
      queries = self.W_Q[i](Q)
      keys = self.W_K[i](K)
      values = self.W_V[i](V)
      #At this point, queries keys & values have dimensions seq x d_k
      heads.append(self.Att.forward(queries, keys, values))
    out = torch.cat(heads, 1) #The output has the same dimension as all the inputs: seq x d_model
    out = self.W_O(out)
    return out

#Define a Feed Forward Network
class FeedForward(nn.Module):

  def __init__(self, d_model, d_hidden):
    super().__init__()
    self.network = nn.Sequential(
        nn.Linear(d_model, d_hidden, bias=True),
        nn.ReLU(),
        nn.Linear(d_hidden, d_model, bias=True)
    )

  def forward(self, x):
    x = self.network(x)
    return x

#Define Layer Normalization:
class LayerNorm(nn.Module):

  def __init__(self, d_model):
    super().__init__()
    self.epsilon = 1e-5
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))

  def __call__(self, x):
    mean = x.mean(1, keepdim=True) #Mean across the layer i.e. the column
    variance = x.var(1, keepdim=True) #Mean across the layer i.e. column
    norm = (x - mean) / torch.sqrt(variance + self.epsilon) #Normalize
    out = self.gamma * norm + self.beta #Scale by gamma, add beta to achieve var= gamma, mean = beta
    return out

#Define an Block of Multi Head Self-Attention with a Residual Connection & Layer Normalization
class NormalizedSelfAttention(nn.Module):

  def __init__(self, d_model, h, mask):
    super().__init__()
    self.MHA = MultiHeadAttention(d_model, h, mask)
    self.norm = LayerNorm(d_model)

  def forward(self, x):
    x = x + self.MHA(x, x, x)
    x = self.norm(x)
    return x

#Define a block of a Feed Forward Network with a Residual Connection & Layer Normalization
class NormalizedFeedForward(nn.Module):

  def __init__(self, d_model, d_hidden):
    super().__init__()
    self.FF = FeedForward(d_model, d_hidden)
    self.norm = LayerNorm(d_model)

  def forward(self, x):
    x = x + self.FF(x)
    x = self.norm(x)
    return x

#Define a stand-alone Decoder block with N self-attention and feed forward blocks (without embeddings or softmax)
#Note that the self attention is Masked because this is a Decoder
class StandAloneDecoder(nn.Module):

  def __init__(self, N, d_model, h, d_hidden):
    super().__init__()
    self.network = nn.Sequential()
    for i in range(N):
      self.network.append(NormalizedSelfAttention(d_model, h, True))
      self.network.append(NormalizedFeedForward(d_model, d_hidden))

  def forward(self, x):
    out = self.network(x)
    return out

#Define an entire Transformer with E encoder blocks and D decoder blocks (without embeddings or softmax)
#The last encoder provides input to all D decoders
class Transformer(nn.Module):

  def __init__(self, E, D, d_model, h, d_hidden):
    super().__init__()
    self.E = E
    self.D = D
    self.encoders = nn.Sequential()
    for e in range(E):
      self.encoders.append(NormalizedSelfAttention(d_model, h, False))
      self.encoders.append(NormalizedFeedForward(d_model, d_hidden))
    self.decoders = nn.ModuleList()
    for d in range(D):
      self.decoders.append(NormalizedSelfAttention(d_model, h, True))
      self.decoders.append(MultiHeadAttention(d_model, h, False))
      self.decoders.append(LayerNorm(d_model))
      self.decoders.append(NormalizedFeedForward(d_model, d_hidden))

  def __forward__(self, inputs, outputs):
    inputs = self.encoders(inputs)
    for e in range(self.E):
      outputs = self.decoders[3*e](outputs) #self attention
      outputs = self.decoders[3*e+1](outputs, inputs, inputs) #cross attention
      outputs = outputs + self.decoders[3*e+2](outputs) #Layer norm & residual connection
      outputs = self.decoders[3*e+3](outputs) #Feed forward
    return outputs
#Test what happens when a cross attention head gets inputs of:
#Queries: a x d_model
#Keys & Values: b x d_model
#Where a != b

In [None]:
#For testing NormalizedFeedForward
d_model = 32
d_hidd = 64
seq = 3
normFeed = NormalizedFeedForward(d_model, d_hidd)
input = torch.ones(seq, d_model)
output = normFeed(input)
print(output)
print(output[0,:].std())

tensor([[ 0.8433, -0.3277,  0.8559,  0.8177,  0.6663, -0.3877, -0.4108,  2.0639,
         -1.2623,  1.3789, -0.2999, -0.0571,  0.1014, -0.0150, -0.4671, -0.1986,
         -0.6606,  2.4152, -0.7222, -1.7776,  0.9840,  0.7970, -0.5530, -0.9412,
         -1.7836, -1.0741, -0.3088, -0.9661,  0.3586, -0.5448,  1.1079,  0.3682],
        [ 0.8433, -0.3277,  0.8559,  0.8177,  0.6663, -0.3877, -0.4108,  2.0639,
         -1.2623,  1.3789, -0.2999, -0.0571,  0.1014, -0.0150, -0.4671, -0.1986,
         -0.6606,  2.4152, -0.7222, -1.7776,  0.9840,  0.7970, -0.5530, -0.9412,
         -1.7836, -1.0741, -0.3088, -0.9661,  0.3586, -0.5448,  1.1079,  0.3682],
        [ 0.8433, -0.3277,  0.8559,  0.8177,  0.6663, -0.3877, -0.4108,  2.0639,
         -1.2623,  1.3789, -0.2999, -0.0571,  0.1014, -0.0150, -0.4671, -0.1986,
         -0.6606,  2.4152, -0.7222, -1.7776,  0.9840,  0.7970, -0.5530, -0.9412,
         -1.7836, -1.0741, -0.3088, -0.9661,  0.3586, -0.5448,  1.1079,  0.3682]],
       grad_fn=<AddBackw

In [None]:
#For testing NormalizedSelfAttention
d_model = 32
heads = 8
seq = 3
selfAtt = NormalizedSelfAttention(d_model, heads, True)
input = torch.ones(seq, d_model)
output = selfAtt(input)
print(output)
print(output[0,:].std())

tensor([[ 1.2899, -0.0913, -0.1104,  0.7696,  0.0718,  0.2509,  0.3303, -1.9164,
         -0.0496,  0.5569,  0.1692, -0.1273, -1.1905,  0.4290, -0.8969, -0.1584,
         -0.1496, -1.1824, -0.4902,  0.4242, -0.4168, -0.4957,  0.9929, -2.2237,
         -0.8495, -0.5435,  1.3629, -0.7239,  2.6632,  1.3310,  1.2465, -0.2724],
        [ 1.2899, -0.0913, -0.1104,  0.7696,  0.0718,  0.2509,  0.3303, -1.9164,
         -0.0496,  0.5569,  0.1692, -0.1273, -1.1905,  0.4290, -0.8969, -0.1584,
         -0.1496, -1.1824, -0.4902,  0.4242, -0.4168, -0.4957,  0.9929, -2.2237,
         -0.8495, -0.5435,  1.3629, -0.7239,  2.6632,  1.3310,  1.2465, -0.2724],
        [ 1.2899, -0.0913, -0.1104,  0.7696,  0.0718,  0.2509,  0.3303, -1.9164,
         -0.0496,  0.5569,  0.1692, -0.1273, -1.1905,  0.4290, -0.8969, -0.1584,
         -0.1496, -1.1824, -0.4902,  0.4242, -0.4168, -0.4957,  0.9929, -2.2237,
         -0.8495, -0.5435,  1.3629, -0.7239,  2.6632,  1.3310,  1.2465, -0.2724]],
       grad_fn=<AddBackw

In [None]:
#For testing layer norm
randoms = torch.randn(2, 5)
layernorm = LayerNorm(5)
randoms = layernorm(randoms)

print("Column mean and std dev")
print(randoms[:,0].mean())
print(randoms[:,0].std())

print("Row mean and std dev")
print(randoms[0,:].mean())
print(randoms[0,:].std())

print("We want the rows normalized")
print(randoms)

Column mean and std dev
tensor(0.8183)
tensor(0.1750)
Row mean and std dev
tensor(-1.1921e-08)
tensor(1.0000)
We want the rows normalized
tensor([[ 0.9421, -1.5342,  0.7111,  0.2911, -0.4101],
        [ 0.6945,  0.5012,  0.8204, -1.5376, -0.4785]])


In [None]:
#For testing & validation of Feed Forward
d_model = 64
d_hidden = 256
seq = 5
ff = FeedForward(d_model, d_hidden)
Input = torch.ones(seq, d_model)
print(ff.forward(Input).shape)

torch.Size([5, 64])


In [None]:
#For testing & validation of MultiHeadAttention
#Input is sequence size by d_model
#Q, K, V are all that size - in self attention Q=K=V=Input
#Hyperparams:
seq = 8 #Sequence size (i.e. number of characters). Must be less than blocksize.
d_model = 16 #Size of the embedding vectors for each character
heads = 8 #Number of heads in our multihead attention
mult = MultiHeadAttention(d_model, heads, True)
Input = torch.ones(seq, d_model)
output = mult.forward(Input, Input, Input)
print(output.shape)
print(output)

torch.Size([8, 16])
tensor([[-0.6204, -0.5655, -0.3487,  0.4491, -0.6671,  0.5785, -0.7576,  0.1988,
          0.4697,  0.6004, -0.5643, -0.0136,  0.8412, -0.1806,  0.3783,  0.0433],
        [-0.6204, -0.5655, -0.3487,  0.4491, -0.6671,  0.5785, -0.7576,  0.1988,
          0.4697,  0.6004, -0.5643, -0.0136,  0.8412, -0.1806,  0.3783,  0.0433],
        [-0.6204, -0.5655, -0.3487,  0.4491, -0.6671,  0.5785, -0.7576,  0.1988,
          0.4697,  0.6004, -0.5643, -0.0136,  0.8412, -0.1806,  0.3783,  0.0433],
        [-0.6204, -0.5655, -0.3487,  0.4491, -0.6671,  0.5785, -0.7576,  0.1988,
          0.4697,  0.6004, -0.5643, -0.0136,  0.8412, -0.1806,  0.3783,  0.0433],
        [-0.6204, -0.5655, -0.3487,  0.4491, -0.6671,  0.5785, -0.7576,  0.1988,
          0.4697,  0.6004, -0.5643, -0.0136,  0.8412, -0.1806,  0.3783,  0.0433],
        [-0.6204, -0.5655, -0.3487,  0.4491, -0.6671,  0.5785, -0.7576,  0.1988,
          0.4697,  0.6004, -0.5643, -0.0136,  0.8412, -0.1806,  0.3783,  0.0433],
  

In [None]:
#For testing & validation of Attention
#Q, K, V all have size seq x d_k
#Hyperparams:
seq = 4 #Sequence length (i.e. number of characters)
d_k = 5 #dimension of embedding vectors in this head
att = Attention(d_k, True)
Input = torch.ones(seq, d_k)
print(att.forward(Input, Input, Input))

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
