|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[49] Successive median-replacement of MLP activations</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Tokens and activations**

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoModelForCausalLM.from_pretrained('gpt2-large').to(device)
tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
model.eval()

In [None]:
nlayers =
nneurons =
nneurons, nlayers

In [None]:
# dictionary to store the mlp activations
mlp_values = {}

def hook(module,input,output):
  mlp_values[f'L{whichlayer}'] =

# implantation surgery
whichlayer = 9
handle = model.transformer.h[whichlayer].mlp.c_fc.register_forward_hook(hook)

In [None]:
text = 'It was a dark and stormy'
target_idx =

tokens = tokenizer.

tokens.shape, tokens, target_idx

In [None]:
with torch.no_grad():
  out_clean = model()
handle.remove()

# calculate softmax probability in percent
logsm_clean = F.log_softmax().detach().cpu().numpy()

# check some sizes
print(f"mlp_values['L9'] has shape {mlp_values['L9'].shape}")
print(f'Output logits have shape {out_clean.logits.shape}')

In [None]:
plt.figure(figsize=(10,4))

# all the log-sm values
plt.plot(logsm_clean,'k.',markersize=2,alpha=.3)

# the target and nontarget values
plt.plot(,,'gs',label=tokenizer.decode(target_idx))

# make the graph look pretty :D
plt.gca().set(xlabel='Vocab elements',ylabel='Log softmax',xlim=[0,model.config.vocab_size])
plt.title(f'Predicted next token is "{}"',fontweight='bold')
plt.legend()

plt.tight_layout()
plt.savefig('ch7_proj49_part1.png')
plt.show()

# **Part 2: Exploring median-based replacement**

In [None]:
# copy of the data to manipulate
acts_replace = .clone()

# find the median and mean
med = torch.median()
mean = torch.

# find the top 10% and replace with median
idx = torch.topk().indices
acts_replace[idx] = med

# show the two histograms
plt.figure(figsize=(10,4))
binedges = np.linspace(-5,2,41)

# pre-replace histogram
y,x = np.histogram(mlp_values[f'L{whichlayer}'][0,-1,:],binedges)
plt.plot(x[:-1]-.02,y,'gs-',linewidth=.5,markerfacecolor=[.7,.9,.7,.7],label='Original')

# post-replace histogram
y,x = np.histogram(,)
plt.plot(,label='Replaced')
plt.axvline(med,linestyle='--',color='m',label='Median')
plt.axvline(mean,linestyle=':',color='k',label='Mean')


plt.gca().set(xlabel='Data value',ylabel='Count',xlim=binedges[[0,-1]],title='Impact of replacement')
plt.legend()

plt.tight_layout()
plt.savefig('ch7_proj49_part2.png')
plt.show()

# **Part 3: Impact of replacement on hidden states and logits**

