|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating neurons and dimensions<h1>|
|<h2>Lecture:</h2>|<h1><b>Relation between hooks and output.hidden_states<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer

model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [None]:
model

# Implant hooks for the attention sublayer activations

In [None]:
# hook function to store attention vectors
activations = {}

def implant_hook_attn(layer_number):
  def hook(module, input, output):
    activations[f'attn_proj_{layer_number}'] = output.detach()
  return hook

# and mlp layers
def implant_hook_mlp(layer_number):
  def hook(module, input, output):
    activations[f'mlp_proj_{layer_number}'] = output.detach()
  return hook

# implant hooks
layer2hook = 10

model.transformer.h[layer2hook].attn.c_proj.register_forward_hook(implant_hook_attn(layer2hook))
model.transformer.h[layer2hook].mlp.c_proj.register_forward_hook(implant_hook_mlp(layer2hook))

# "Hook" the activations

In [None]:
text = "Plants make fantastic pets because they don't leave a mess after you feed them."
tokens = tokenizer.encode(text,return_tensors='pt')

# forward pass to trigger the hook
with torch.no_grad(): outputs = model(tokens,output_hidden_states=True)

# Reconstructing hidden_state from attn+mlp hooks

In [None]:
print('Hidden state is size: ',outputs.hidden_states[layer2hook].shape)
print('Hooked MLP is size: ',activations[f'mlp_proj_{layer2hook}'].shape)

In [None]:
# get the activations of the current and hidden states
hs_curr = outputs.hidden_states[layer2hook].detach()
hs_next = outputs.hidden_states[layer2hook+1].detach()

# extract the attention and mlp deltas (updates to current embedding)
attn_delta = activations[f'attn_proj_{layer2hook}']
mlp_delta  = model.transformer.h[layer2hook].mlp.dropout(activations[f'mlp_proj_{layer2hook}'])


# create a threshold mask to remove extreme activation values (helps with the demo)
threshVal = 30
threshMask = np.full(hs_curr.shape,True)
for mat in [hs_curr,hs_next,attn_delta,mlp_delta]:
  threshMask[abs(mat)>threshVal] = False

In [None]:
# reconstruct the next layer
reconstruction = hs_curr + attn_delta + mlp_delta

In [None]:
# comparisons!
_,axs = plt.subplots(2,2,figsize=(10,8))


# correlate current and next layer hidden states
R = np.corrcoef(hs_curr[threshMask],hs_next[threshMask])[0,1]
axs[0,0].plot(hs_curr[threshMask],hs_next[threshMask],'ko',markerfacecolor=[.7,.7,.9,.6])
axs[0,0].set(xlabel=f'Layer {layer2hook}',ylabel=f'Layer {layer2hook+1}',title=f'r = {R:.3f}')


# correlate attention and MLP deltas
R = np.corrcoef(attn_delta[threshMask],mlp_delta[threshMask])[0,1]
axs[0,1].plot(attn_delta[threshMask],mlp_delta[threshMask],'ko',markerfacecolor=[.7,.9,.7,.6])
axs[0,1].set(xlabel='Attention $\Delta$',ylabel='MLP $\Delta$',title=f'r = {R:.3f}')


# correlate reconstruction and next layer
R = np.corrcoef(reconstruction[threshMask],hs_next[threshMask])[0,1]
axs[1,0].plot(reconstruction[threshMask],hs_next[threshMask],'ko',markerfacecolor=[.7,.9,.7,.6])
axs[1,0].set(xlabel=f'Layer {layer2hook} + $\Delta$(att) + $\Delta$(mlp)',ylabel=f'Layer {layer2hook+1}',title=f'r = {R:.3f}')


# correlate difference of layers with sum of deltas
layer_diff = (hs_next - hs_curr)[threshMask]
attention_sum = (attn_delta + mlp_delta)[threshMask]
R = np.corrcoef(layer_diff,attention_sum)[0,1]
axs[1,1].plot(layer_diff,attention_sum,'ko',markerfacecolor=[.9,.7,.7,.6])
axs[1,1].set(xlabel='Difference of layers',ylabel=r"Sum of attention $\Delta$'s",title=f'r = {R:.3f}')

plt.tight_layout()
plt.show()