|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[46] Statistics-based lesioning in MLP neurons</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import scipy.stats as stats
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Get MLP activations for him vs. her**

In [None]:
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

In [None]:
nneurons = model.bert.
print(f'There are {nneurons} units it the expansion layer.')

In [None]:
# FYI, for hook location
model.bert.encoder.layer[4].intermediate

In [None]:
# dictionary to store the mlp activations
mlp_values = {}

def hook(module,input,output):
  mlp_values[f'L{whichlayer}'] =

# surgery ;)
whichlayer = 9
handle = model.bert.encoder.

In [None]:
# generated by Claude.ai
sentences = [
    "I saw him at the market.",
    "She gave him the book.",
    "They asked him for advice.",
    "We invited him to dinner.",
    "The dog followed him home.",
    "They asked him to join.",
    "He saw him at the park yesterday.",
    "Did you give him your address?",
    "I haven't seen him in ages.",
    "I told him the truth.",
    "They congratulated him on his success.",
    "She recognized him immediately.",
    "The teacher praised him for his work.",
    "I met him last summer.",
    "The child hugged him tightly.",
    "They warned him about the danger.",
    "She drove him to the airport.",
    "We waited for him for hours.",
    "The cat scratched him accidentally.",
    "They surprised him with a gift.",
    "She called him on the phone.",
    "The jury found him not guilty.",
    "I remembered him from school.",
    "They elected him as president.",
    "She forgave him for his mistake.",
    "The police questioned him yesterday.",
    "I helped him with his homework.",
    "They spotted him in the crowd.",
    "She visited him in the hospital.",
    "The manager promoted him last week.",
    "I trusted him completely.",
    "They respected him for his honesty.",
    "She taught him how to swim.",
    "The bird attacked him suddenly.",
    "I greeted him warmly.",
    "They supported him through difficult times.",
    "She ignored him at the party.",
    "The judge sentenced him to community service.",
    "I photographed him during the event.",
    "They believed him despite the evidence.",
    "She surprised him on his birthday.",
    "The guard stopped him at the entrance.",
    "I missed him terribly.",
    "They watched him leave the building.",
    "She accompanied him to the concert.",
    "The crowd cheered him enthusiastically.",
    "I described him to the police.",
    "They thanked him for his help.",
    "She admired him for his courage.",
    "The committee nominated him for the award.",
    "I married him last spring.",
    "They informed him about the changes.",
    "She introduced him to the parents.",
    "The author based the character on him.",

## same sentences but with "her"

    "I saw her at the market.",
    "She gave her the book.",
    "They asked her for advice.",
    "We invited her to dinner.",
    "The dog followed her home.",
    "They asked her to join.",
    "He saw her at the park yesterday.",
    "Did you give her your address?",
    "I haven't seen her in ages.",
    "I told her the truth.",
    "They congratulated her on his success.",
    "She recognized her immediately.",
    "The teacher praised her for his work.",
    "I met her last summer.",
    "The child hugged her tightly.",
    "They warned her about the danger.",
    "She drove her to the airport.",
    "We waited for her for hours.",
    "The cat scratched her accidentally.",
    "They surprised her with a gift.",
    "She called her on the phone.",
    "The jury found her not guilty.",
    "I remembered her from school.",
    "They elected her as president.",
    "She forgave her for his mistake.",
    "The police questioned her yesterday.",
    "I helped her with his homework.",
    "They spotted her in the crowd.",
    "She visited her in the hospital.",
    "The manager promoted her last week.",
    "I trusted her completely.",
    "They respected her for his honesty.",
    "She taught her how to swim.",
    "The bird attacked her suddenly.",
    "I greeted her warmly.",
    "They supported her through difficult times.",
    "She ignored her at the party.",
    "The judge sentenced her to community service.",
    "I photographed her during the event.",
    "They believed her despite the evidence.",
    "She surprised her on his birthday.",
    "The guard stopped her at the entrance.",
    "I missed her terribly.",
    "They watched her leave the building.",
    "She accompanied her to the concert.",
    "The crowd cheered her enthusiastically.",
    "I described her to the police.",
    "They thanked her for his help.",
    "She admired her for his courage.",
    "The committee nominated her for the award.",
    "I married her last spring.",
    "They informed her about the changes.",
    "She introduced her to the parents.",
    "The author based the character on her."
]

print(f'There are {len(sentences)} sentences.')

