|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[37] Token prediction and attention KL divergences</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import requests

from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Model, tokens, attention projections**

In [None]:
# load GPT2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-large')
tokenizer = AutoTokenizer.from_pretrained('gpt2-large')

model.eval()

In [None]:
n_layers = model.config.n_layer

In [None]:
# hook the attention projection vectors
att_projs = {}

def implant_hook(layer_number):
  def hook(module, input, output):
    att_projs[f'attn_L{layer_number}'] = output.detach()
  return hook

# implant the hooks
handles = []
for i in range(n_layers):
  h = model.transformer.h[i].attn.c_proj.register_forward_hook(implant_hook(i))
  handles.append(h)

In [None]:
# Through the Looking Glass (Alice in Wonderland)
text = requests.get('https://www.gutenberg.org/cache/epub/11/pg11.txt').text

allTokens = tokenizer.encode(text,return_tensors='pt')
max_seq_len = model.config.max_position_embeddings

# get context-length from middle of the book
start_idx = len(allTokens[0])//2
end_idx = start_idx + max_seq_len

tokens = allTokens[:,start_idx:end_idx]
tokens.shape

In [None]:
# vector of accurate token predictions
with torch.no_grad():
  outs = model(tokens)

# print attention activation sizes
for k,v in att_projs.items():
  print(f'{k:>8} has shape {list(v.shape)}')

# **Part 2: Token prediction accuracy**

In [None]:
prediction_acc = np.zeros(max_seq_len,dtype=bool)
toklens = np.zeros(max_seq_len)

for toki in range(max_seq_len):

  # target token (taken from 'allTokens', not 'tokens'!)
  targettok = allTokens[0,start_idx+toki+1]

  # test whether the max-logit on this token matches the next token
  prediction_acc[toki] = np.argmax(outs.logits[0,toki,:]) == targettok

  toklens[toki] = len(tokenizer.decode(targettok))

print(f'Model correctly predicted {prediction_acc.mean():.2%} ({prediction_acc.sum()}/{len(prediction_acc)})')

In [None]:
plt.figure(figsize=(10,3))
plt.plot(prediction_acc+np.random.normal(0,.05,max_seq_len),'ws',markerfacecolor='k',alpha=.7)
plt.gca().set(yticks=[0,1],yticklabels=['Incorrect','Correct'],ylim=[-.5,1.5],
              xlabel='Token position',title='Accuracy by token position')

plt.tight_layout()
plt.savefig('ch6_proj37_part2a.png')
plt.show()

In [None]:
plt.figure(figsize=(10,3))
y,x = np.histogram(toklens[prediction_acc==False],bins='fd',density=True)
plt.plot(x[:-1],y,'rs-',markersize=10,markerfacecolor=[.9,.5,.5],
         label=f'Incorrect (ave. {toklens[prediction_acc==False].mean():.2f} chars)')

y,x = np.histogram(toklens[prediction_acc==True],bins='fd',density=True)
plt.plot(x[:-1],y,'go-',markersize=10,markerfacecolor=[.5,.9,.5],
         label=f'Correct (ave. {toklens[prediction_acc==True].mean():.2f} chars)')

plt.legend()
plt.gca().set(xlabel='Token length (characters)',ylabel='Density',
              title='Token length distributions by prediction accuracy')

plt.tight_layout()
plt.savefig('ch6_proj37_part2b.png')
plt.show()

# **Part 3: Attention projection norms by accuracy (one layer)**

In [None]:
layeri = 4

# adjustment vectors and their norms
# C = correct prediction; I = incorrect prediction
C = att_projs[f'attn_L{layeri}'][0,prediction_acc,:]
C_norms = torch.norm(C,dim=-1).log() # using method here instead of function

I = att_projs[f'attn_L{layeri}'][0,~prediction_acc,:]
I_norms = torch.norm(I,dim=-1).log()

# histogram bins
minmax = [ min(min(C_norms),min(I_norms)),
           max(max(C_norms),max(I_norms)) ]
histbins = torch.linspace(minmax[0],minmax[1],31)

# histogram data with counts
yC,_ = torch.histogram(C_norms,bins=histbins,density=False)
yI,_ = torch.histogram(I_norms,bins=histbins,density=False)

