|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[20] Perplexity over time and text</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import requests
from datasets import load_dataset

from tqdm import tqdm

import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300,
})

# **What does perplexity mean?**

In [None]:
# each list item is model outputs (logits)
situations = [
    [ 1,1,1,9 ],
    [ 1,1,1,2 ],
    [ 1,1,2,2 ],
    [ 3,1,1,2 ],
    [ 9,1,1,1 ] ]

y = len(situations[0])-1 # final logit is the target category index

# create a figure
plt.figure(figsize=(12,4))
xlabls = [] # x-axis tick labels
label_code = ['A','B','C','D','E'] # for lettering the x-axis labels

for i,sit in enumerate(situations):

  # raw model output (logits)
  model_output = torch.tensor([sit],dtype=torch.float32)

  # log-softmax
  log_softmax = F.log_softmax(model_output,dim=-1)

  # negative log-likelihood loss
  loss = -log_softmax[0,y]

  # perplexity
  ppl = torch.exp(loss)

  # draw the results
  plt.bar(np.array([.7,.9,1.1,1.3])+i,model_output[0].detach(),width=.2,edgecolor='k')
  plt.text(1.3+i,model_output[0,-1].detach()+.1,'Targ',font={'size':14},ha='center',va='bottom')

  # x-axis tick label
  xlabls.append(f'"{label_code[i]}"\nppl = {ppl.item():.3f}')



plt.gca().set(title='Model outputs (logits) and perplexity',ylabel='Logits',
              xticks=range(1,len(situations)+1),xticklabels=xlabls,ylim=[0,10])

plt.tight_layout()
plt.savefig('ch4_proj20_part0.png')
plt.show()

# **Part 1: Perplexity over time**

In [None]:
# load pretrained GPT-2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')
gpt2.eval()

In [None]:
# connect to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpt2.to(device);

In [None]:
# import Frankenstein
url = 'https://www.gutenberg.org/cache/epub/84/pg84.txt'
text = requests.
print()

In [None]:
# tokenize the book
tokens =
num_tokens =
print(f'Number of tokens: {num_tokens:,}')

In [None]:
# max sequence length
seq_len =
print(f'Sequence length: {seq_len}')

# how many samples fit into the data
nSegments =
print(f'Number of non-overlapping segments: {nSegments:,}')

In [None]:
perplexities = np.zeros(nSegments)

sum_losses = 0.

