In [4]:
# We import a dataset to train on. Here, we will use the Tiny Shakespeare dataset.
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

zsh:1: command not found: wget


# Understanding our implementation of GPT, and our dataset.
We want to get a better understanding of what exactly we will be implementing here.
First, let us examine our dataset in more detail.

Then, we will examine a simple Bigram language model, which gives us motivation to build a transformer.

In [5]:
# Reading in our input dataset as a variable called "text":
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [8]:
print(f"Length of the dataset (in chars): {len(text)}")

# Printing first 300 characters:
print(text[:300])

Length of the dataset (in chars): 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [9]:
# All the unique characters that occur in our data:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [10]:
# Encoding input text into IDs:
#    creating a mapping:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

#    creating functions (lookup tables):
encode = lambda s: [stoi[c] for c in s]            # encoder: Str -> list[int]
decode = lambda l: ''.join([itos[i] for i in l])   # decoder: list[int] -> Str

print(encode("Hi there!"))
print(decode(encode("Hi there!")))

# Here, we are encoding individual characters. In theory.pdf, we see that entire words are encoded.
# In practice, encoding is done on the sub-word level. Google uses SentencePiece for tokenization, and then each
#   token is encoded, where "tokens" are sub-words decided somehow.
#   note: maybe good idea to create a tokenizer from scratch?

# Bigger token size => smaller sequence length. 

[20, 47, 1, 58, 46, 43, 56, 43, 2]
Hi there!


In [11]:
# Encode our entire dataset and store as a torch.Tensor:
import torch
data = torch.tensor(encode(text), dtype=torch.long)

print(data.shape, data.dtype)
print(data[:300])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [13]:
# We split our dataset into a train:validation split. 
#    first 90% will be training data, rest will be validation.
n = int(0.9 * len(data))
#    validation:
val_data = data[n:]
#    training:
train_data = data[:n]

### Training and Validation
In practice, training is difficult to do on the entire text all at once. Instead, the training occurs in *chunks* of text, randomly sampled from `train_data`.
These chunks of text must have some maximum length to allow possible computation. We will call this `block_size`.



In [14]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [15]:
# However, we want to train all possible previous subsequences of this block_size size training data:
x = train_data[:block_size]       # our training (input)
y = train_data[1:block_size+1]    # our labels (expected output)

for t in range(block_size):
    context = x[:t+1]             # first "t" characters (starting at t=0)
    target = y[t]                 # the "next" character in the sequence
    print(f"When input is {context} the target is {target}.")
    
# This helps the transformer get used to seeing inputs of varying lengths, from 1 to block_size.

When input is tensor([18]) the target is 47.
When input is tensor([18, 47]) the target is 56.
When input is tensor([18, 47, 56]) the target is 57.
When input is tensor([18, 47, 56, 57]) the target is 58.
When input is tensor([18, 47, 56, 57, 58]) the target is 1.
When input is tensor([18, 47, 56, 57, 58,  1]) the target is 15.
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is 47.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is 58.


In [19]:
# Now we will try to combine multiple blocks into a batch of size batch_size (number of blocks per batch).
# This is done in practice for computational efficiency.

torch.manual_seed(1337)    # to replicate results consistently w/ random generation
batch_size = 4             # how many independent sequences do we process in parallel?
block_size = 8

# get_batch(split) generates a small batch of data of inputs "x" and targets "y".
#    split: either "train" or "val" (we only check for "train")
def get_batch(split):
    data = train_data if (split == "train") else val_data
    # We generate batch_size number of random offsets (these are our block starting indices):
    #    starts at len(data)-block_size to avoid IndexError.
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    
    # Stacking our tensors as rows:
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

xb, yb = get_batch("train")

print("Inputs:")
print(xb.shape)
print(xb)

print("Targets:")
print(yb.shape)
print(yb)

print("-----------------------------------------")

for b in range(batch_size):   # batch dimension
    for t in range(block_size):   # time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context.tolist()} the target is {target}.")
    

Inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----------------------------------------
When input is [24] the target is 43.
When input is [24, 43] the target is 58.
When input is [24, 43, 58] the target is 5.
When input is [24, 43, 58, 5] the target is 57.
When input is [24, 43, 58, 5, 57] the target is 1.
When input is [24, 43, 58, 5, 57, 1] the target is 46.
When input is [24, 43, 58, 5, 57, 1, 46] the target is 43.
When input is [24, 43, 58, 5, 57, 1, 46, 43] the target is 39.
When input is [44] the target is 53.
When input is [44, 53] the target is 56.
When input is [44, 53, 56] the target is 1.
When input is [44, 53, 56, 1] the target is 5

