|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Interfering with attention <h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Token prediction after head ablations<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

from tqdm import tqdm

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Model, hook, tokens

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

In [None]:
# some useful variables
nheads =
n_emb =
head_dim =

In [None]:
def implant_hook(layer_number):
  def hook4attn(module,input):

    # print some useful information
    # print(len(input),type(input),input[0].shape)

    # modify the activation only for this layer
    if

      # reshape so we can index heads
      head_tensor = input

      # specify the value to replace
      if
        value2replace = 0
      else:


      # then replace
      head_tensor

      # print confirmation
      # print(f'Zeroed out L{layer_number}, H{head2ablate}')

      # reshape back to tensor
      head_tensor =

      # return a tuple matching the original
      input =

    return input
  return hook4attn


handles = []
for layeri in range(model.config.n_layer):
  register_forward_pre_hook(implant_hook(layeri))
  .append(h)

In [None]:
tokens = tokenizer.encode('Berlin is the capital of',return_tensors='pt')
nbatches,ntokens =


for i in range(ntokens):
  print(f'Token position {i:2} is index {} and is "{}"')

In [None]:
# target and semantically related nontarget
nontarget_idx =
target_idx =

# confirm single-tokens
nontarget_idx,target_idx

# Exercise 2: Confirm accuracy and get clean logits

In [None]:
layer2ablate =
head2ablate =

# forward pass

# calculate softmax probability in percent
sm_clean =

In [None]:
plt.figure(figsize=(10,4))

# all the log-sm values
plt.plot(,'k.',markersize=2,alpha=.3)

# the target and nontarget values
plt.plot(target_idx,,'gs',label='Germany')
plt.plot(nontarget_idx,,'ro',label='France')

# make the graph look pretty :D
plt.gca().set(xlabel='Vocab elements',ylabel='Log softmax',xlim=[0,model.config.vocab_size])
plt.title(f'Predicted next token is "{}"')
plt.legend()

plt.show()

# Exercise 3: Zero-out attention heads for all token indices

In [None]:
replaceWithZeros = True

In [None]:
resultsZero = np.zeros

# loop over layers and heads
for layer2ablate in tqdm
  for head2ablate in

    # forward pass


    # softmax
    sm =

    # sm logits for target and nontarget
    resultsZero[layer2ablate,head2ablate,0] = sm
    resultsZero[layer2ablate,head2ablate,1] = sm

    # and the predicted next token
    resultsZero[layer2ablate,head2ablate,2] =

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,4))

clim = 5

h = axs[0].imshow(,vmin=-clim,vmax=clim,cmap=mpl.cm.plasma,aspect='auto')
axs[0].set(xlabel='Layer',ylabel='Head',yticks=range(0,nheads,2),title='%$\Delta$ in prob. for target word')
fig.colorbar(h,ax=axs[0],pad=.01)

axs[1].set(xlabel='Layer',ylabel='Head',yticks=range(0,nheads,2),title='%$\Delta$ in prob. for non-target word')
fig.colorbar(h,ax=axs[1],pad=.01)

plt.suptitle('Change in token selection probability from clean model',fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
u,c = np.unique(resultsZero[:,:,2],)


# Exercise 4: Repeat with head mean imputation

In [None]:
resultsMean = np.zeros((model.config.n_layer,nheads,4))

# loop over layers and heads
for layer2ablate in tqdm(range(model.config.n_layer),desc='Layers...'):
  for head2ablate in range(nheads):

    # forward pass
    with torch.no_grad():
      out = model(tokens)

    # log-softmax
    sm =

    # log-sm logits for target and nontarget


    # the empirical mean value that was imputed


    # and the predicted next token



In [None]:
# create the figure

In [None]:
# print the unique values and their counts

In [None]:
# head-averaged activations
fig,axs = plt.subplots(1,2,figsize=(12,5))

axs[0].plot(,'ko',markerfacecolor=[.9,.7,.7,.6])
axs[0].set(xlabel='Heads $\\times$ layer (index)',ylabel='Head mean',title='As scatter plot')

h = axs[1].imshow(,vmin=-.05,vmax=.05,cmap=mpl.cm.plasma,aspect='auto')
axs[1].set(xlabel='Layer',ylabel='Head',yticks=range(0,nheads,2),title='As image')
fig.colorbar(h,ax=axs[1],pad=.02,fraction=.05)

plt.suptitle('Head activation averages',fontweight='bold')
plt.tight_layout()
plt.show()

# Exercise 5: Comparisons

In [None]:
# setup the figure
fig = plt.figure(figsize=(13,4))
gs  = GridSpec(1,3,figure=fig)
axs = [ fig.add_subplot(gs[:2]) , fig.add_subplot(gs[-1]) ]


### histograms
nbins = 20

y,x = np.histogram(resultsZero - sm_clean
axs[0].plot(x[:-1],y,'.-',linewidth=2,markersize=10,label='Zero target')

y,x = np.histogram(
axs[0].plot(x[:-1],y,'.-',linewidth=2,markersize=10,label='Mean target')

y,x = np.histogram(
axs[0].plot(x[:-1],y,linewidth=2,label='Zero nontarget')

y,x = np.histogram(
axs[0].plot(x[:-1],y,linewidth=2,label='Mean nontarget')

axs[0].set(xlabel='Token probability ($\Delta$ from clean model)',ylabel='Count',ylim=[-1,None],
           title='Histograms of $\Delta$ softmax')
axs[0].legend(fontsize=15)


# difference heat map
h = axs[1].imshow( ,vmin=-1,vmax=1,cmap=mpl.cm.plasma,aspect='auto')
axs[1].set(xlabel='Layer',ylabel='Head',yticks=range(0,nheads,2),title='$\Delta$ target: (mean - zero)')
fig.colorbar(h,ax=axs[1],pad=.02,fraction=.05)

plt.tight_layout()
plt.show()