|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Modifying MLP<h1>|
|<h2>Lecture:</h2>|<h1><b>Successive median-replacement of MLP neurons<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GPT2LMHeadModel.from_pretrained('gpt2-large').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')

model.eval()

In [None]:
nlayers = len(model.transformer.h)
nneurons = model.transformer.h[3].mlp.c_fc.weight.shape[-1]
nneurons, nlayers

# Forward pass and get activations

In [None]:
# dictionary to store the mlp activations
mlp_values = {}

def hook(module, input, output):
  mlp_values[f'L{whichlayer}'] = output[0].detach().cpu() # detach from the computational graph

# surgery ;)
whichlayer = 9
handle = model.transformer.h[whichlayer].mlp.c_fc.register_forward_hook(hook)

In [None]:
text = 'It was a dark and stormy'
target_idx = tokenizer.encode(' night')[0]

# text = 'The cat sat on the'
# target_idx = tokenizer.encode(' mat')[0]

tokens = tokenizer.encode(text,return_tensors='pt').to(device)

tokens.shape, tokens, target_idx

In [None]:
with torch.no_grad():
  out_clean = model(tokens,output_hidden_states=True)

# calculate softmax probability in percent
logsm_clean = F.log_softmax(out_clean.logits[0,-1,:],dim=-1).detach().cpu().numpy()

handle.remove()

mlp = mlp_values[f'L{whichlayer}']
mlp.shape

In [None]:
# note that these are pre-gelu
plt.figure(figsize=(8,4))
plt.hist(mlp.flatten(),np.linspace(-7,7,101),edgecolor='k',facecolor='gray')

plt.gca().set(xlabel='Activation value',ylabel='Count (a.u.)',yticks=[],
              title=f'MLP activations from layer {whichlayer}')
plt.show()

In [None]:
# exploring median-based replacement
t = mlp[-1,:]
med = torch.median(t)
idx = torch.topk(t,int(.1*nneurons)).indices
trepl = t.clone()
trepl[idx] = med

# show the two histograms
plt.figure(figsize=(10,4))
binedges = np.linspace(-6,3,41)

# pre-replace histogram
y,x = np.histogram(t,binedges)
plt.plot(x[:-1]-.02,y,'ks-',markerfacecolor='w',label='Original')

# post-replace histogram
y,x = np.histogram(trepl,binedges)
plt.plot(x[:-1]+.02,y,'bo-',markerfacecolor='w',label='Replaced')

plt.gca().set(xlabel='Data value',ylabel='Count',xlim=binedges[[0,-1]],title='Impact of replacement')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,4))

# all the log-sm values
plt.plot(logsm_clean,'k.',markersize=2,alpha=.3)

# the target and nontarget values
plt.plot(target_idx,logsm_clean[target_idx],'gs',label=tokenizer.decode(target_idx))

# make the graph look pretty :D
plt.gca().set(xlabel='Vocab elements',ylabel='Log softmax',xlim=[0,model.config.vocab_size])
plt.title(f'Predicted next token is "{tokenizer.decode(np.argmax(logsm_clean))}"',fontweight='bold')
plt.legend()

plt.show()

# Replace up to 90% of MLP neurons per layer

In [None]:
log_sm = np.zeros((nlayers,9))

# loop over layers
for layeri in range(nlayers):

  for replrate in range(1,10):

    # replace this layer
    def replace_hook(module, input, output):

      # find the median
      vals = output[0,-1,:]
      med = torch.median(vals)

      # replace top p%
      idx = torch.topk(vals,int((replrate/10)*nneurons)).indices
      output[0,-1,idx] = med
      return output

    handle = model.transformer.h[layeri].mlp.c_fc.register_forward_hook(replace_hook)

    # forward pass to get output logits, and remove hook
    with torch.no_grad(): out = model(tokens)
    handle.remove()

    log_sm[layeri,replrate-1] = F.log_softmax(out.logits[0,-1,:].detach(),dim=-1)[target_idx]

In [None]:
plt.figure(figsize=(10,4))

for i in range(nlayers):
  plt.plot(np.random.randn(9)/20+i,logsm_clean[target_idx]-log_sm[i,:],'ko',alpha=.4,
           markerfacecolor=mpl.cm.plasma(i/nlayers),markersize=10)

plt.axhline(0,color='k',linewidth=.5,linestyle='--',zorder=-13)

plt.gca().set(xlabel='Transformer block',ylabel='Logit difference from clean',ylim=[-.1,.2])
plt.show()