### Bigram Language Modeling
See the NLP project for more.

In [30]:
# Let's try modeling using a Bigram language model. More info can be found in NLP inference (not this project).
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)   # for reproducibility

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        
        # Each token reads directly off the logits for the next token from a lookup table:
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        # idx and targets are both (batch, time) tensors of integers:
        #    here:
        #      batch = batch_size
        #      time = block_size
        #      channel = vocab_size
        logits = self.token_embedding_table(idx)   # (batch, time, channel) or (B, T, C)
        
        if targets is None:
            loss = None
        else:
            # F.cross_entropy expects a (B, C, T) input so we reshape:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)

            # We are predicting what's next based on the identity of an individual token.
            # This is done for every token we have.
            loss = F.cross_entropy(logits, targets)    # we know the next character, how well are we predicting?
        
        return logits, loss
    
    # Generating new tokens.
    #    idx is a (B, T) array of indices in the current context.
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):   # we generate 1, 2, ..., max_new_tokens tokens
            # Get predictions:
            logits, loss = self(idx)
            # Focus only on the last time step (a.k.a. our prediction):
            logits = logits[:, -1, :]   # becomes (B, C) from (B, T, C)
            # Softmax to get probabilities:
            probs = F.softmax(logits, dim=-1)   # (B, C)
            # Sample from the distribution:
            idx_next = torch.multinomial(probs, num_samples=1)   # (B, 1)
            # Append sampled index to the running sequence:
            idx = torch.cat((idx, idx_next), dim=1)   # (B, T+1)
        return idx
    

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

# We start with the character "0":
idx = torch.zeros((1, 1), dtype=torch.long)

# Our prediction of (max_new_tokens) next tokens:
pred = m.generate(idx, max_new_tokens=100)[0].tolist()

print(idx)
print(pred)
print(decode(pred))


torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)
tensor([[0]])
[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15, 12, 52, 55, 7, 29, 17, 9, 9, 10, 15, 22, 55, 49, 27, 23, 20, 7, 55, 11, 10, 50, 39, 2, 53, 47, 63, 61, 49, 20, 48, 45, 15, 46, 64, 40, 29, 12, 59, 2, 9, 40, 24, 21, 45, 61, 43, 60, 51, 63, 18, 22, 19, 33, 19, 54, 0, 61, 52, 37, 35, 51, 52, 62, 23, 35, 35, 43, 60, 7, 58, 16, 55, 36, 17, 56, 34, 23, 24, 45, 22]

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [31]:
# We import the AdamW optimizer to begin training and gradient descent:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [38]:
# Training process:
batch_size = 32

for steps in range(10000):
    # Get a sample batch of data:
    xb, yb = get_batch("train")
    
    # Evaluate the loss:
    logits, loss = m(xb, yb)
    
    # Optimize and backpropagation:
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

2.4950690269470215


In [40]:
# We start with the character "0":
idx = torch.zeros((1, 1), dtype=torch.long)

# Our prediction of (max_new_tokens) next tokens:
pred = m.generate(idx, max_new_tokens=200)[0].tolist()

# Predictions AFTER TRAINING:
print(idx)
print(pred)
print(decode(pred))


tensor([[0]])
[0, 14, 59, 56, 56, 39, 61, 1, 39, 52, 53, 60, 43, 1, 58, 46, 53, 52, 42, 6, 1, 57, 1, 58, 46, 63, 53, 39, 1, 25, 39, 49, 5, 43, 1, 57, 43, 52, 1, 46, 47, 57, 0, 31, 1, 58, 53, 1, 42, 1, 51, 43, 56, 47, 57, 1, 51, 43, 1, 51, 43, 1, 47, 52, 39, 52, 50, 63, 11, 1, 58, 1, 51, 39, 58, 1, 47, 51, 43, 43, 63, 6, 1, 56, 43, 43, 42, 1, 20, 21, 33, 23, 21, 18, 1, 50, 43, 1, 56, 53, 59, 57, 59, 57, 54, 39, 56, 43, 6, 1, 58, 53, 56, 43, 57, 43, 57, 47, 52, 60, 39, 52, 41, 46, 43, 39, 52, 42, 8, 0, 37, 53, 52, 6, 1, 51, 43, 51, 1, 40, 50, 53, 52, 1, 45, 39, 51, 40, 43, 57, 1, 61, 43, 42, 1, 58, 46, 39, 44, 44, 47, 5, 50, 1, 47, 52, 1, 21, 1, 61, 43, 56, 1, 48, 59, 39, 49, 43, 1, 57, 54, 39, 42, 1, 57, 1, 58, 46, 43, 1, 58, 46, 39, 41, 43, 47, 60, 43, 1, 21, 1]

