In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt

In [37]:
d_embed = 512       # embedding size for the attention modules
num_heads = 8       # Number of attention heads
num_batches = 1     # number of batches (1 makes it easier to see what is going on)
vocab = 50000       # vocab size
max_len = 5000      # Max length of TODO what exactly?
n_layers = 1        # number of attention layers (not used but would be an expected hyper-parameter)
d_ff = 2048         # hidden state size in the feed forward layers
epsilon = 1e-6      # epsilon to use when we need a small non-zero number


In [38]:

x = torch.tensor([[1, 2, 3]]) # input will be 3 tokens
y = torch.tensor([[1, 2, 3]]) # target will be same as the input for many applications
x_mask = torch.tensor([[1, 0, 1]]) # Mask the 2nd input token
y_mask = torch.tensor([[1, 0, 1]]) # Mask the 2nd target token
print("x", x.size())
print("y", y.size())

x torch.Size([1, 3])
y torch.Size([1, 3])


In [39]:

# Make the embedding module. It understands that each token should result in a separate embedding.
emb = nn.Embedding(vocab, d_embed)
x = emb(x)
# Scale the embedding
x = x * math.sqrt(d_embed)
print(x.size())

torch.Size([1, 3, 512])


In [40]:
pe = torch.zeros(max_len,d_embed, requires_grad=False)

In [41]:
pe.size()

torch.Size([5000, 512])

In [42]:
position = torch.arange(0, max_len).unsqueeze(1)
print(position.size())

torch.Size([5000, 1])


In [43]:
# Start with an empty tensor
pe = torch.zeros(max_len, d_embed, requires_grad=False)
# array containing index values 0...max_len
position = torch.arange(0, max_len).unsqueeze(1)
divisor = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))
# Make overlapping sine and cosine wave inside positional embedding tensor
pe[:, 0::2] = torch.sin(position * divisor)
pe[:, 1::2] = torch.cos(position * divisor)
pe = pe.unsqueeze(0)
# Add the position embedding to the main embedding
x = x + pe[:, :x.size(1)]
print(x.size())

torch.Size([1, 3, 512])


In [44]:
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
W1 = nn.Parameter(torch.ones(d_embed))
b1 = nn.Parameter(torch.zeros(d_embed))
x = W1 * (x - mean) / (std + epsilon) + b1
print(x.size())

torch.Size([1, 3, 512])


In [53]:
# Make three versions of x, for the query, key, and value
# We don't need to clone because these will immediately go through linear layers, making new tensors
k = x # key
q = x # query
v = x # value
# Make three linear layers
# This is where the network learns to make scores
linear_k = nn.Linear(d_embed, d_embed)
linear_q = nn.Linear(d_embed, d_embed)
linear_v = nn.Linear(d_embed, d_embed)
# We are going to fold the embedding dimensions and treat each fold as an attention head
d_k = d_embed // num_heads
# Pass q, k, v through their linear layers
q = linear_q(q)
k = linear_k(k)
v = linear_v(v)
# Do the fold, treating each h dimenssions as a head
# Put the head in the second position
q = q.view(num_batches, -1, num_heads, d_k).transpose(1, 2)
k = k.view(num_batches, -1, num_heads, d_k).transpose(1, 2)
v = v.view(num_batches, -1, num_heads, d_k).transpose(1, 2)
print("q", q.size())
print("x", k.size())
print("v", v.size())

q torch.Size([1, 8, 3, 64])
x torch.Size([1, 8, 3, 64])
v torch.Size([1, 8, 3, 64])


In [57]:
q.transpose(-2, -1).size()

torch.Size([1, 8, 64, 3])

In [77]:
d_k = q.size(-1)
# Compute the raw scores by multiplying k and q (and normalize)
scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
print("scores", scores.size())
# Mask out the scores
scores = scores.masked_fill(x_mask == 0, -epsilon)
attn = F.softmax(scores, dim = -1)
print("attention", attn.size())
# Apply the scores to v
x = torch.matmul(attn, v)
print("x", x.size())    

scores torch.Size([1, 8, 3, 3])
attention torch.Size([1, 8, 3, 3])
x torch.Size([1, 8, 3, 64])


In [75]:
F.softmax(scores, dim=-1).size()


torch.Size([1, 8, 3, 3])

In [76]:
F.softmax(scores, dim=-1)

tensor([[[[0.2252, 0.3892, 0.3856],
          [0.4831, 0.2371, 0.2798],
          [0.3575, 0.3727, 0.2698]],

         [[0.5478, 0.2035, 0.2487],
          [0.3136, 0.3352, 0.3512],
          [0.3864, 0.3723, 0.2413]],

         [[0.3558, 0.3108, 0.3334],
          [0.3114, 0.3117, 0.3768],
          [0.4919, 0.3081, 0.2000]],

         [[0.2825, 0.3268, 0.3907],
          [0.2674, 0.4068, 0.3257],
          [0.4279, 0.3212, 0.2510]],

         [[0.2240, 0.4012, 0.3748],
          [0.5012, 0.2689, 0.2299],
          [0.3869, 0.3520, 0.2611]],

         [[0.2569, 0.3529, 0.3902],
          [0.2282, 0.2568, 0.5150],
          [0.2551, 0.3774, 0.3676]],

         [[0.4467, 0.2480, 0.3053],
          [0.3810, 0.3086, 0.3104],
          [0.3773, 0.3159, 0.3068]],

         [[0.2699, 0.4120, 0.3181],
          [0.3626, 0.3377, 0.2997],
          [0.2675, 0.3964, 0.3361]]]], grad_fn=<SoftmaxBackward0>)