The material in this notebook follows Andrej Karpathy's nanoGPT tutorial (https://www.youtube.com/watch?v=kCc8FmEb1nY)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
#Loading the TinyShakespeare Dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
  text = f.read()

In [None]:
print(f'The dataset contains {len(text)} characters')

In [None]:
print(type(text)) #this dataset is just a big string (not a list of strings, we haven't used split or anything)
print(text[:1000])

In [None]:
# extract the unique characters/symbols/atoms that build the dataset
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [None]:
# TOKENIZER
# map characters to integers and vice versa
# this corresponds to a character level tokenizer
# state of the art methods tend to use subword level tokenizers
#
# note that although simple, character level tokenizers tend to produce very long sequences compared to other tokenizers.

stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s] # converts string (list of characters) to integers
decode = lambda l: ''.join([itos[i] for i in l]) #converts list of integers to a single string (list of characters)

print(encode('hello there'))
print(decode(encode('hello there')))

In [None]:
# encoding the dataset and storing as a tensor

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

plt.hist(data, vocab_size)
plt.show()

In [None]:
print(f'The most occuring character is: "{itos[1]}"')

In [None]:
## Training and validation split of the data

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In most applications, the training dataset is massive to be fed into the transformer all at once. Instead, training is performed by randomly sampling blocks or subsequences from the training set of a specified size, which is refered to here as a $\texttt{block_size}$ or a $\texttt{context_length}$. These blocks are ordered and contain multiple pieces of information. For example, consider  $\texttt{block_size} = N$, e.g., 

$[1,...,N]$

As we are interested in predicting next tokens, we can create the following list of training examples from that single block with the format $\texttt{target}$ | $\texttt{input}$:

$1|$

$2|1$

$3|2,1$

...

$N | N-1,....,1$

Hence the training using contexts of size $1$ all the way up to $\texttt{block_size}$ 

In [None]:
# here is an example of the above

block_size = 8 # context length
print(f'The first training block is {train_data[:block_size+1]}')
print(f'From this block, we can construct the following prediction cases.')
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f'When the input is {context}, the target is {target}.')

Recall the analogy with joint probabilities. For a sequence $\{t_1,...,t_N\}$, the joint probability has the following factorization:

$p(t_1,...,t_N) = p(t_1) p(t_2|t_1) p(t_3| t_2,t_1) ... p(t_{N} | t_{N-1},...,t_1)$ 

Transformers are not trained on individual blocks, but rather minibatches of multiple blocks. Hence, the $\texttt{batch_size}$ is also an important hyperparameter. The primary reason for doing this is for efficiency.  

In [None]:
#Some default hyperparameters, which we will change throughout
batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

In [None]:
print(device)

In [None]:
torch.manual_seed(1337)
batch_size = 4 # how many independent blocks/sequences will we process in parallel?
block_size = 8 # what is the max context length for predictions?

def get_batch(opt):
  data = train_data if opt == 'train' else val_data
  #select n random starting indices for a sequence of size block_size, where n = batch_size
  ix = torch.randint(len(data) - block_size, (batch_size,)) 
  x = torch.stack([data[i:i+block_size] for i in ix]) #create each block at each starting location in ix
  y = torch.stack([data[i+1:i+block_size+1] for i in ix]) #create targets for each block in the batch
  x,y = x.to(device), y.to(device)
  return x, y 

#next, we sample a batch from the training data set
xb, yb = get_batch('train') 

print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

#below we unpack all of the examples stored in each block in the batch
for b in range(batch_size): #b = batch
  for t in range(block_size): #t = time
    context = xb[b,:t+1]
    target = yb[b,t]
    print(f'When input is {context.tolist()}, the target is: {target}')


In [None]:
 ## Helper functions

@torch.no_grad()
def estimate_loss(model):
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      logits, loss = model(X,Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

In order to build up to a Transformer network, we will first recall the Bigram language model.

In [None]:
class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # Each token directly reads off the logits for the next token from a lookup table.
    # This lookup table is basically a transition probability matrix, where the probabilities are replaced by logits.
    # Hence, entry (i,j) of the table corresponds to the logit associated with transitioning from the current token i the next token j.
    # Thus the i-th row contains the logits (probability distribution over the vocabulary) for the next token.

    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) #(C,C)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size 
    # idx and targets are both integer tensors of dimension (B,T) 
    logits = self.token_embedding_table(idx) #(B,T,C)
    # logits(b,t,:) contains the logits for predicting the next token given that we are currently at token t in batch b.
    # these logits correspond to a probability distribution over the vocabulary.
    
    # we need to reshape things to work with F.cross_entropy
    
    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T) 
      loss = F.cross_entropy(logits, targets) #(B,T)      

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t
    # max_new_tokens is how many new tokens to sample

    ## Note: In this simple bigram model, we only need to keep track of the latest indices and not the entire history.
    ##       In the following code, we actually keep track of all of the indices in 'idx' and continually append to it.
    ##       This is because we will need this capability in subsequent models that have more context.
    for _ in range(max_new_tokens):
      # make the predictions (probability distributions over next entries)
      logits, loss = self(idx) # logits is (B,T,C), loss is (B*T)
      # focus on the last time step because that is the only context needed for bigram predictions
      logits = logits[:,-1,:] # (B,C)
      # convert logits to probs
      probs = F.softmax(logits, dim=-1) #(B,C)
      # sample
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      # appending new sampled indices to our current set
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb,yb)