Burraw anove thond, s thyoa Mak'e sen his
S to d meris me me inanly; t mat imeey, reed HIUKIF le roususpare, toresesinvancheand.
Yon, mem blon gambes wed thaffi'l in I wer juake spad s the thaceive I 


# Exploring Self-Attention (Single-Head)

In [41]:
# We will construct an inference example here.

torch.manual_seed(1337)   # for reproducibility
B, T, C = 4, 8, 2   # batch, time, channels  (# batches,  # tokens per batch,  channels)

x = torch.randn(B, T, C)   # random (B, T, C) tensor
x.shape

# We have up to 8 tokens in a "batch". Currently, they don't "talk" to each other. We could like to couple
#   tokens. 
# In particular, tokens should only share information "forward", so from previous timesteps to the current
#   timestep. Future tokens will be predicted, so cannot share information backwards. 

# The simplest way is to do an average of the preceding elements. We'll get a feature-vector that summarizes the
#   history of this block so far. This is enough for now (though we lose positional information).

torch.Size([4, 8, 2])

In [45]:
# xbow = x "bag of words",  which is our average.
xbow = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        # Previous tokens are in this batch's dimension, up to and including the "t"th token
        xprev = x[b, :t+1]   # (t, C)
        # We store the mean in the (b, t) element of xbow
        xbow[b, t] = torch.mean(xprev, 0)
        
print(x[0], end="\n\n")
print(xbow[0])

# Notice now that the second tensor (xbow[0]) is the average of the previous elements of the first, row-wise. 
# But this is inefficient! We can be efficient using matrix multiplication.

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


### Improving Efficiency using Matrix Multiplication

In [48]:
torch.manual_seed(42)

# Notice that when we perform the following operation, we are "summing" along the columns of b:
a = torch.ones(3, 3)   # 3x3 matrix of "1"s
b = torch.randint(0, 10, (3, 2)).float()   # 3x2 matrix of random integers (range 0-10)
c = a @ b   # matrix multiplication

print("a: ")
print(a, end="\n ~~~~~ \n")

print("b: ")
print(b, end="\n ~~~~~ \n")

print("c: ")
print(c, end="\n ~~~~~ \n")


a: 
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
 ~~~~~ 
b: 
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
 ~~~~~ 
c: 
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
 ~~~~~ 


In [51]:
torch.manual_seed(42)

# However, if we consider a lower-triangular matrix of "1"s we get something more interesting:
a = torch.tril(torch.ones(3, 3))   # 3x3 LOWER TRIANGULAR matrix of "1"s
b = torch.randint(0, 10, (3, 2)).float()   # 3x2 matrix of random integers (range 0-10)
c = a @ b   # matrix multiplication

print("a: ")
print(a, end="\n ~~~~~ \n")

print("b: ")
print(b, end="\n ~~~~~ \n")

print("c: ")
print(c, end="\n ~~~~~ \n")

# In particular, we are now adding the "previous" entries instead of just adding all entries. But we want
#   to take an average, not just sum.

# We can normalize "a" such that its rows add to "1", giving an average.

a: 
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
 ~~~~~ 
b: 
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
 ~~~~~ 
c: 
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
 ~~~~~ 


In [53]:
torch.manual_seed(42)

# Normalizing as explained above:
a = torch.tril(torch.ones(3, 3))   # 3x3 LOWER TRIANGULAR matrix of "1"s
a = a / torch.sum(a, 1, keepdim=True)   # NORMALIZING "a" along "dim=1" keeping dimensions fixed.

b = torch.randint(0, 10, (3, 2)).float()   # 3x2 matrix of random integers (range 0-10)
c = a @ b   # matrix multiplication

print("a: ")
print(a, end="\n ~~~~~ \n")

print("b: ")
print(b, end="\n ~~~~~ \n")

print("c: ")
print(c, end="\n ~~~~~ \n")

a: 
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
 ~~~~~ 
b: 
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
 ~~~~~ 
c: 
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
 ~~~~~ 


#### Vectorizing and Improving Efficiency:

In [54]:
# We will now implement this matrix solution to our initial problem:
weights = torch.tril(torch.ones(T, T))
weights = weights / torch.sum(weights, 1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [57]:
# Our new xbow, calculated using multiplication instead:
xbow2 = weights @ x   # (T, T) x (B, T, C) ==> (B, T, T) x (B, T, C) ==> (B, T, C)

xbow2[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [58]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [60]:
# We can use Softmax instead, giving an identical result:
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))