for segi in tqdm(range(nSegments)):

  # 1) start and end indices
  start =
  end = start +

  # 2) extract the data and push to the GPU
  X = .to(device)

  # 3) forward pass
  with torch.no_grad():
    outputs = gpt2

  # 4) accumulated loss for later perplexity calculations
  sum_losses +=

  # 5) per-segment perplexity
  perplexities[segi] = torch.exp(

# 6) calculate perplexity
ave_perplexity = torch.exp(

In [None]:
plt.figure(figsize=(10,4))

plt.plot(label='Segment perplexities')
plt.axhline(label='Perplexity of average')
plt.axhline(label='Average of perplexities')

plt.legend()
plt.gca().set(xlabel='Segment position',ylabel='Perplexity',xlim=[-2,nSegments+2])

plt.tight_layout()
plt.savefig('ch4_proj20_part1.png')
plt.show()

# **Part 2: A perplexity function**

In [None]:
def calc_perplex(tokens,model=gpt2,seq_len=gpt2.config.n_positions):

  # number of segments in the total token sequence
  nSegments =

  # initialize losses
  sum_losses = 0.

  for i in range(nSegments):

    # find start and end indices
    start =
    end =

    # get the token sequence (with batch dimension)
    X =

    # forward pass
    with torch.no_grad():
      outputs =

    # calculate and store this batch's loss
    sum_losses +=

  # after segments loop, perplexity = exp(average per-token losses over this segment)
  perplexity = torch.exp(  /  )

  return perplexity

In [None]:
calc_perplex(tokens)

In [None]:
ave_perplexity

# **Part 3: Impact of sequence length**

In [None]:
seq_lengths =

perp_by_len = np.zeros(len(seq_lengths))

for i in tqdm(range(len(seq_lengths))):
  perp_by_len[i] = calc_perplex(

print('\n')
for i in range(len(seq_lengths)):
  print(f'Sequence length {} has perplexity {}')

In [None]:
plt.figure(figsize=(10,4))

# min-max scaled
perpl_scaled =

plt.bar(,,edgecolor='k',linewidth=.5,
        color=plt.cm.plasma(perpl_scaled))
plt.gca().set(title='Perplexities for different sequence lengths')

plt.tight_layout()
plt.savefig('ch4_proj20_part3.png')
plt.show()

# **Part 4: Perplexities in different texts**

In [None]:
# all books have the same url format; they are unique by numerical code
baseurl='https://www.gutenberg.org/cache/epub/'

bookurls = [
    # code       title
    ['84',    'Frankenstein'    ],
    ['64317', 'GreatGatsby'     ],
    ['11',    'AliceWonderland' ],
    ['1513',  'RomeoJuliet'     ],
    ['76',    'HuckFinn'        ],
    ['219',   'HeartDarkness'   ],
    ['2591',  'GrimmsTales'     ],
    ['2148',  'EdgarAllenPoe'   ],
    ['36',    'WarOfTheWorlds'  ],
    ['829',   'GulliversTravels']
]

In [None]:
ppls = np.zeros(len(bookurls))

for i,(code,title) in enumerate(bookurls):

  # get the text tokens
  fullurl = baseurl + code + '/pg' + code + '.txt'
  text = requests.get(fullurl).text
  tokens = tokenizer.encode(text,return_tensors='pt')

  # just the first 50k tokens for speed and direct comparison
  tokens =

  # calculate perplexity
  ppls[i] = calc_perplex

  print

In [None]:
plt.figure(figsize=(10,4))

# min-max scaling for coloring the bars
ppls_scaled = (ppls - ppls.min()) / (ppls.max() - ppls.min())

# and draw the bars


plt.tight_layout()
plt.savefig('ch4_proj20_part4.png')
plt.show()

# **Part 5: Perplexities in different models**

In [None]:
# dictionary of modelname:identifier
model_ids = {
    'small':  'gpt2',        # 124M
    'medium': 'gpt2-medium', # 355M
    'large':  'gpt2-large',  # 774M
    'xl':     'gpt2-xl'      # 1.6B
}

# load all models into a dictionary
models = {}
for name, id in model_ids.items():
  models[name] = AutoModelForCausalLM.from_pretrained(id).to(device)
  # switch to eval mode

In [None]:
perp_models = np.zeros(len(models))

for i,(name, model) in enumerate(models.items()):
  perp_models[i] = calc_perplex(,model=
  print(f'Perplexity of {} for GPT2-{}')

In [None]:
plt.figure(figsize=(10,4))

perp_models_scaled = (perp_models - perp_models.min()) / \
                     (perp_models.max() - perp_models.min())


plt.tight_layout()
plt.savefig('ch4_proj20_part5.png')
plt.show()

# **Part 6: Perplexities in different models (Wikitext)**

In [None]:
# note: over-writing 'tokens' from earlier
text = load_dataset('wikitext','wikitext-2-raw-v1',split='test')

# join the text samples
tokens =
torch.numel(tokens)

In [None]:
perp_models = np.zeros(len(models))

for i,(name, model) in enumerate(models.items()):
  perp_models[i] =
  print(f'Perplexity of

In [None]:
plt.figure(figsize=(10,4))

perp_models_scaled =

plt.bar(range(len(perp_models)),perp_models,edgecolor='k',linewidth=.5,
        color=plt.cm.plasma(perp_models_scaled))
plt.gca().set(xticks=range(len(perp_models)),xticklabels=list(models.keys()),
              ylabel='Perplexity',title='Perplexities of different GPT2 models (Wikitext)')

plt.tight_layout()
plt.savefig('ch4_proj20_part6.png')
plt.show()