|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[18] Token prediction accuracy</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

# to get wikitext dataset
from datasets import load_dataset

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300,
})

# **Part 1: Token prediction**

In [None]:
# load pretrained GPT-2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
model.eval()

In [None]:
# text is paraphrased from https://en.wikipedia.org/wiki/Cubism
text = 'Cubism is an art movement that sparked innovations in music and architecture'
tokens = tokenizer.encode(text, return_tensors='pt')

for t in tokens[0]:
  print(f'Token {t:5} is "{tokenizer.decode(t)}"')

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(tokens)

# find max-logit prediction for one token
whichToken = 5

# token logits are predictions for the subsequent token
token_logits = outputs.logits[0,whichToken,:]
maxlogit = torch.argmax(token_logits)

# visualize the logits and mark the predicted next-token
plt.figure(figsize=(12,4))
plt.plot(maxlogit,token_logits[maxlogit],'go',markersize=10)
plt.plot(token_logits,'.',markerfacecolor=[.9,.7,.7,.3])

plt.gca().set(title=f'Model prediction is "{tokenizer.decode(maxlogit)}" (target is "{tokenizer.decode(tokens[0,whichToken+1])}")',
              xlabel='Token index',ylabel='Model output logit',xlim=[-50,tokenizer.vocab_size+49])

plt.tight_layout()
plt.savefig('ch4_proj18_part1.png')
plt.show()

# **Part 2: Per-token predictions**

In [None]:
# find the token associated with a quote to insert into the sentences
quotetok = tokenizer.encode('"',return_tensors='pt')[0]

# initialize
accuracy = np.zeros(len(tokens[0]-1),dtype=int)

# loop over tokens
for idx in range(len(tokens[0])-1):

  # get max logit
  max_logit = torch.argmax( outputs.logits[0,idx,:] ,dim=-1).unsqueeze(0)

  # get the max prediction
  accuracy[idx] = max_logit[0] == tokens[0][idx+1]

  # reconstruct the sentence tokens
  sentence = torch.cat(
      (tokens[0][:idx+1],              # original first few tokens
       quotetok, max_logit, quotetok), # predicted token, padded by quotes
      dim=-1)

  # decode and print
  print(f"{'Xâœ“'[accuracy[idx]]}: {tokenizer.decode(sentence)}")

print(f'\nModel correctly predicted {accuracy.sum()}/{len(accuracy)} tokens ({accuracy.mean():.2%})')

# **Part 3: Top-k predicted tokens**

In [None]:
k = 10

topk = torch.topk(token_logits,k,dim=-1) # same token_logits variable (from whichToken) defined in Part 1

print('Original text:\n  ',text,'\n\nGPT2 predictions:')

for t in topk.indices:

  # reconstruct the sentence tokens
  sentence = torch.cat(
      (tokens[0][:whichToken+1],            # original first few tokens
       quotetok, t.unsqueeze(0), quotetok), # predicted token
      dim=-1)

  # decode and print
  print('  ',tokenizer.decode(sentence))

# **Part 4: Token prediction accuracy in wikitext**

In [None]:
wikitxt = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
tokens = tokenizer.encode('\n\n'.join(wikitxt['text']), return_tensors='pt')
num_tokens = torch.numel(tokens)

print(f'There are {num_tokens:,} tokens in the wikitext dataset')

In [None]:
# here's what one sample looks like
wikitxt[1112]['text']

In [None]:
max_seq_len = model.config.n_positions
print(f'The model can accept up to {max_seq_len} tokens per sequence per forward pass.')

In [None]:
# initialize
accuracy = np.zeros(max_seq_len-1,dtype=int)

# forward pass
with torch.no_grad():
  outputs = model(tokens[:,:max_seq_len])

# loop over tokens
for idx in range(max_seq_len-1):

  # get max logit
  max_logit = torch.argmax( outputs.logits[0,idx,:] ,dim=-1)

  # get the max prediction
  accuracy[idx] = max_logit == tokens[0][idx+1]

print(f'\nModel correctly predicted {accuracy.sum()}/{len(accuracy)} tokens ({accuracy.mean():.2%})')

In [None]:
# the text that was assessed
print(tokenizer.decode(tokens[0,:max_seq_len]))

In [None]:
# visualize
plt.figure(figsize=(10,3))
plt.plot(np.random.normal(0,.03,len(accuracy))+accuracy,'ko',markerfacecolor=[.7,.7,.9,.4])
plt.gca().set(xlabel='Token position',yticks=[0,1],yticklabels=['Incorrect','Correct'],
              title=f'Average categorical accuracy: ({accuracy.mean():.2%})')

plt.tight_layout()
plt.savefig('ch4_proj18_part4.png')
plt.show()