# masked_fill:  for all the elements where tril==0 (masked element-wise), make them -inf. 
weights = weights.masked_fill(tril==0, float("-inf"))
print(weights, end="\n\n")

# softmax:      exponentiates entries and divides by the sum.
weights = F.softmax(weights, dim=-1)
print(weights, end="\n\n")

xbow3 = weights @ x

xbow3[0]

# We prefer this because the initial  weights=torch.zeros((T, T))  begins with zeros, and can be thought of as
#   an "interaction strength" between past and previous tokens.
# The -inf says we will not aggregate anything from the future tokens. 

# These affinities ("interaction strength") will change over time, which is the crux of self-attention. 
# Hence we prefer this method. 

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])



tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

## Implementing Self-Attention

In [61]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32   # batch, time, channels
x = torch.randn(B, T, C)   # our random "input"

# Implementation from above (method 3):
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril==0, float("-inf"))
weights = F.softmax(weights, dim=-1)
print(weights, end="\n\n")

out = weights @ x
out.shape

# Problem:
#   We want to somehow encode how position matters. If the current timestep is a vowel, we might be more interested
#   in previous consonants than vowels, for example, in certain positions. 
# We gather information in the past in a data-dependent way.

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])



torch.Size([4, 8, 32])

### Key and Query Vectors

In [66]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32   # batch, time, channels
x = torch.randn(B, T, C)   # our random "input"

# Let's introduce a single attention head:
head_size = 16
# Creating a Key vector:
key = nn.Linear(C, head_size, bias=False)
# Creating a Query vector:
query = nn.Linear(C, head_size, bias=False)

# Producing "k" and "q" by forwarding on "x":
#   (B, T, head_size)
k = key(x)
q = query(x)
# All the tokens in all the positions in the (B, T) arrangement, in parallel independently produce k and q.

weights = q @ k.transpose(-2, -1)   # transposing the last 2 dimensions then multiplying
#   (B, T, 16) x (B, 16, T) ==> (B, T, T)

# Lower triangular "1"s:
tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril==0, float("-inf"))
weights = F.softmax(weights, dim=-1)
# Our new weights, which is batch-element dependent:
print(weights[0], end="\n\n")

out = weights @ x
print(out.shape)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

torch.Size([4, 8, 32])


In [67]:
# But we don't want to aggregate the tokens exactly. We will introduce another parameter "value".
torch.manual_seed(1337)

B, T, C = 4, 8, 32   # batch, time, channels
x = torch.randn(B, T, C)   # our random "input"

# Let's introduce a single attention head:
head_size = 16
# Creating a Key vector:
key = nn.Linear(C, head_size, bias=False)
# Creating a Query vector:
query = nn.Linear(C, head_size, bias=False)
# Creating a Value vector:
value = nn.Linear(C, head_size, bias=False)

# Producing "k" and "q" by forwarding on "x":
#   (B, T, head_size)
k = key(x)
q = query(x)
# All the tokens in all the positions in the (B, T) arrangement, in parallel independently produce k and q.

weights = q @ k.transpose(-2, -1)   # transposing the last 2 dimensions then multiplying
#   (B, T, 16) x (B, 16, T) ==> (B, T, T)

# Lower triangular "1"s:
tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril==0, float("-inf"))
weights = F.softmax(weights, dim=-1)

# Propagating the linear layer onto "x":
v = value(x)

# Our new weights, which is batch-element dependent:
print(weights[0], end="\n\n")

# Using our VALUE instead of the raw "x":
out = weights @ v
print(out.shape)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

torch.Size([4, 8, 16])


In [80]:
# We are missing one element from the paper: scaling by 1/sqrt(head_size).
# This is for numerical (variance) stability:

k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)

# Multiplication scales variance to ~1. 
weights = q @ k.transpose(-2, -1) * head_size ** -0.5

# This is important because weights feeds into Softmax, so it's important that weights is fairly diffuse.
# If weights has very negative and very positive numbers, Softmax(weights) will converge to one-hot vectors.

In [81]:
k.var()

tensor(1.0966)

In [82]:
q.var()

tensor(0.9416)

In [83]:
weights.var()

tensor(1.0065)

In [84]:
# Here's an example of convergence of Softmax(weights) showing why we scale:

# NOT converging to one-hot vectors:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [88]:
# CONVERGING to one-hot vectors:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)

# (aggregrates everything to the maximum node, which we don't want)

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])