In [None]:
# identify the target token
target_token_him = tokenizer.encode('him'
target_token_her =
print(f'The target token indices are {target_token_him} and {target_token_her}\n')

# tokenize
tokens =

In [None]:
tokens

In [None]:
# prepare a vector of target indices per sentence as a torch tensor
target_indices = torch.zeros()

# loop over sentences
for senti in range(len(sentences)):
  targBool = torch.isin(,)
  target_indices[senti] = torch.where()[0]
  # torch.where()[0] works here b/c each sentence contains exactly one occurrance of the target.

target_indices

In [None]:
with torch.no_grad():
  model

handle.remove()

mlp_values[f'L{whichlayer}'].shape

In [None]:
# loop through sentences to get target activations

acts = np.zeros((len(sentences),mlp_values[].shape[]))

# get the activations per sentence
for senti in range(len(sentences)):
  acts[senti,:] = mlp_values[f'L{whichlayer}'][,,]

acts.shape

# **Part 2: T-test in one layer**

In [None]:
# t-test and find significant neurons via FDR (correction for multiple comparisons)
tres = stats.ttest_ind(acts[,],acts[,],axis=0)
issig = stats.()<

# find the neurons
himNeurons = issig &
herNeurons =

In [None]:
# setup the figure
fig = plt.figure(figsize=(12,4))
gs = GridSpec(1,3,figure=fig)
ax0 = fig.add_subplot(gs[:2])
ax1 = fig.add_subplot(gs[2])

# draw the scatter plot
ax0.plot(,label='Non-sig.')
ax0.plot(,label='him > her')
ax0.plot(,label='her > him')

ax0.legend()
ax0.set(xlim=[-1,nneurons],xlabel='MLP expansion neurons',ylabel='T-value',
              title=f'{himNeurons.sum()} "him" neurons and {herNeurons.sum()} "her" neurons')


# and the pie chart
ax1.pie([.sum(),.sum(),-.sum()-.sum()],autopct='%1.1f%%',
           labels=['"him" significant','"her" significant','Non-significant'],
           colors=[[.7,.9,.7],[.9,.7,.7,.7],[1,0,0,.7]],wedgeprops={'edgecolor':'k'})

plt.tight_layout()
plt.savefig('ch7_proj46_part2.png')
plt.show()

# **Part 3: T-tests in all layers**

In [None]:
nlayers = len(model.bert.encoder.layer)

In [None]:
# dictionary to store the mlp t-test results
mlpTs = {}

# hook function that runs a t-test on MLP activations
def implant_hook(layer_number):
  def hook(module,input,output):

    # detach activations
    mlpVals =

    # matrix of target activation values
    acts = np.zeros((len(sentences),mlpVals.shape[2]))
    for senti in range(len(sentences)):
      acts[senti,:] = mlpVals[,,:]

    # t-test and find significance
    tres = stats.ttest_ind(
    issig =

    # store the results
    mlpTs[f'L{layer_number}_him'] =
    mlpTs[f'L{layer_number}_her'] =

    # ### for Part 7 (leave commented before Part 7!)
    # numsig_pos = (issig & (tres.statistic>0)).sum()
    # numsig_neg = (issig & (tres.statistic<0)).sum()
    # mlpTs[f'L{layer_number}_him'] = np.argsort(abs(tres.statistic))[:numsig_pos]
    # mlpTs[f'L{layer_number}_her'] = np.argsort(abs(tres.statistic))[:numsig_neg]

  return hook


# implant into all layers
handles = []
for layeri in range(nlayers):
  h = model.bert.encoder.layer[layeri].intermediate.dense.register_forward_hook(
  handles.

In [None]:
# forward pass
with torch.no_grad(): model(**tokens)

# remove handles


In [None]:
mlpTs.keys(), mlpTs['L4_her'].shape

In [None]:
plt.figure(figsize=(10,3))

# draw the percentage of significant t-tests per layer
for i in range(nlayers):
  plt.plot(i,100*mlpTs[].sum() / ,'ko',markerfacecolor=[.9,.9,.7,.7],markersize=12)
  plt.plot(i,100*mlpTs[].sum() / ,'ks',markerfacecolor=[.7,.7,.9,.7],markersize=12)


# hacky solution to get legend
plt.plot(100,100*mlpTs[f'L{i}_her'].sum() / nneurons,'ko',markerfacecolor=[.9,.9,.7,.7],markersize=12,label='her > him')
plt.plot(100,100*mlpTs[f'L{i}_him'].sum() / nneurons,'ks',markerfacecolor=[.7,.7,.9,.7],markersize=12,label='him > her')
plt.legend()

plt.gca().set(xlim=[-.5,nlayers-.5],xlabel='Layers',ylabel='% sig. t-values')

plt.tight_layout()
plt.savefig('ch7_proj46_part3.png')
plt.show()

# **Part 4: Gender word predictions in an independent dataset**

In [None]:
texts = [ 'Robert helped Lucy with her project, and she thanked him for his hard work.',
          'Robert helped Lucy with [MASK] project, and she thanked him for his hard work.',
          'Robert helped Lucy with her project, and she thanked [MASK] for his hard work.' ]

# tokenize
testtokens = tokenizer(texts,return_tensors='pt')
testtokens

In [None]:
# find indices of [MASK]
mask_idx_her = torch.where(
mask_idx_him = torch.where(

print(f'Masks are at indices {mask_idx_her} and {mask_idx_him}')

In [None]:
with torch.no_grad():
  out = model(**testtokens)

logits = out.
logits.shape

In [None]:
target_logits_clean = np.zeros((3,2,2))

for senti in range(3):
  target_logits_clean[senti,0,:] = logits[,,[,]]
  target_logits_clean[senti,1,:] = logits[,,[,]]

target_logits_clean

In [None]:
plt.figure(figsize=(10,4))

for i in range(3):
  plt.bar(np.array([-.1,.1])+i*1.5,target_logits_clean[,,],width=.2,facecolor=[[.9,.3,.9],[.3,.9,.9]],edgecolor='k')
  plt.bar(np.array([-.1,.1])+i*1.5+.5,target_logits_clean[,,],width=.2,facecolor=[[.9,.7,.9],[.7,.9,.9]],edgecolor='k')

# create the bar labels
basetxt = 'her || him     her || him\n-------------------+-------------------\nher position   ||   him position'
xticklabels = [ basetxt + '\n\n|_______$\\bf{Clean\\; sentence}$______|',
                basetxt + '\n\n|____$\\bf{HER\\; mask\\; sentence}$____|',
                basetxt + '\n\n|____$\\bf{HIM\\; mask\\; sentence}$____|' ]

plt.gca().set(xticks=np.arange(.25,3.5,1.5),xticklabels=xticklabels,ylabel='Logits')

plt.tight_layout()
plt.savefig('ch7_proj46_part4.png')
plt.show()

# **Part 5: Ablate "him" and "her" neurons in MLP**

In [None]:
# run on the final transformer layer ("-1" b/c of 0-indexing)
whichlayer = nlayers - 1

def ablation_hook(module,input,output):
  output[1,mask_idx_her,mlpTs[f'L{whichlayer}_her']] =
  output[2, =
  return output

# implant
handle = model.bert.encoder.layer[whichlayer].

In [None]:
with torch.no_grad():
  out = model(**testtokens)

handle.remove()
logitsZero = out.
logitsZero.shape

In [None]:
target_logitsZ = np.zeros((3,2,2)) # Z = zeroed

for senti in range(3):
  target_logitsZ[senti,0,:] = logitsZero[
  target_logitsZ[senti,1,:] =

In [None]:
fig = plt.figure(figsize=(13,4.5))
gs = GridSpec(1,3,figure=fig)
ax0 = fig.add_subplot(gs[:2])
ax1 = fig.add_subplot(gs[2])

# calculate the difference in logits
deltaLogits =

# show the bar plots
for i in range(3):
  ax0.bar(,,width=.2,facecolor=[[.9,.3,.9],[.3,.9,.9]],edgecolor='k')
  ax0.bar(,,width=.2,facecolor=[[.9,.7,.9],[.7,.9,.9]],edgecolor='k')

ax0.axhline(0,color='k',linewidth=.2)
ax0.set(xticks=np.arange(.25,3.5,1.5),xticklabels=xticklabels,ylabel='$\\mathbf{\\Delta}$ logits',
        title=f'A) Clean - ablated logits (ablation in layer {whichlayer})')

# scatter plot showing modulation impact
ax1.plot(,,'ko',markerfacecolor=[.7,.9,.9,.6],markersize=10)
ax1.set(xlabel='Clean logits',ylabel='Ablation logits',title='B) Clean vs. ablated model logits')
ax1.grid(linewidth=.4)

plt.tight_layout()
plt.savefig('ch7_proj46_part5.png')
plt.show()

# **Part 6: Laminar nullification of MLP neurons**

In [None]:
# results are (1) magnitude of modulation, (2) "her" impact, (3) "him" impact
results = np.zeros((nlayers,3))


# loop over layers
for layeri in range(nlayers):

  # patch this layer
  def mlp_ablate_hook(module, input, output):
    # zero-out the "her neurons" on the HER token, and the "him neurons" on the HIM token
    output[1,mask_idx_her,mlpTs[f'L{layeri}_her'
    output[2,
    return output
  handle = model.bert.encoder.layer[layeri].intermediate.dense.register_forward_hook(mlp_ablate_hook)

  # forward pass to get output logits, and remove hook
  with torch.no_grad(): out=model(**testtokens)
  logitsZero =
  handle

  # get the logits for the target tokens
  target_logitsZ = np.zeros((,,)) # ("Z" for zeroed out)

  for senti in range(3):
    target_logitsZ[senti,0,:] =
    target_logitsZ[senti,1,:] =
  deltaLogits =  # difference between clean and ablation logits

  # measure the total magnitude change
  results[layeri,0] = np.mean(

  # specific modulations
  results[layeri,1] = deltaLogits[
  results[layeri,2] = deltaLogits[


In [None]:
fig,axs = plt.subplots(1,2,figsize=(12,4))

axs[0].plot(results[:,0],'kh',markerfacecolor=[.9,.7,.7],markersize=12)
axs[0].set(xlabel='Transformer block',ylabel='Magnitude change in logits',
           ylim=[0,None],title='A) Overall impact of manipulation on target logits')

axs[1].plot(,label='HER manipulation')
axs[1].plot(,label='HIM manipulation')
axs[1].axhline(0,color='k',zorder=-5,linewidth=.5)
axs[1].set(xlabel='Transformer block',ylabel='Signed change in logits',title='B) Clean - ablated target $\\mathbf{\\Delta}$ logits')
axs[1].legend()

plt.tight_layout()
plt.savefig('ch7_proj46_part6.png')
plt.show()

# **Part 7: Ablate the least-tuned neurons**