# then convert to probability (not densities)
yC = yC/yC.sum()
yI = yI/yI.sum()


plt.figure(figsize=(10,4))
plt.plot(histbins[:-1],yC,linewidth=2,label='Correct')
plt.plot(histbins[:-1],yI,linewidth=2,label='Incorrect')

plt.gca().set(xlabel='Log-norm',ylabel='Probability')

plt.tight_layout()
plt.savefig('ch6_proj37_part3.png')
plt.show()

In [None]:
y_count,_ = torch.histogram(C_norms,bins=histbins,density=False)
y_density,_ = torch.histogram(C_norms,bins=histbins,density=True)

print(f'Sum of counts: {y_count.sum()}')
print(f'Sum of density: {y_density.sum():.2f}')
print(f'Sum of bin-normalized density: {y_density.sum() * (histbins[1]-histbins[0]):.2f}')
print(f'Sum of count-scaled probability: {sum(y_count/y_count.sum()):.2f}')

# **Part 4: Asymmetric and symmetric KL divergence**

In [None]:
F.kl_div(torch.log(yC),yI)#,reduction='batchmean')
# torch.__version__ # to check version

In [None]:
# first argument is the "input" (q); second argument is "target" (p)
c2i = F.kl_div(torch.log(yC+1e-15),yI,reduction='batchmean')
i2c = F.kl_div(torch.log(yI+1e-15),yC,reduction='batchmean')

print(f'Target is "incorrect": {c2i:.4f}')
print(f'Target is "correct"  : {i2c:.4f}')

In [None]:
# symmetric approach (Jensen-Shannon divergence)
pAve = (yC+yI)/2

symKL = F.kl_div(torch.log(yC + 1e-15), pAve, reduction='batchmean') + \
        F.kl_div(torch.log(yI + 1e-15), pAve, reduction='batchmean')
symKL /= 2

print(f'Average KLs : {(c2i+i2c)/2:.4f}')
print(f'Symmetric KL: {symKL:.4f}')

In [None]:
# JSD function
def symmetric_KL(d1,d2,nbins=31):

  # histogram bins
  minmax = [ min(min(d1),min(d2)),
             max(max(d1),max(d2)) ]
  histbins = torch.linspace(minmax[0],minmax[1],nbins)


  # histogram data in probabilities
  yd1,_ = torch.histogram(d1,bins=histbins)
  yd2,_ = torch.histogram(d2,bins=histbins)
  yd1 = yd1/yd1.sum()
  yd2 = yd2/yd2.sum()

  # get the average probability distribution
  pAve = (yd1+yd2)/2

  # and calculate symmetric KL
  symKL = F.kl_div(torch.log(yd1 + 1e-15), pAve, reduction='batchmean') + \
          F.kl_div(torch.log(yd2 + 1e-15), pAve, reduction='batchmean')
  return symKL/2

In [None]:
# check for symmetry
symmetric_KL(C_norms,I_norms), symmetric_KL(I_norms,C_norms)

In [None]:
# initializations
bincounts = np.arange(10,51)
kl_divs = np.zeros(len(bincounts))

# run the experiment!
for bini in range(len(bincounts)):
  kl_divs[bini] = symmetric_KL(C_norms,I_norms,bincounts[bini])


# visualize
plt.figure(figsize=(8,4))
plt.plot(bincounts,kl_divs,'kh',markersize=10,markerfacecolor=[.7,.7,.9])
plt.gca().set(xlabel='Number of histogram bins',ylabel='Divergence value',title='Impact of discretization on symmetric KL divergence')

plt.tight_layout()
plt.savefig('ch6_proj37_part4.png')
plt.show()

# **Part 5: Statistical significance of KL**

In [None]:
# observed KL distance
kldiv = symmetric_KL(C_norms,I_norms)

# number of permutes
n_permutes = 1000

# H0 distribution
kl_perms = torch.zeros(n_permutes)
for permi in range(n_permutes):

  # permuted predictions
  fake_predictions = np.random.permutation(prediction_acc)
  C = att_projs[f'attn_L{layeri}'][0,fake_predictions,:]
  C_norms_p = torch.norm(C,dim=-1).log()

  I = att_projs[f'attn_L{layeri}'][0,~fake_predictions,:]
  I_norms_p = torch.norm(I,dim=-1).log()

  # calculate and store KL
  kl_perms[permi] = symmetric_KL(C_norms_p,I_norms_p)