In [None]:
# this vector is used for the rest of the project
pcts_to_replace = np.linspace(

# initialize results matrices
log_sm_targ_ex = np.zeros()
hs_diff_norms_ex = np.zeros((,))


for repli in range(


  ### ------ hook ------ ###
  def replace_hook(module,input,output):

    # find the median
    vals = output[0,-1,:]
    med =

    # replace top p% magnitude
    idx = torch.topk(
    output[0,-1,idx] = med
    return output

  # put hooks in all layers
  handles = []
  for layeri in range(0,nlayers):
    h = model.transformer.h[layeri].mlp.)
    handles.append(h)
  ### ------ end of hook ------ ###


  # forward pass to get output logits, and remove hook
  with torch.no_grad():
    out = model(tokens,output_hidden_states=True)
  # remove hooks

  # log-softmax of target
  log_sm_targ_ex[repli] = F.log_softmax()[target_idx]

  # norm of hidden-states differences
  for i in range(nlayers+1):
    hs = out.hidden_states[i].cpu().squeeze()[,]
    hs_clean = out_clean.hidden_states
    hs_diff_norms_ex[repli,i] =


In [None]:
fig,axs = plt.subplots(1,3,figsize=(12,3))

# heatmap of hidden-states changes
h = axs[0].imshow()
fig.colorbar(h,ax=axs[0],pad=.01)
axs[0].set(xlabel='Hidden states layer',ylabel='Percent neurons replaced',title='A) Impact on hidden states')

# accumulation of hidden-states impacts
axs[1].plot(np.mean(,axis=),'kh',markerfacecolor=[.7,.9,.7,.7],markersize=10)
axs[1].set(xlabel='Hidden state layer',ylabel='Vector norm difference',title='B) Averaging over replacement %')

# impact on output logits (log-softmax-transformed)
axs[2].plot(,,'kh',markerfacecolor=[.7,.7,.9,.7],markersize=10)
axs[2].axhline(logsm_clean[target_idx],color='k',linestyle='--')
axs[2].set(xlabel='Percent neurons replaced',ylabel='Log softmax to target',title='C) Impact on output logits')

plt.tight_layout()
plt.savefig('ch7_proj49_part3.png')
plt.show()

# **Part 4: Comparison with post-GELU expansion**

In [None]:
log_sm_targ_ge = np.zeros(len(pcts_to_replace))
hs_diff_norms_ge = np.zeros((len(pcts_to_replace),nlayers+1))


for repli in range(len(pcts_to_replace)):


  ### ------ hook ------ ###
  def replace_hook(module,input,output):

    # find the median
    vals = output[0,-1,:]
    med =

    # replace top p%
    idx =
    output[0,-1,idx] =
    return output

  # put hooks in all layers
  handles = []
  for layeri in range(0,nlayers):
    h = model.transformer.h[layeri].
    handles.append(h)
  ### ------ end of hook ------ ###


  # forward pass to get output logits, and remove hook
  with torch.no_grad():
    out = model(tokens,output_hidden_states=True)
  for h in handles: h.remove()

  # log-softmax of target
  log_sm_targ_ge[repli] = F.log_softmax(out.logits[0,-1,:].detach(),dim=-1)[target_idx]

  # norm of hidden-states differences
  for i in range(nlayers+1):
    hs = out.hidden_states[i].cpu().squeeze()[-1,:]
    hs_clean = out_clean.hidden_states[i].cpu().squeeze()[-1,:]
    hs_diff_norms_ge[repli,i] = torch.norm(hs-hs_clean)


In [None]:
fig,axs = plt.subplots(2,2,figsize=(10,6))

h = axs[0,0].imshow()
fig.colorbar(h,ax=axs[0,0],pad=.01)
axs[0,0].set(xlabel='Hidden states layer',ylabel='Percent neurons replaced',title='A) Replacing pre-GELU neurons')

h = axs[1,0].imshow()
fig.colorbar(h,ax=axs[1,0],pad=.01)
axs[1,0].set(xlabel='Hidden states layer',ylabel='Percent neurons replaced',title='B) Replacing post-GELU neurons')



axs[0,1].plot(,label='Pre-GELU')
axs[0,1].plot(,label='Post-GELU')
axs[0,1].set(xlabel='Hidden state layer',ylabel='Vector norm difference',title='C) Impact on hidden states')
axs[0,1].legend()

axs[1,1].plot(,label='pre-GELU')
axs[1,1].plot(,label='Post-GELU')
axs[1,1].axhline(logsm_clean[target_idx],color='k',linestyle='--',label='Clean model')
axs[1,1].set(xlabel='Percent neurons replaced',ylabel='Log softmax to target',title='D) Impact on output logits')
axs[1,1].legend(fontsize=8)


plt.tight_layout()
plt.savefig('ch7_proj49_part4.png')
plt.show()

# **Part 5: Comparison with MLP projections**

In [None]:
log_sm_targ_pr = np.zeros(len(pcts_to_replace))
hs_diff_norms_pr = np.zeros((len(pcts_to_replace),nlayers+1))


for repli in range(len(pcts_to_replace)):


  ### ------ hook ------ ###
  def replace_hook(module,input,output):

    # find the median
    vals =
    med =

    # replace top-magnitude p%
    idx = torch.topk
    output[0,-1,idx] = med
    return output

  # put hooks in all layers
  handles = []
  for layeri in range(0,nlayers):
    h = model.transformer.h[layeri].mlp.
    handles.append(h)
  ### ------ end of hook ------ ###


  # forward pass to get output logits, and remove hook
  with torch.no_grad():
    out = model(tokens,output_hidden_states=True)
  for h in handles: h.remove()

  # log-softmax of target
  log_sm_targ_pr[repli] = F.log_softmax(out.logits[0,-1,:].detach(),dim=-1)[target_idx]

  # norm of hidden-states differences
  for i in range(nlayers+1):
    hs = out.hidden_states[i].cpu().squeeze()[-1,:]
    hs_clean = out_clean.hidden_states[i].cpu().squeeze()[-1,:]
    hs_diff_norms_pr[repli,i] = torch.norm(hs-hs_clean)

In [None]:
fig,axs = plt.subplots(2,2,figsize=(10,6))

combi_exp = (hs_diff_norms_ex+hs_diff_norms_ge)/2
h = axs[0,0].imshow(combi_exp,aspect='auto',cmap='magma',origin='lower',vmax=combi_exp.max()*.8,
              extent=[0,nlayers,100*pcts_to_replace[0],100*pcts_to_replace[-1]])
fig.colorbar(h,ax=axs[0,0],pad=.01)
axs[0,0].set(xlabel='Hidden states layer',ylabel='Percent neurons replaced',title='A) Average of pre- and post-GELU')

h = axs[1,0].imshow(hs_diff_norms_pr,aspect='auto',cmap='magma',origin='lower',vmax=hs_diff_norms_pr.max()*.8,
              extent=[0,nlayers,100*pcts_to_replace[0],100*pcts_to_replace[-1]])
fig.colorbar(h,ax=axs[1,0],pad=.01)
axs[1,0].set(xlabel='Hidden states layer',ylabel='Percent projections replaced',title='B) Replacing projection dimensions')



axs[0,1].plot(np.mean(hs_diff_norms_ex,axis=0),'go-',linewidth=.5,markerfacecolor=[.7,.9,.7,.7],markersize=10,label='Expansion')
axs[0,1].plot(np.mean(hs_diff_norms_ge,axis=0),'rs-',linewidth=.5,markerfacecolor=[.9,.7,.7,.7],markersize=10,label='Post-GELU')
axs[0,1].plot(np.mean(hs_diff_norms_pr,axis=0),'b^-',linewidth=.5,markerfacecolor=[.7,.7,.9,.7],markersize=10,label='Projection')
axs[0,1].set(xlabel='Hidden state layer',ylabel='Vector norm difference',title='C) Impact on hidden states')
axs[0,1].legend()

axs[1,1].plot(100*pcts_to_replace,log_sm_targ_ex,'go-',linewidth=.5,markerfacecolor=[.7,.9,.7,.7],markersize=10,label='Expansion')
axs[1,1].plot(100*pcts_to_replace,log_sm_targ_ge,'rs-',linewidth=.5,markerfacecolor=[.9,.7,.7,.7],markersize=10,label='Post-GELU')
axs[1,1].plot(100*pcts_to_replace,log_sm_targ_pr,'b^-',linewidth=.5,markerfacecolor=[.7,.7,.9,.7],markersize=10,label='Projection')
axs[1,1].axhline(logsm_clean[target_idx],color='k',linestyle='--',label='Clean model')
axs[1,1].set(xlabel='Percent replaced',ylabel='Log softmax to target',title='D) Impact on output logits')
axs[1,1].legend(fontsize=9)


plt.tight_layout()
plt.savefig('ch7_proj49_part5.png')
plt.show()