|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>How to modify activations<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: replacing attention, MLP, and hidden states<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

# Exercise 1: Zero-out the third attention head in K

In [None]:
tokens = tokenizer.encode('I wonder how many tokens are in pomegranate.',return_tensors='pt')

for t in tokens[0]:
  print(f'Token index {t:5} is "{tokenizer.decode(t)}"')

In [None]:
# some useful variables
nheads = model.config.n_head
n_emb = model.config.n_embd
head_dim = model.config.n_embd // nheads

# find the start and end index of the 3rd head
whichHead_idx = 2
h3_start = whichHead_idx*head_dim
h3_end = h3_start + head_dim

print(f'Attention head {whichHead_idx+1} starts at \nindex {h3_start} and ends at index {h3_end-1}.')

In [None]:
# initialize activations dictionary
activations = {}


def implant_hook(layer_number):
  def hook(module, input, output):

    # split the output into QKV (each is [B,S,H])
    q,k,v = output.split(n_emb,dim=2)

    # make an editable copy of k vectors
    k_copy = k.clone()

    # zero-out the data
    k_copy[:,:,h3_start:h3_end] = 0

    # recombine q with modified k and v
    QKV = torch.cat([q,k_copy,v],dim=2)

    # store the activations
    activations['qkv'] = QKV.detach().numpy()

    # output the QKV matrix so it replaces the original
    return QKV

  return hook


layer2modify = 3
hookHandle = model.transformer.h[layer2modify].attn.c_attn.register_forward_hook(implant_hook(layer2modify))

In [None]:
# confirm
model(tokens)

plt.figure(figsize=(12,4))
plt.plot(range(n_emb),activations['qkv'][0,5,:n_emb],'ks',markerfacecolor=[.7,.7,.9,.5],label='Q')
plt.plot(range(n_emb,2*n_emb),activations['qkv'][0,5,n_emb:n_emb*2],'ks',markerfacecolor=[.7,.9,.7,.5],label='K')
plt.plot(range(2*n_emb,3*n_emb),activations['qkv'][0,5,n_emb*2:],'ks',markerfacecolor=[.9,.7,.7,.5],label='V')

plt.legend()
plt.gca().set(xlim=[-5,n_emb*3+4],xlabel='Index into QKV matrix',ylabel='Activation value',
              title=f'Activations to the token "{tokenizer.decode(tokens[0,5])}" in layer {layer2modify}')
plt.show()

In [None]:
hookHandle.remove()

# Exercise 2: Replace even-indexed MLP neurons with noise

In [None]:
# initialize activations dictionary
activations = {}

def hook(module, input, output):

  # create random noise of the same size
  noise = torch.randn_like(output[:,4,::2]) + 10

  # Note: Because the modification is done directly on the tensor and not on a view of it,
  #       you can edit it in-place as shown below. Making a copy (as in the video) is also fine :D

  # replace
  output[:,4,::2] = noise

  # store the activations
  activations['mlp'] = output.detach().numpy()

  # and return the modified version
  return output


hookHandle = model.transformer.h[5].mlp.c_fc.register_forward_hook(hook)

In [None]:
# confirm
model(tokens)

plt.figure(figsize=(12,4))
plt.plot(activations['mlp'][0,4,:],'ks',markerfacecolor=[.7,.7,.9,.5])

plt.gca().set(xlim=[-5,activations['mlp'].shape[-1]+4],xlabel='Index into MLP expansion',ylabel='Activation value',
              title=f'Activations to the token "{tokenizer.decode(tokens[0,4])}" in layer 5')
plt.show()

In [None]:
hookHandle.remove()

# Exercise 3: Scale the hidden-state activations

In [None]:
# scaling factor
scaling_factor = .1

def hook(module, input, output):

  # extract the hidden states
  hs = output[0]

  # scaling via matrix-scalar multiplication
  hs.mul_(scaling_factor)

  # reconstruct and output
  return (hs,*output[1:])

# note: it's not necessary to create a separate variable hs; you could also use:
# > output[0].mul_(scaling_factor)
# > return output

hookHandle = model.transformer.h[8].register_forward_hook(hook)

In [None]:
# confirm
out = model(tokens,output_hidden_states=True)

hs = out.hidden_states
print(f'There are {len(hs)} hidden_states.')
print(f'Each hidden state is of size {list(hs[3].shape)}')

In [None]:
_,axs = plt.subplots(1,2,figsize=(10,3.5))

for i in range(len(hs)):

  # data from this transformer block for one token
  thisBlock = hs[i][0,4,:].detach().numpy()

  # plot all the data
  axs[0].plot(np.ones(n_emb)*i,thisBlock,'ks',markerfacecolor=mpl.cm.plasma(i/len(hs)))

  # plot the norm
  axs[1].plot(i,np.linalg.norm(thisBlock),'ks',markerfacecolor=mpl.cm.plasma(i/len(hs)))

axs[0].set(xlabel='Hidden state layer',ylabel='Activation value',title='Hidden state activations for token #4')
axs[1].set(xlabel='Hidden state layer',ylabel='Matrix norm',title='Hidden state norms from token #4')

plt.tight_layout()
plt.show()

# Exercise 4: Scale up

In [None]:
# now scale by 10x
scaling_factor = 10
out = model(tokens,output_hidden_states=True)
hs = out.hidden_states


_,axs = plt.subplots(1,2,figsize=(10,3.5))

for i in range(len(hs)):

  # data from this transformer block for one token
  thisBlock = hs[i][0,4,:].detach().numpy()

  # plot all the data
  axs[0].plot(np.ones(n_emb)*i,thisBlock,'ks',markerfacecolor=mpl.cm.plasma(i/len(hs)))

  # plot the norm
  axs[1].plot(i,np.linalg.norm(thisBlock),'ks',markerfacecolor=mpl.cm.plasma(i/len(hs)))

axs[0].set(xlabel='Hidden state layer',ylabel='Activation value',title='Hidden state activations')
axs[1].set(xlabel='Hidden state layer',ylabel='Vector norm',title='Hidden state norms')

plt.tight_layout()
plt.show()

In [None]:
hookHandle.remove()

# Bonus! Example of output variable with more than just transformer outputs

In [None]:
# just a hook to print

def hook(module, input, output):

  # print info about the output variable
  print(f'output is type {type(output)} and has {len(output)} element(s).')

  # info about each element of output
  for i in range(len(output)):
    print(f'Element {i} has size {list(output[i].shape)}')

hookHandle = model.transformer.h[8].register_forward_hook(hook)

In [None]:
text = [ 'Here is the first sentence', 'Here is another one of a different length.', 'Shall we go for three?' ]
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer(text,padding=True,return_tensors='pt')
tokens

In [None]:
model(**tokens);

In [None]:
model.config._attn_implementation = 'eager'
model.config.output_attentions = True
model.config.output_hidden_states = True
model(**tokens);