|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[26] Current layer = previous layer + adjustments</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import pearsonr,spearmanr

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Hooks and hidden states**

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2',output_hidden_states=True)
model.eval()

In [None]:
# initialize the dictionary containing data
transf_out = {}

# the hook function
def hook(module, input, output):
  # module: the layer the hook is attached to
  # input : inputs passed into that layer during the forward pass
  # output: output of that layer
  transf_out['data'] = output[0].detach()

# implant the hook
layi = 5
handle = model.transformer.h[layi].register_forward_hook(hook)

In [None]:
text = "Plants make fantastic pets because they don't leave a mess after you feed them."
tokens = tokenizer.encode(text,return_tensors='pt')

# forward pass to trigger the hook
with torch.no_grad():
  outputs = model(tokens)

# remove the hook
handle.remove()

In [None]:
print('     Hidden states is size:',outputs.hidden_states[layi+1].shape)
print('Transformer output is size:',transf_out['data'].shape)

In [None]:
plt.figure(figsize=(10,4))

dimnum = 123

plt.plot(outputs.hidden_states[layi+1][0,:,dimnum],'o-',markeredgecolor='k',
         color=[.7,.7,.9],markerfacecolor=[.7,.7,.9],markersize=10,label='Hidden state')
plt.plot(transf_out['data'][0,:,dimnum],'r+',markersize=14,markeredgewidth=2,label='Transformer output')

plt.gca().set(xlabel='Token position',ylabel='Activation value',title=f'Hidden state vs. transformer output from dimension {dimnum}')
plt.legend()

plt.tight_layout()
plt.savefig('ch5_proj26_part1.png')
plt.show()

In [None]:
# differences are zeros
outputs.hidden_states[layi+1] - transf_out['data']

# **Part 2: Final hidden layer normalization**

In [None]:
n_layers = model.config.n_layer

In [None]:
# hook all hidden layers
transf_out = {}

def outerHook(layeri):
  def hook(module, input, output):
    transf_out[f'layer_{layeri}'] = output[0].detach()
  return hook

handles = []
for layeri in range(model.config.n_layer):
  modname = model.transformer.h[layeri]
  h = modname.register_forward_hook(outerHook(layeri))
  handles.append(h)

In [None]:
# forward pass to trigger the hook
with torch.no_grad():
  outputs = model(tokens)

# remove the hooks
for h in handles:
  h.remove()

In [None]:
print('Here are the keys in transf_out:\n ',transf_out.keys(),'\n')
print("transf_out['layer_3'] has size:\n ",transf_out['layer_3'].shape)

In [None]:
# initialize
norms = torch.zeros(n_layers)

# loop over layers
for i in range(n_layers):
  diffmat = outputs.hidden_states[i+1] - transf_out[f'layer_{i}']
  norms[i] = torch.norm(diffmat)

In [None]:
plt.figure(figsize=(8,3))
plt.plot(norms,'kh',markerfacecolor=[.9,.7,.7],markersize=14)
plt.gca().set(xlabel='Hidden layer',ylabel='Norm of difference',
              title='Differences between hidden_state and hook output')

plt.tight_layout()
plt.savefig('ch5_proj26_part2.png')
plt.show()

In [None]:
lastL = model.config.n_layer - 1

layerNorm_final = model.transformer.ln_f( transf_out[f'layer_{lastL}'] )
outputs.hidden_states[lastL+1] - layerNorm_final

# **Part 3: Hooking attention and MLP adjustments**

In [None]:
# re-initialize activations dictionary
activations = {}

# into attention projections
def hook_att(module, input, output):
  activations['attn'] = output.detach()

# and into mlp projections
def hook_mlp(module, input, output):
  activations['mlp'] = output.detach()


# implant hooks into projection layers
layer2hook = 10

model.transformer.h[layer2hook].attn.c_proj.register_forward_hook(hook_att)
model.transformer.h[layer2hook].mlp.c_proj.register_forward_hook(hook_mlp)

In [None]:
# forward pass to trigger the hook
with torch.no_grad():
  outputs = model(tokens)

print('Here are the keys in "activations":\n ',activations.keys(),'\n')
print("activations['attn'] has size:\n ",activations['attn'].shape,'\n')
print("activations['mlp'] has size:\n ",activations['mlp'].shape)

In [None]:
_,axs = plt.subplots(1,2,figsize=(10,4))


# loop over tokens
for toki in range(len(tokens[0])):

  # extract vectors
  a = activations['attn'][0,toki,:]
  m = activations['mlp'][0,toki,:]

  # correlate them
  r_p = pearsonr(a,m).statistic
  r_s = spearmanr(a,m).statistic

  # plot them
  axs[0].plot(toki,r_p,'ks',markerfacecolor=[.9,.7,.7],markersize=10)
  axs[0].plot(toki,r_s,'ko',markerfacecolor=[.7,.9,.7],markersize=10)
  axs[0].plot([toki,toki],[r_p,r_s],'--',color=[.3,.3,.3],linewidth=.7,zorder=-10)


