|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[17] Probabilistic token selection</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300,
})

# **Part 1: PyTorch's multinomial function**

In [None]:
# a vector (must be tensor)
vect = torch.tensor([1,2,5],dtype=torch.float)

# sample a number
torch.multinomial(vect,1)

In [None]:
# sample 10 times from that vector
vect[torch.multinomial(vect,10,replacement=True)]

In [None]:
# 10k samples!
mn = torch.multinomial(vect,10000,replacement=True)

# collect the distribution
vals,counts = torch.unique(mn,return_counts=True)

# print the output values and how often they occurred
for v,c in zip(vals,counts):
  print(f'"{v}" was sampled {c} times ({c/len(mn):.2%})')

In [None]:
# treat the vector as if it contains (scaled) probability values

print('Value | Observed | Expected')
print('------+----------+----------')

# again with more information
for v,c,vectval in zip(vals,counts,vect):

  observedFrequency = c/len(mn)
  expectedFrequency = vectval/torch.sum(vect)

  print(f' "{v}"  |  {observedFrequency:.2%}  |  {expectedFrequency:.2%}')

In [None]:
torch.multinomial??

# **Part 2: Common errors with torch.multinomial**

In [None]:
# # error 1: requires torch tensor
# torch.multinomial([1.,2,.3],1)
# torch.multinomial(np.array([1.,2,.3]),1)

In [None]:
# # error 2: default is no replacement
# torch.multinomial(vect,len(vect)+1)

In [None]:
# # error 3: only floats
# torch.multinomial(torch.tensor([1,1,1]),1)

In [None]:
# # error 4: only non-negative numbers
# torch.multinomial(torch.tensor([-1,1.,1]),1)

# **Part 3: Deterministic token selection**

In [None]:
# load pretrained GPT-2 model and tokenizer
llm = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')

llm.eval() # switch to eval mode

In [None]:
txt = "I don't want to grow"
tokens = tokenizer.encode(txt,return_tensors='pt')
tokens

In [None]:
# forward pass through the model
with torch.no_grad():
  outputs = llm(tokens)
outputs

In [None]:
outputs.logits.shape

In [None]:
logits = outputs.logits[0,-1,:].detach()
logits_sm = F.softmax(logits,dim=-1)
logits.shape

In [None]:
# find the maximum
max_logit = logits.argmax()
print(f'The maximum logit is #{max_logit} with a softmax probability of {logits_sm[max_logit]:.2%}')
print(f'The max word is "{tokenizer.decode(max_logit)}"')

In [None]:
# plot the raw and softmax logits
_,axs = plt.subplots(1,2,figsize=(12,3))

axs[0].plot(max_logit,logits[max_logit],'rh')
axs[0].plot(logits,'gh',linewidth=.2,markerfacecolor=[.7,.9,.7,.3])
axs[0].set(xlabel='Token index',ylabel='Output logits (raw)',title='A) Raw output logits')

axs[1].plot(max_logit,logits_sm[max_logit],'rh')
axs[1].plot(logits_sm,'bh',linewidth=.2,markerfacecolor=[.7,.9,.7,.3])
axs[1].set(xlabel='Token index',ylabel='Softmax probability',title='B) Softmax probabilities')

plt.tight_layout()
plt.savefig('ch4_proj17_part3.png')
plt.show()

# **Part 4: Multinomial probabilistic selection**

In [None]:
for t in torch.multinomial(logits_sm,10,replacement=True):
  print(f'{txt}"{tokenizer.decode(t)}"\t({logits_sm[t]:.2%})')

# **Part 5: Top-k sampling**

In [None]:
k = 10
top_k = torch.topk(logits_sm,k)

print(txt,'___\n')

for i in range(k):
  val = top_k.values[i]
  tok = top_k.indices[i]
  print(f'{tok:5} ({val:5.1%}) is "{tokenizer.decode(tok)}"')

In [None]:
print(f'The top {k} options account for {logits_sm[top_k.indices].sum():.1%} of the probability mass.')

In [None]:
# re-normalize to sum to 1
sm_top_k = F.softmax(logits[top_k.indices],dim=-1)

# what they look like:
print(sm_top_k)

# confirm: sum to 1
print(sm_top_k.sum())

In [None]:
for t in torch.multinomial(sm_top_k,10):
  tidx = top_k.indices[t] # token index into vocab (not top-k)
  print(f'{txt}"{tokenizer.decode(tidx)}"\t({sm_top_k[t]:6.2%})')

# **Part 6: Top-p sampling**

In [None]:
# threshold
p_thresh = .9

# sort the probabilities
sorted_sm,sorted_idx = torch.sort(logits_sm,descending=True)

cumulative_probs = torch.cumsum(sorted_sm,dim=-1)
tokens2keep = torch.where(cumulative_probs>=p_thresh)[0][0]

# print the results
print(f'{tokens2keep} tokens accounts for {cumulative_probs[tokens2keep]:.1%} of the probability mass.')

In [None]:
tokens_idx = sorted_idx[:tokens2keep+1]

# renormalize from logits to probability
sm_top_p = F.softmax(logits[tokens_idx],dim=-1)

In [None]:
for t in torch.multinomial(sm_top_p,10):
  tidx = tokens_idx[t] # token index into vocab (not top-p)
  print(f'{txt}"{tokenizer.decode(tidx)}"\t({sm_top_p[t]:6.2%})')

# **Part 7: Chat with GPT2**

In [None]:
n_new_tokens = 15

# start from a prompt
txt = "I don't want to grow"
print(f'Start prompt: {txt}')

# loop over the new tokens
for _ in range(n_new_tokens):

  # tokenize, forward pass, get final logits
  tokens = tokenizer.encode(txt,return_tensors='pt')
  with torch.no_grad():
    logits = llm(tokens).logits[0,-1,:].detach()

  # softmax
  logits_sm = F.softmax(logits,dim=-1)

  # pick a new token
  newtok = torch.multinomial(logits_sm,1)

  # concatenate and print
  txt += tokenizer.decode(newtok)
  print(f'New token {_+1:2}: {txt}')