In [None]:
loss

In [None]:
preds = m.generate(xb, 20)

decode(preds[0].tolist())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
print( decode( m.generate(idx, max_new_tokens=500)[0].tolist() ))

Obviously, this purely random model spits out gibberish. Now we will train this model on the tinyShakespeare dataset.

In [None]:
# define some hyperparameters
batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

In [None]:
# PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
for steps in range(10000):
  #minibatch
  xb, yb = get_batch('train')

  #loss
  logits, loss = m(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
print( decode( m.generate(idx, max_new_tokens=500)[0].tolist() ))

While the model is still incomprehensible, it is considerably improved over the untrained model. At least some of the structure of has been learned. Bigram models are quite poor, but they at least provide a benchmark for comparison. We need longer contexts than just 1.

## Efficient implementations of self-attention use the following trick

Basically, to ensure causal interactions via masking. By causal, we mean that the future cannot communicate with the past. This is a form of autoregression and highlights the role of lower triangular matrices.

In [None]:
# toy example

torch.manual_seed(1337)
B,T,C = 4,8,2 #batch, time, channels
x = torch.randn(B,T,C)
x.shape

In [None]:
# We want a causal temporal average:
#  x[b,t] = mean_{i<=t} x[b,i]
# bow = bag of words, averaging over previous words
# all notion of order is loss, so some information is loss
xbow = torch.zeros((B,T,C)) 
for b in range(B): #for each sequence in the batch
  for t in range(T): #for each time
    xprev = x[b,:t+1] #(t,C), includes the t-th entry
    xbow[b,t] = torch.mean(xprev,0) #(C) average across the previous times

In [None]:
# A more efficient formulation than for loops
# causal averaging operator
a = torch.tril(torch.ones(3,3))
a = a / a.sum(axis=1, keepdim=True)
a

In [None]:
wei = torch.tril(torch.ones(T,T)) #temporal averaging weights
wei = wei / wei.sum(axis=1, keepdim=True)

xbow2 = wei @ x #(T,T) @ (B,T,C) ----> (B,T,C) due to broadcasting, batch calculations done in parallel 

In [None]:
print( torch.allclose(xbow,xbow2) )
print(xbow[0], xbow2[0])

In [None]:
# We can write the same averaging operator with a softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

# The averaging operator 'wei' is currently a uniform average.
# In reality, we may want a more general weighted average (but still causal), where the weights can be learned.

## Expanding the Bigram Model
We will begin to expand the previous bigram model with seemingly redundant information. These "redundant" aspects, such as token and position embeddings become important when dealing with Transformer models.

In [None]:
# define some hyperparameters
batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

In [None]:
class BigramLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    # Each token directly reads off the logits for the next token from a lookup table.
    # This lookup table is basically a transition probability matrix, where the probabilities are replaced by logits.
    # Hence, entry (i,j) of the table corresponds to the logit associated with transitioning from the current token i the next token j.
    # Thus the i-th row contains the logits (probability distribution over the vocabulary) for the next token.

    # first we embed each input token into euclidean space (rather than using a one-hot vector)
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd) #(vocab_size,n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd) #(block_size, n_embd)
    # Note the token_embedding only embeds information about the value of the token, but not its position in the context
    # Hence, we also include a position_embedding in the same space, basically as an additional learnable degree of freedom to account for the positioning within the context.
    # Importantly, the position_embedding is only used for position information and not for conveying token value. Hence it is the same within the batch.

    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size 
    # idx and targets are both integer tensors of dimension (B,T) 
    B,T = idx.shape
    tok_embd = self.token_embedding_table(idx) #(B,T,n_embd)
    pos_embd = self.position_embedding_table(torch.arange(T, device=device)) #(T, n_embd)
    x = tok_embd+pos_embd #(B,T,n_embd) + (T,n_embd) = (B,T,n_embd)
    logits = self.lm_head(x) #(B,T,vocab_size) --> corresponds to logistic model on the embedded tokens

    # logits(b,t,:) contains the logits for predicting the next token given that we are currently at token t in batch b.
    # these logits correspond to a probability distribution over the vocabulary.
    
    # we need to reshape things to work with F.cross_entropy
    
    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T) 
      loss = F.cross_entropy(logits, targets) #(B,T)      

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t
    # max_new_tokens is how many new tokens to sample

    ## Note: In this simple bigram model, we only need to keep track of the latest indices and not the entire history.
    ##       In the following code, we actually keep track of all of the indices in 'idx' and continually append to it.
    ##       This is because we will need this capability in subsequent models that have more context.
    for _ in range(max_new_tokens):
      # make the predictions (probability distributions over next entries)
      logits, loss = self(idx) # logits is (B,T,C), loss is (B*T)
      # focus on the last time step because that is the only context needed for bigram predictions
      logits = logits[:,-1,:] # (B,C)
      # convert logits to probs
      probs = F.softmax(logits, dim=-1) #(B,C)
      # sample
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      # appending new sampled indices to our current set
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx


m = BigramLanguageModel()
logits, loss = m(xb,yb)

## The crux of attention

In [None]:
#recap
torch.manual_seed(1337)
B,T,C = 4,8,32 #batch, time, channels
x = torch.randn(B,T,C)
# x[i,:,:] is a (T,C) snapshot matrix, where the t-th row x[i,t,:] contains the C-dimensional embedding of the t-th token
# so the rows of x[i,:,:] index the time steps

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x # causal time averages of each dimension of the embedding

With attention, we want to achieve something similar but more flexible. We still want to have information propogate in a causal way -- meaning that future tokens cannot influence the past. In constrast with the previous time averaging, however, we want this to be achieved in a data-dependent way.

To do this, we abstract the notion of weights and averaging to the notion of affinities and similarity. Such concepts have been fruitfull applied with respect to manifold learning and kernel methods on graphs. The transformer perspective is as follows.

Each token emits a query vector and a key vector in the same space. Key vectors contain information pertaining to the token's identity and query vectors contain information about what the token is looking for. We use "information" here in the loose sense as these key and query vectors constructed by a linear transformation of the input, which is learned from data. Each token communicates with all other tokens by taking the dot product between its own query vector and all other key vectors. These dot products are "affinities".

The key and query vectors are created by linear transformation of the input vector. Why not use an emmbedding table? Perhapsto reduce the number of parameters.

Attention is a communication mechanism between tokens in a sequence. We can interpret this as a directed graph.

In [None]:
# A single attention head
torch.manual_seed(1337)
B,T,C = 4,8,32 #batch, time, channels
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False) # (C, head_size)
query = nn.Linear(C, head_size, bias=False) # (C, head_size)
value = nn.Linear(C, head_size, bias=False) #(C, head_size)
## x[i,:,:] is a (T,C) matrix of time snapshots in the rows
## x[i,:,:] @ key is a (T,head_size)
k = key(x) #(B,T,head_size)
q = query(x) #(B,T,head_size)
v = value(x) #(B,T,head_size)
wei = q @ k.swapaxes(-2,-1) # (B, T, head_size) @ (B, head_size, T) = (B, T, T)
wei = wei * (head_size)**(-0.5) #scaling / normalization

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0, float('-inf')) #ensures causal masking
wei = F.softmax(wei, dim=-1)
out = wei @ v #(B,T, head_size)

In [None]:
out.shape

In [None]:
## Formalizing

class Head(nn.Module):
  """ One-headed self-attention """

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # torch will recognize that 'tril' is not a model parameter

  def forward(self, x):
    B,T,C = x.shape #batch, time, channel
    k = self.key(x) #(B,T,C)
    q = self.query(x) #(B,T,C)
    v = self.value(x) #(B,T,C)

    wei = q @ k.swapaxes(-2,-1) * C**(-0.5) # (B, T, C) @ (B, C, T) = (B, T, T)
    wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))

    wei = F.softmax(wei, dim=-1) #(B,T,T)
    out = wei @ v #(B,T,C)
    return out