# p-value is number of H0 values greater than observed value
pval = (kl_perms>kldiv).sum() / n_permutes

In [None]:
plt.figure(figsize=(10,3))
plt.hist(kl_perms,bins='fd',edgecolor='k',linewidth=.3,color=[.5,.3,.3],label='$H_0$ dist.')
plt.axvline(kldiv,linewidth=4,color='k',label='Observed KL')

plt.gca().set(xlabel='KL divergence',ylabel='Count',title=f'Statistical evaluation of KL (p = {pval:.3f})')
plt.legend()

plt.tight_layout()
plt.savefig('ch6_proj37_part5.png')
plt.show()

# **Part 6: Laminar distribution of KL divergences**

In [None]:
n_permutes = 1000

_,axs = plt.subplots(1,3,figsize=(12,3))


for layeri in tqdm(range(n_layers)):

  # adjustment vectors and their norms
  C = att_projs[f'attn_L{layeri}'][0,prediction_acc,:]
  C_norms = torch.norm(C,dim=-1).log()

  I = att_projs[f'attn_L{layeri}'][0,~prediction_acc,:]
  I_norms = torch.norm(I,dim=-1).log()

  # panel A: means of the norms
  axs[0].plot(layeri,I_norms.mean(),'rh',markerfacecolor=[.9,.5,.5,.7],markersize=7)
  axs[0].plot(layeri,C_norms.mean(),'bs',markerfacecolor=[.5,.5,.9,.7],markersize=7)

  diff = C_norms.mean() - I_norms.mean()


  # observed KL distance
  kldiv = symmetric_KL(C_norms,I_norms).item()

  # H0 distribution via permutation testing
  kl_perms = np.zeros(n_permutes)
  for permi in range(n_permutes):

    # permuted predictions
    fake_predictions = np.random.permutation(prediction_acc)
    C = att_projs[f'attn_L{layeri}'][0,fake_predictions,:]
    C_norms = torch.norm(C,dim=-1).log()

    I = att_projs[f'attn_L{layeri}'][0,~fake_predictions,:]
    I_norms = torch.norm(I,dim=-1).log()

    # calculate and store KL
    kl_perms[permi] = symmetric_KL(C_norms,I_norms)

  # p-value is number of H0 values greater than observed value
  pval = (kl_perms>kldiv).sum() / n_permutes

  axs[1].plot([layeri,layeri],[kl_perms.min(),kl_perms.max()],'k',linewidth=.4)
  axs[1].plot(layeri,kl_perms.mean(),'ks',markersize=4)

  # draw the observed KL according to significance
  if pval>.05:

    # for panel B
    msize = 6
    axs[1].plot(layeri,kldiv,'kh',markerfacecolor=[.9,.5,.5,.7],markersize=msize)

    # for panel C
    if diff>0:
      axs[2].plot(layeri,diff,'bs',markerfacecolor=[.7,.7,.9,.7],markersize=msize)
    else:
      axs[2].plot(layeri,diff,'rh',markerfacecolor=[.9,.7,.7,.7],markersize=msize)

  else:

    # panel B
    msize = 10
    axs[1].plot(layeri,kldiv,'kh',markerfacecolor=[.5,.9,.5,.7],markersize=msize)

    # panel C
    if diff>0:
      axs[2].plot(layeri,diff,'bs',markerfacecolor=[.3,.3,.9,.7],markersize=msize)
    else:
      axs[2].plot(layeri,diff,'rh',markerfacecolor=[.9,.3,.3,.7],markersize=msize)


# finalize the figure
axs[0].legend(['Incorrect','Correct'])
axs[0].set(xlabel='Transformer layer',ylabel='Mean of norms',title='A) Distribution means')
axs[1].set(xlabel='Transformer layer',ylabel='KL divergence',title='B) Observed KL with $H_0$ distribution')
axs[2].axhline(0,linestyle='--',color='k',linewidth=.4)
axs[2].set(xlabel='Transformer layer',ylabel='Difference',title='C) Difference of means')


plt.tight_layout()
plt.savefig('ch6_proj37_part6.png')
plt.show()