|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[32] Patching hidden states in indirect object identification</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from scipy.optimize import curve_fit

from tqdm import tqdm

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: The IOI task**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = GPT2LMHeadModel.from_pretrained('gpt2-xl').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

n_layers = model.config.n_layer
model.eval()

In [None]:
text_ME = 'When Mike and Emma went to the cafe, Mike gave a coffee to'
text_EM = 'When Mike and Emma went to the cafe, Emma gave a coffee to'

target_M = tokenizer.
target_E = tokenizer.

tokensME = tokenizer.
tokensEM = tokenizer.

In [None]:
with torch.no_grad():
  outME = model(
  outEM = model(

hs_ME = outME.hidden_states
outME.keys(), outME.hidden_states[3].shape

In [None]:
# predicted next words
nextword_ME = torch.argmax(
nextword_EM = torch.argmax(

print(f'{text_ME}"{tokenizer.decode(nextword_ME)}"')
print(f'{text_EM}"{tokenizer.decode(nextword_EM)}"')

In [None]:
logits_ME = outME.logits
logits_EM =


# setup the figure
fig = plt.figure(figsize=(12,3))
gs = GridSpec(1,5,figure=fig)
ax1 = fig.add_subplot(gs[:2])
ax2 = fig.add_subplot(gs[2:4])
ax3 = fig.add_subplot(gs[-1])

# plot log-sm from "EM" sentence
ax1.plot('go',label='"Mike"')
ax1.plot('rs',label='"Emma"')
ax1.plot(,'k.',alpha=.2)
ax1.legend(fontsize=8)
ax1.set(xlabel='Vocab index',ylabel='Logit value',
           title='A) '+text_EM[-21:]+'...',xlim=[-100,tokenizer.vocab_size+100])

# plot log-sm from "ME" sentence
ax2.plot(,'go',label='"Mike"')
ax2.plot(,'rs',label='"Emma"')
ax2.plot(,'k.',alpha=.2)
ax2.legend(fontsize=8)
ax2.set(xlabel='Vocab index',ylabel='Logit value',
           xlim=[-100,tokenizer.vocab_size+100],title='B) '+text_ME[-21:]+'...')

# how they relate to each other
ax3.plot(,,'k.',alpha=.3)
ax3.set(xlabel='ME logits',ylabel='EM logits',title='C) ME vs. EM')

plt.tight_layout()
plt.savefig('ch5_proj32_part1.png')
plt.show()

In [None]:
IOI_score_ME = outME... - outME...
IOI_score_EM = outEM... - outEM...

print(f'IOI score for text "ME": {IOI_score_ME:6.3f}')
print(f'IOI score for text "EM": {IOI_score_EM:6.3f}')

# **Part 2: IOI with hidden-state patching**

In [None]:
# pick one layer
layeri =

# patch this layer
def hookfun(module, input, output):
  hs = output[0].clone() # make a copy
  hs[0,-1,:] =  # index +1!
  output = (hs,*output[1:])
  return output

# implant the hook
handle = model.transformer.h[layeri].register_forward_hook(hookfun)

# forward pass with hook
with torch.no_grad():
  outEM_patch = model(

# remove the hook
handle.remove()

# now for the IOI score test
IOI_score =

In [None]:
print(f'  Clean IOI score: {IOI_score_EM:6.3f}')
print(f'Patched IOI score: {IOI_score:6.3f}')

In [None]:
logits_EM_patch = outEM_patch.

# setup the figure
fig,axs = plt.subplots(1,3,figsize=(12,3))

# plot log-sm from "EM" sentence from the clean model
axs[0].plot(,'k.',alpha=.2)
axs[0].plot(,'go',label='"Mike"')
axs[0].plot(,'rs',label='"Emma"')
axs[0].legend(fontsize=8)
axs[0].set(xlabel='Vocab index',ylabel='Logit value',
           title='A) Clean model',xlim=[-100,tokenizer.vocab_size+100])

# plot log-sm from "EM" sentence from the patched model
axs[1].plot(,'k.',alpha=.2)
axs[1].plot(,'go',label='"Mike"')
axs[1].plot(,'rs',label='"Emma"')
axs[1].legend(fontsize=8)
axs[1].set(xlabel='Vocab index',ylabel='Logit value',
           title='B) Patched model',xlim=[-100,tokenizer.vocab_size+100])

# impact of patching on all token probs
logits_diff =
axs[2].plot(,'k.',alpha=.3)
axs[2].plot(,'go',label='"Mike"')
axs[2].plot(label='"Emma"')
axs[2].legend(fontsize=8)
axs[2].set(xlabel='Vocab index',ylabel='<-- boost ---- suppress -->',ylim=[-6,6],
           title='C) Manipulation effect',xlim=[-100,tokenizer.vocab_size+100])

plt.tight_layout()
plt.savefig('ch5_proj32_part2.png')
plt.show()

In [None]:
bigFX = torch.topk(
for t in bigFX[1]:
  print(f'Î” (C-P) = {:6.3f} for "{}"')

# **Part 3: IOI experiment over layers**

In [None]:
# initializations
confirmManipulation = np.zeros((n_layers,2))
IOI_scores = np.zeros(n_layers)

# loop over layers
for layeri in tqdm(range(n_layers)):

  # patch this layer
  def hookfun(module,input,output):
    hs = output[0].clone()
    hs[0,-1,:]
    output = (hs,*output[1:])
    return output

  # implant the hook
  handle = model.transformer.h[layeri].register_forward_hook(hookfun)

  # forward pass with hook
  with torch.no_grad():
    outEM = model
  hs_EM = outEM.hidden_states

  # remove the hook
  handle.remove()

  # confirmation: first element should be zero, second non-zero
  confirmManipulation[layeri,0] = torch.norm(
  confirmManipulation[layeri,1] = torch.norm(

  # now for the IOI score
  IOI_scores[layeri] =

In [None]:
# sanity check :)
confirmManipulation

In [None]:
# visualization
plt.figure(figsize=(11,4))

# plot the logit differences for the "clean" runs (no patching)
plt.axhline(label='Clean "EM"')
plt.axhline(label='Clean "ME"')

# then for the experiment results
plt.plot(,'ko',markerfacecolor=[.9,.7,.9],markersize=10)
plt.plot(

# the dividing line
plt.axhline(0,linestyle='--',color='gray',linewidth=.5)
plt.text(0,.1,'Prefer "Mike"',fontsize=12,va='bottom')
plt.text(0,-.1,'Prefer "Emma"',fontsize=12,va='top')

plt.gca().set(xlabel='Transformer layer (index)',ylabel='IOI score',title='Laminar profile of patch manipulation')

plt.tight_layout()
plt.savefig('ch5_proj32_part3.png')
plt.show()

# **Part 4: Curve-fitting with scipy**

In [None]:
# sigmoid function
def sigmoid_fun(x,A,x0,k,b):
  # params:
  #   A: maximum value
  #  x0: x-value of midpoint
  #   k: curve steepness
  #   b: minimum value
  return

# create some data
x = np.linspace
y = sigmoid_fun
y += np.random.randn(len(x))

# visualize them
plt.figure(figsize=(11,4))
plt.plot(x,y,'ko',markerfacecolor=[.9,.7,.9],markersize=10)
plt.gca().set(xlabel='x',ylabel='y',title='Simulated data')

plt.tight_layout()
plt.savefig('ch5_proj32_part4a.png')
plt.show()

In [None]:
# initial parameter guesses [A, x0, k, b]
p0 = []

# fit the sigmoid function to data
est_params,pcov =

print('    Truth | Estim.')
print('---+------+--------')
print(f' A |  10  | {}')
print(f'x0 |   0  | {}')
print(f' k |   1  | {}')
print(f' b |   2  | {}')

In [None]:
# high-res model predictions
yHat = sigmoid_fun

# visualization
plt.figure(figsize=(11,4))

plt.plot(x,y,'ko',markerfacecolor=[.9,.7,.9],markersize=10,label='Data')
plt.plot(label='Model')
plt.axvline(x=,color='m',linestyle='--',label='x0')

plt.legend()
plt.gca().set(xlabel='x',ylabel='y',title=f'Sigmoid fit to data')

plt.tight_layout()
plt.savefig('ch5_proj32_part4b.png')
plt.show()

# **Part 5: Curve-fitting IOI scores**

In [None]:
# remove the final score


In [None]:
### now for the real data
x = np.arange

# initial parameter guess
p0 = []

# fit function to data
est_params,pcov = curve_fit

# high-res model predictions
yHat = sigmoid_fun(

In [None]:
# visualization
plt.figure(figsize=(11,4))

# plot the logit differences for the "clean" runs (no patching)
plt.axhline(IOI_score_EM.cpu(),color='b',label='Clean "EM"')
plt.axhline(IOI_score_ME.cpu(),color='r',label='Clean "ME"')

# then for the experiment results
plt.plot(label='Experiment results')

plt.plot(label='Model')
plt.axvline(x=,color='m',linestyle='--',label=f'x0 (L{est_params[1]})')

# the dividing line
plt.axhline(0,linestyle='--',color='gray',linewidth=.5)
plt.text(0,.1,'Prefer "Mike"',fontsize=12,va='bottom')
plt.text(0,-.1,'Prefer "Emma"',fontsize=12,va='top')

plt.gca().set(xlabel='Transformer layer (index)',ylabel='IOI score',title=f'Data and sigmoid fit')

plt.tight_layout()
plt.savefig('ch5_proj32_part5.png')
plt.show()