In [None]:
class SelfAttentionModel(nn.Module):

  def __init__(self):
    super().__init__()

    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.sa_head = Head(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size 
    # idx and targets are both integer tensors of dimension (B,T) 
    B,T = idx.shape
    tok_embd = self.token_embedding_table(idx) #(B,T,n_embd)
    pos_embd = self.position_embedding_table(torch.arange(T, device=device)) #(T, n_embd)
    x = tok_embd+pos_embd #(B,T,n_embd) + (T,n_embd) = (B,T,n_embd)
    x = self.sa_head(x)
    logits = self.lm_head(x) #(B,T,vocab_size) --> corresponds to logistic model on the embedded tokens

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T) 
      loss = F.cross_entropy(logits, targets) #(B,T)      

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t

    for _ in range(max_new_tokens):
      #ensure we stay within scope (context never exceeds block_size, i.e., the context = the most recent upt-to-block_size tokens) 
      idx_cond = idx[:,-block_size:] 
      logits, loss = self(idx_cond) # logits is (B,T,C), loss is (B*T)
      logits = logits[:,-1,:] # (B,C)
      probs = F.softmax(logits, dim=-1) #(B,C)
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx


m = SelfAttentionModel()
logits, loss = m(xb,yb)

In [None]:
# define some hyperparameters
torch.manual_seed(1337)
batch_size = 32
block_size = 8
max_iters = 10000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
for step in range(max_iters):
  #minibatch
  xb, yb = get_batch('train')

  #loss
  logits, loss = m(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if step % 1000 == 0:
    outs = estimate_loss(m)
    print(f"iter {step} | train: {outs['train']} | test: {outs['val']}")

print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
print( decode( m.generate(idx, max_new_tokens=500)[0].tolist() ))

In [None]:
class MultiHeadAttention(nn.Module):
  """ multiple self-attention heads in parallel """

  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat([h(x) for h in self.heads], dim=-1) #results of each head concatenated together along the last axis (channel dimension)

mtmp = MultiHeadAttention(4, int(n_embd/4) )

In [None]:
class MultiHeadSelfAttentionModel(nn.Module):

  def __init__(self):
    super().__init__()

    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.sa_heads = MultiHeadAttention(4, int(n_embd/4) ) #n_embd needs to be divisible by 4
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size 
    # idx and targets are both integer tensors of dimension (B,T) 
    B,T = idx.shape
    tok_embd = self.token_embedding_table(idx) #(B,T,n_embd)
    pos_embd = self.position_embedding_table(torch.arange(T, device=device)) #(T, n_embd)
    x = tok_embd+pos_embd #(B,T,n_embd) + (T,n_embd) = (B,T,n_embd)
    x = self.sa_heads(x)
    logits = self.lm_head(x) #(B,T,vocab_size) --> corresponds to logistic model on the embedded tokens

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T) 
      loss = F.cross_entropy(logits, targets) #(B,T)      

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t

    for _ in range(max_new_tokens):
      #ensure we stay within scope (context never exceeds block_size, i.e., the context = the most recent upt-to-block_size tokens) 
      idx_cond = idx[:,-block_size:] 
      logits, loss = self(idx_cond) # logits is (B,T,C), loss is (B*T)
      logits = logits[:,-1,:] # (B,C)
      probs = F.softmax(logits, dim=-1) #(B,C)
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx


m = MultiHeadSelfAttentionModel()
logits, loss = m(xb,yb)

In [None]:
# define some hyperparameters
torch.manual_seed(1337)
batch_size = 32
block_size = 8
max_iters = 10000
eval_interval = 300
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
for step in range(max_iters):
  #minibatch
  xb, yb = get_batch('train')

  #loss
  logits, loss = m(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if step % 1000 == 0:
    outs = estimate_loss(m)
    print(f"iter {step} | train: {outs['train']} | test: {outs['val']}")

print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
print( decode( m.generate(idx, max_new_tokens=500)[0].tolist() ))

In [None]:
class FeedForward(nn.Module):
  """ just a linear layer and subsequent nonlinearity"""

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,n_embd),
        nn.ReLU(),
    )

  def forward(self,x):
    return self.net(x)

In [None]:
class MultiHeadSelfAttentionModel(nn.Module):

  def __init__(self):
    super().__init__()

    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.sa_heads = MultiHeadAttention(4, int(n_embd/4) ) #n_embd needs to be divisible by 4
    self.ffwd = FeedForward(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size 
    # idx and targets are both integer tensors of dimension (B,T) 
    B,T = idx.shape
    tok_embd = self.token_embedding_table(idx) #(B,T,n_embd)
    pos_embd = self.position_embedding_table(torch.arange(T, device=device)) #(T, n_embd)
    x = tok_embd+pos_embd #(B,T,n_embd) + (T,n_embd) = (B,T,n_embd)
    x = self.sa_heads(x)
    x = self.ffwd(x)
    logits = self.lm_head(x) #(B,T,vocab_size) --> corresponds to logistic model on the embedded tokens

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T) 
      loss = F.cross_entropy(logits, targets) #(B,T)      

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t

    for _ in range(max_new_tokens):
      #ensure we stay within scope (context never exceeds block_size, i.e., the context = the most recent upt-to-block_size tokens) 
      idx_cond = idx[:,-block_size:] 
      logits, loss = self(idx_cond) # logits is (B,T,C), loss is (B*T)
      logits = logits[:,-1,:] # (B,C)
      probs = F.softmax(logits, dim=-1) #(B,C)
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx


m = MultiHeadSelfAttentionModel()
logits, loss = m(xb,yb)

In [None]:
# define some hyperparameters
torch.manual_seed(1337)
batch_size = 32
block_size = 8
max_iters = 10000
eval_interval = 300
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
for step in range(max_iters):
  #minibatch
  xb, yb = get_batch('train')

  #loss
  logits, loss = m(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if step % 1000 == 0:
    outs = estimate_loss(m)
    print(f"iter {step} | train: {outs['train']} | test: {outs['val']}")

print(loss.item())

idx = torch.zeros((1,1), dtype=torch.long)
print( decode( m.generate(idx, max_new_tokens=500)[0].tolist() ))

## Full Transformer Model

In [None]:
# define some hyperparameters
torch.manual_seed(1337)
batch_size = 128
block_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 180 #needs to be divisible by n_head

n_head = 6
n_layer = 6
dropout = 0.2

In [None]:
## Formalizing

class Head(nn.Module):
  """ One-headed self-attention """

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # torch will recognize that 'tril' is not a model parameter

    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape #batch, time, channel
    k = self.key(x) #(B,T,C)
    q = self.query(x) #(B,T,C)
    v = self.value(x) #(B,T,C)

    wei = q @ k.swapaxes(-2,-1) * C**(-0.5) # (B, T, C) @ (B, C, T) = (B, T, T)
    wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))

    wei = F.softmax(wei, dim=-1) #(B,T,T)
    weu = self.dropout(wei)
    out = wei @ v #(B,T,C)
    return out

class MultiHeadAttention(nn.Module):
  """ multiple self-attention heads in parallel """

  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1) #results of each head concatenated together along the last axis (channel dimension)
    out = self.dropout( self.proj(out) )
    return out

class FeedForward(nn.Module):
  """ just a linear layer and subsequent nonlinearity"""

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential( 
        nn.Linear(n_embd, 4 * n_embd),    #the choice of 4 here is empirical
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout)
    )

  def forward(self,x):
    return self.net(x)

class Block(nn.Module):
  """ Transformer block: communication follwed by computation """

  def __init__(self, n_embd, n_head):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size) #communication (matcing keys and queries among tokens to get values for each token)
    self.ffwd = FeedForward(n_embd) #computation (operating on the token values)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)
  def forward(self, x):

    x = x+self.sa( self.ln1(x) ) #residual connections
    x = x+self.ffwd( self.ln2(x) )
    return x

In [None]:
class MultiHeadSelfAttentionModel(nn.Module):

  def __init__(self):
    super().__init__()

    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size 
    # idx and targets are both integer tensors of dimension (B,T) 
    B,T = idx.shape
    tok_embd = self.token_embedding_table(idx) #(B,T,n_embd)
    pos_embd = self.position_embedding_table(torch.arange(T, device=device)) #(T, n_embd)
    x = tok_embd+pos_embd #(B,T,n_embd) + (T,n_embd) = (B,T,n_embd)
    x = self.blocks(x)
    x = self.ln_f(x)
    logits = self.lm_head(x) #(B,T,vocab_size) --> corresponds to logistic model on the embedded tokens

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T) 
      loss = F.cross_entropy(logits, targets) #(B,T)      

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t

    for _ in range(max_new_tokens):
      #ensure we stay within scope (context never exceeds block_size, i.e., the context = the most recent upt-to-block_size tokens) 
      idx_cond = idx[:,-block_size:] 
      logits, loss = self(idx_cond) # logits is (B,T,C), loss is (B*T)
      logits = logits[:,-1,:] # (B,C)
      probs = F.softmax(logits, dim=-1) #(B,C)
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx


m = MultiHeadSelfAttentionModel()
m.to(device)

In [None]:
print(f'This model has {sum(p.numel() for p in m.parameters())} parameters')

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
for step in range(max_iters):
  #minibatch
  xb, yb = get_batch('train')

  #loss
  logits, loss = m(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if step % 100 == 0:
    outs = estimate_loss(m)
    print(f"iter {step} | train: {outs['train']} | test: {outs['val']}")

print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
idx = idx.to(device)
print( decode( m.generate(idx, max_new_tokens=500)[0].tolist() ))