axs[0].legend(['Pearson','Spearman'])
axs[0].axhline(0,linestyle=':',color=[.7,.7,.7],zorder=-20)
axs[0].set(xlabel='Token index',ylabel='Correlation (r)',
           title='A) Attention vs. MLP adjustments')


axs[1].plot(a,m,'kh',markerfacecolor=[.7,.7,.9,.7],markersize=10)
axs[1].grid(linestyle='--',color=[.7,.7,.7],linewidth=.7)
axs[1].set(xlabel='Attention output',ylabel='MLP output',
           title=f'B) Data from token #{toki}')

plt.tight_layout()
plt.savefig('ch5_proj26_part3.png')
plt.show()

# **Part 4: Reconstructing hidden state L from L-1**

In [None]:
hs_pre = outputs.hidden_states[layer2hook]
hs_pst = outputs.hidden_states[layer2hook+1]

# reconstruct one layer from the previous
recon = hs_pre + activations['attn'] + activations['mlp']

print('  Attention size:', activations['attn'].shape)
print('        MLP size:', activations['mlp'].shape)
print('Hidden state pre:', hs_pre.shape)
print('Hidden state pst:', hs_pst.shape)

In [None]:
# demonstrate the reconstruction accuracy
recon-hs_pst, torch.norm(recon-hs_pst)

In [None]:
# common axis limits
ax_lim = max( torch.max(torch.abs(hs_pre[0,1:,:])).item(),
              torch.max(torch.abs(hs_pst[0,1:,:])).item() )

# increase a bit for marker size
ax_lim *= 1.1

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3.5))

axs[0].plot(hs_pre[0,1:,:].flatten(),hs_pst[0,1:,:].flatten(),'ko',markerfacecolor=[.9,.7,.7,.5],markersize=8)
axs[0].set(xlabel=f'Hidden layer {layer2hook}',ylabel=f'Hidden layer {layer2hook+1}',
           title='A) Hidden state activations',xlim=[-ax_lim,ax_lim],ylim=[-ax_lim,ax_lim])

axs[1].plot(activations['attn'][0,1:,:].flatten(),activations['mlp'][0,1:,:].flatten(),'ks',markerfacecolor=[.7,.7,.9,.5],markersize=8)
axs[1].set(xlabel=f'B) Attention layer {layer2hook}',ylabel=f'MLP layer {layer2hook}',title='B) Adjustments',
           xlim=[-ax_lim,ax_lim],ylim=[-ax_lim,ax_lim])

axs[2].plot(hs_pst[0,1:,:].flatten(),recon[0,1:,:].flatten(),'kh',markerfacecolor=[.7,.9,.7,.5],markersize=8)
axs[2].set(xlabel=f'Hidden layer {layer2hook+1}',ylabel=f'HL({layer2hook}) + attn + mlp',title='C) Reconstructed',
           xlim=[-ax_lim,ax_lim],ylim=[-ax_lim,ax_lim])

for a in axs:
  a.grid(linestyle='--',color=[.3,.3,.3],linewidth=.5)


plt.tight_layout()
plt.savefig('ch5_proj26_part4.png')
plt.show()

# **Part 5: Overwriting hooked activations**

In [None]:
activations['attn'].shape

In [None]:
model(tokenizer.encode('I like corn',return_tensors='pt'))
activations['attn'].shape

In [None]:
# no handles to remove, must reimport :(
model = AutoModelForCausalLM.from_pretrained('gpt2',output_hidden_states=True)
model.eval();

In [None]:
activations = []
def mlp_hook(module, inp, out):
  activations.append(out)

# hook the MLP in layer 4
handle = model.transformer.h[4].mlp.c_proj.register_forward_hook(mlp_hook)

In [None]:
# run some text through the model
model( tokenizer.encode('I like chocolate.',return_tensors='pt') );

In [None]:
# note: just a list, not a dictionary!
activations

In [None]:
print(f'"activations" is a {type(activations)} that contains {len(activations)} elements \n')
for i in range(len(activations)):
  print(f'Element {i} has shape {activations[i].shape}')

In [None]:
# run the model three more times
model( tokenizer.encode('I still like chocolate.',return_tensors='pt') )
model( tokenizer.encode('You know the shape my breath will take before I let it out.',return_tensors='pt') )
model( tokenizer.encode('Four score and seven years ago.',return_tensors='pt') );

In [None]:
print(f'"activations" is a {type(activations)} that contains {len(activations)} elements \n')
for i in range(len(activations)):
  print(f'Element {i} has shape {activations[i].shape}')