|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[44] Grammar tuning in MLP neurons</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import scipy.stats as stats
!pip install pingouin
import pingouin as pg # for effect size calculations

import requests

import torch
from transformers import AutoModelForCausalLM,AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Get nouns and verbs**

In [None]:
# Eleuther's tokenizer
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125m')

# load in GPTneo
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125m')
model.eval()

In [None]:
# main repo: https://github.com/david47k/top-english-wordlists/

# lists of verbs
url = 'https://raw.githubusercontent.com/david47k/top-english-wordlists/refs/heads/master/top_english_verbs_lower_10000.txt'
all_verbs = requests.get(url).text.split('\n')

# initialize as empty list
verbs = []
len_verbs = []

# loop over all the verbs
for word in all_verbs:

  # tokenize with preceding space
  tok = tokenizer.encode(f' {word}')

  # add to the list if its single-token
  if len(tok)==1:
    verbs.append(tok[0])
    len_verbs.append(len(word))


# split by odd/even
verbs_split1 = verbs[:2000:2]
verbs_split2 = verbs[1:2000:2]

# and print
print(f'{len(verbs)} out of {len(all_verbs)} verbs are single-token.')
print(f'There are {len(verbs_split1)} split-1 and {len(verbs_split2)} split-2 samples.')

In [None]:
# Note: the .split('\n') method adds an extra element at the end, which is why there seems to be 10,001 verbs:
all_verbs[-1]

In [None]:
# repeat for nouns
url = 'https://raw.githubusercontent.com/david47k/top-english-wordlists/refs/heads/master/top_english_nouns_lower_10000.txt'
all_nouns = requests.get(url).text.split('\n')

# initialize as empty list
nouns = []
len_nouns = []

# loop over all the nouns
for word in all_nouns:

  # tokenize with preceding space
  tok = tokenizer.encode(f' {word}')

  # add to the list if its single-token
  if len(tok)==1:
    nouns.append(tok[0])
    len_nouns.append(len(word))

# split by odd/even
nouns_split1 = nouns[:2000:2]
nouns_split2 = nouns[1:2000:2]

# and print
print(f'{len(nouns)} out of {len(all_nouns)} nouns are single-token.')
print(f'There are {len(nouns_split1)} split-1 and {len(nouns_split2)} split-2 samples.')

In [None]:
print('First 5 split-1 verbs:')
print([tokenizer.decode(v) for v in verbs_split1[:5]])

print('\nFirst 5 split-2 verbs:')
print([tokenizer.decode(v) for v in verbs_split2[:5]])


print('\n\nFirst 5 split-1 nouns:')
print([tokenizer.decode(n) for n in nouns_split1[:5]])

print('\nFirst 5 split-2 nouns:')
print([tokenizer.decode(n) for n in nouns_split2[:5]])

In [None]:
# check lengths

# convenient to have in numpy
len_nouns = np.array(len_nouns[:2000])
len_verbs = np.array(len_verbs[:2000])

yN = np.bincount(len_nouns)
yV = np.bincount(len_verbs)

plt.figure(figsize=(10,3))
plt.bar(np.arange(len_nouns.max()+1)-.15,yN,width=.6,label='Nouns',alpha=.9,edgecolor='b')
plt.bar(np.arange(len_verbs.max()+1)+.15,yV,width=.6,label='Verbs',alpha=.9,edgecolor='r')

tres = stats.ttest_ind(len_nouns,len_verbs)
cohensd = pg.compute_effsize(len_nouns,len_verbs,paired=False,eftype='cohen')

plt.gca().set(xticks=range(np.max(len_nouns)),xlabel='Number of characters',ylabel='Count',
              title=f"t({tres.df:g}) = {tres.statistic:.2f}, p = {tres.pvalue:.3f}\nCohen's d = {abs(cohensd):.3f}")

plt.tight_layout()
plt.savefig('ch7_proj44_part1.png')
plt.show()

# **Part 2: Implant a hook and get activations**

In [None]:
model

In [None]:
# a hook function to grab the activations
mlp_acts = {}

def hook(module,input,output):
  mlp_acts[f'{whichdata}'] = output.detach().numpy().squeeze()

handle = model.transformer.h[8].mlp.c_fc.register_forward_hook(hook)

In [None]:
# make a batch
torch.tensor(nouns_split1).unsqueeze(1).shape

In [None]:
# this cell takes around 20 s

# reinialize data-dictionary
mlp_acts = {}

with torch.no_grad():

  # run the split1 nouns
  whichdata = 'nouns_split1'
  model(torch.tensor(nouns_split1).unsqueeze(1))

  # split2 nouns
  whichdata = 'nouns_split2'
  model(torch.tensor(nouns_split2).unsqueeze(1))

  # the split1 verbs
  whichdata = 'verbs_split1'
  model(torch.tensor(verbs_split1).unsqueeze(1))

  # and the split2 verbs
  whichdata = 'verbs_split2'
  model(torch.tensor(verbs_split2).unsqueeze(1))


In [None]:
for k,v in mlp_acts.items():
  print(f"mlp_acts['{k}'] has shape {list(v.shape)}")

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3.5))

axs[0].imshow(mlp_acts['nouns_split1'],aspect='auto',vmin=-2,vmax=2,cmap='plasma')
axs[0].set(xlabel='Neurons',ylabel='Nouns (index)',title='A) Nouns activations')

axs[1].plot(mlp_acts['nouns_split1'].mean(axis=0),'ko',markersize=5,markerfacecolor=[.9,.7,.9,.5])
axs[1].set(xlabel='Neurons',ylabel='Activation',title='B) Mean activations over all nouns')

axs[2].plot(mlp_acts['nouns_split1'].mean(axis=0),mlp_acts['verbs_split1'].mean(axis=0),
            'ko',markersize=5,markerfacecolor=[.9,.7,.9,.5])
axs[2].set(xlabel='Nouns',ylabel='Verbs',title='C) Activations to all words')

plt.tight_layout()
plt.savefig('ch7_proj44_part2.png')
plt.show()

# **Part 3: T-tests on split-1 data**

In [None]:
nneurons = mlp_acts['nouns_split1'].shape[-1]
nneurons

In [None]:
# t-test on all neurons
T_split1 = stats.ttest_ind(mlp_acts['nouns_split1'],
                           mlp_acts['verbs_split1'],axis=0)

# Cohen's d
cohensd = np.zeros(nneurons)
for i in range(nneurons):
  cohensd[i] = pg.compute_effsize(mlp_acts['nouns_split1'][:,i],mlp_acts['verbs_split1'][:,i],
                                  paired=False,eftype='cohen')

# plot
plt.plot(cohensd,T_split1.statistic,'ko',markerfacecolor='w')
plt.gca().set(xlabel="Cohen's d",ylabel='T-value')

plt.tight_layout()
plt.savefig('ch7_proj44_part3a.png')
plt.show()

In [None]:
plt.figure(figsize=(10,4))

# pvalues FDR corrected
sigPvals1 = stats.false_discovery_control(T_split1.pvalue)<.05

# plot the significant neurons
plt.plot(np.where(sigPvals1)[0],T_split1.statistic[sigPvals1],'go',markerfacecolor='w')

# significant and large effect size (Cohen's d>.8)
plt.plot(np.where(abs(cohensd)>.8)[0],T_split1.statistic[abs(cohensd)>.8],'go')

# non-significant
plt.plot(np.where(sigPvals1==False)[0],T_split1.statistic[sigPvals1==False],'rx')

# adjustments
plt.gca().set(xlabel='Neuron index',ylabel='T-value',xlim=[-10,nneurons+10],
              title=f'{np.sum(sigPvals1)}/{len(sigPvals1)} were significant, {np.sum(abs(cohensd)>.8)}/{len(cohensd)} were large effects.')

plt.tight_layout()
plt.savefig('ch7_proj44_part3b.png')
plt.show()

# **Part 4: T-tests on split-2 data**

In [None]:
# in split 2
T_split2 = stats.ttest_ind(mlp_acts['nouns_split2'],
                           mlp_acts['verbs_split2'],axis=0)

# across the two splits
T_split12 = stats.ttest_ind(mlp_acts['nouns_split1'],
                            mlp_acts['verbs_split2'],axis=0)

In [None]:
# Bonferroni correction
sigPthresh = .05 / len(T_split2.pvalue)

# find where one or both are significant
bothSig_2  = (T_split1.pvalue<sigPthresh).astype(int) + (T_split2.pvalue<sigPthresh).astype(int)
bothSig_12 = (T_split1.pvalue<sigPthresh).astype(int) + (T_split12.pvalue<sigPthresh).astype(int)

# correlations between t-values
r_2  = np.corrcoef(T_split1.statistic,T_split2.statistic)[0,1]
r_12 = np.corrcoef(T_split1.statistic,T_split12.statistic)[0,1]


# visualizations
_,axs = plt.subplots(1,2,figsize=(8,3.5))

# split-1 vs. split-2
axs[0].plot(T_split1.statistic[bothSig_2==2],T_split2.statistic[bothSig_2==2],'ks',markerfacecolor=[.7,.9,.7,.5],markersize=5,label='Both sig.')
axs[0].plot(T_split1.statistic[bothSig_2==0],T_split2.statistic[bothSig_2==0],'rx',markersize=3,alpha=.5,label='Neither sig')
axs[0].plot(T_split1.statistic[bothSig_2==1],T_split2.statistic[bothSig_2==1],'ko',markerfacecolor=[.9,.7,.7,.5],markersize=4,label='One sig.')

# split-1 vs. split-12
axs[1].plot(T_split1.statistic[bothSig_12==2],T_split12.statistic[bothSig_12==2],'ks',markerfacecolor=[.7,.9,.7,.5],markersize=5,label='Both sig.')
axs[1].plot(T_split1.statistic[bothSig_12==0],T_split12.statistic[bothSig_12==0],'rx',markersize=3,alpha=.5,label='Neither sig')
axs[1].plot(T_split1.statistic[bothSig_12==1],T_split12.statistic[bothSig_12==1],'ko',markerfacecolor=[.9,.7,.7,.5],markersize=4,label='One sig.')

# axis adjustments
axs[0].set(xlabel='Split-1 t-value',ylabel='Split-2 t-value',title=f'A) T-val comparison (r = {r_2:.3f})')
axs[1].set(xlabel='Split-1 t-value',ylabel='Split-12 t-value',title=f'B) T-val comparison (r = {r_12:.3f})')

# common adjustments
for a in axs:
  a.axhline(0,color='k',linestyle='--',linewidth=.5)
  a.axvline(0,color='k',linestyle='--',linewidth=.5)
  a.legend()


plt.tight_layout()
plt.savefig('ch7_proj44_part4.png')
plt.show()

# **Part 5: Investigating distributions**

In [None]:
# extract histograms
yNouns1,xNouns1 = np.histogram(mlp_acts['nouns_split1'],bins='fd')
yNouns2,xNouns2 = np.histogram(mlp_acts['nouns_split2'],bins='fd')
yVerbs1,xVerbs1 = np.histogram(mlp_acts['verbs_split1'],bins='fd')
yVerbs2,xVerbs2 = np.histogram(mlp_acts['verbs_split2'],bins='fd')

# and visualize them
plt.figure(figsize=(9,3))
plt.plot(xNouns1[:-1],yNouns1,linewidth=2,label='Nouns 1')
plt.plot(xNouns2[:-1],yNouns2,linewidth=2,label='Nouns 2')
plt.plot(xVerbs1[:-1],yVerbs1,linewidth=2,label='Verbs 1')
plt.plot(xVerbs2[:-1],yVerbs2,linewidth=2,label='Verbs 2')

plt.legend()
plt.gca().set(xlabel='MLP expansion activation',ylabel='Count',ylim=[0,None],
              title='Histograms of all MLP neurons')

plt.tight_layout()
plt.savefig('ch7_proj44_part5a.png')
plt.show()

In [None]:
# histograms of t>0 and t<0 subpopulations
yNouns1_neg,xNouns1_neg = np.histogram(mlp_acts['nouns_split1'][:,T_split1.statistic<0],bins='fd',density=True)
yVerbs1_neg,xVerbs1_neg = np.histogram(mlp_acts['verbs_split1'][:,T_split1.statistic<0],bins='fd',density=True)
yNouns1_pos,xNouns1_pos = np.histogram(mlp_acts['nouns_split1'][:,T_split1.statistic>0],bins='fd',density=True)
yVerbs1_pos,xVerbs1_pos = np.histogram(mlp_acts['verbs_split1'][:,T_split1.statistic>0],bins='fd',density=True)

plt.figure(figsize=(9,3))
plt.plot(xNouns1_neg[:-1],yNouns1_neg,linewidth=2,label='Nouns t<0')
plt.plot(xVerbs1_neg[:-1],yVerbs1_neg,linewidth=2,label='Verbs t<0')
plt.plot(xNouns1_pos[:-1],yNouns1_pos,linewidth=2,label='Nouns t>0')
plt.plot(xVerbs1_pos[:-1],yVerbs1_pos,linewidth=2,label='Verbs t>0')

plt.legend()
plt.gca().set(xlabel='MLP expansion activation',ylabel='Density',ylim=[0,None],
              title='Histograms separated by t-value sign')

plt.tight_layout()
plt.savefig('ch7_proj44_part5b.png')
plt.show()

In [None]:
# find the neurons with the largest positive and negative t-values
max_t = np.argmax(T_split1.statistic)
min_t = np.argmin(T_split1.statistic)

# and get their histograms
yNouns1_max,xNouns1_max = np.histogram(mlp_acts['nouns_split1'][:,max_t],bins='fd',density=True)
yVerbs1_max,xVerbs1_max = np.histogram(mlp_acts['verbs_split1'][:,max_t],bins='fd',density=True)
yNouns1_min,xNouns1_min = np.histogram(mlp_acts['nouns_split1'][:,min_t],bins='fd',density=True)
yVerbs1_min,xVerbs1_min = np.histogram(mlp_acts['verbs_split1'][:,min_t],bins='fd',density=True)

plt.figure(figsize=(9,3))
plt.plot(xNouns1_max[:-1],yNouns1_max,'r',linewidth=2,label=f'Nouns (t = {T_split1.statistic[max_t]:.2f})')
plt.plot(xVerbs1_max[:-1],yVerbs1_max,'g',linewidth=2,label=f'Verbs (t = {T_split1.statistic[max_t]:.2f})')
plt.plot(xNouns1_min[:-1],yNouns1_min,'r--',linewidth=2,label=f'Nouns (t = {T_split1.statistic[min_t]:.2f})')
plt.plot(xVerbs1_min[:-1],yVerbs1_min,'g--',linewidth=2,label=f'Verbs (t = {T_split1.statistic[min_t]:.2f})')

plt.gca().set(xlabel='MLP expansion activation',ylabel='Density',ylim=[0,None],
              title='Histograms from two neurons')

plt.legend()
plt.tight_layout()
plt.savefig('ch7_proj44_part5c.png')
plt.show()

# **Part 6: Within-category tests**

In [None]:
# within-category t-tests
T_withinNoun = stats.ttest_ind(mlp_acts['nouns_split1'],
                               mlp_acts['nouns_split2'],axis=0)
T_withinVerb = stats.ttest_ind(mlp_acts['verbs_split1'],
                               mlp_acts['verbs_split2'],axis=0)

# and plot
fig,axs = plt.subplots(1,2,figsize=(9,4))
axs[0].plot(T_split1.statistic,T_withinNoun.statistic,'ko',markerfacecolor=[.9,.7,.7,.3])
axs[1].plot(T_split2.statistic,T_withinVerb.statistic,'ks',markerfacecolor=[.7,.9,.7,.3])

axlim = np.max([abs(T_split1.statistic).max(),abs(T_split2.statistic).max()])*1.1
axs[0].set(xlim=[-axlim,axlim],ylim=[-axlim,axlim],xlabel='t(nouns,verbs), split 1',ylabel='t(nouns-1,nouns-2)',
           title='A) Across vs. within-nouns comparison')
axs[1].set(xlim=[-axlim,axlim],ylim=[-axlim,axlim],xlabel='t(nouns,verbs), split 2',ylabel='t(verbs-1,verbs-2)',
           title='B) Across vs. within-verbs comparison')

plt.tight_layout()
plt.savefig('ch7_proj44_part6a.png')
plt.show()

In [None]:
# Cohen's d
cohensd_within = np.zeros(nneurons)
for i in range(nneurons):
  cohensd_within[i] = pg.compute_effsize(mlp_acts['nouns_split1'][:,i],mlp_acts['nouns_split2'][:,i],paired=False,eftype='cohen')

# histograms
yW,xW = np.histogram(abs(cohensd_within),bins='fd')
yA,xA = np.histogram(abs(cohensd),bins='fd')

# visualize
plt.figure(figsize=(9,3))
plt.plot(xW[:-1],yW,'o-',linewidth=2,label='Within category')
plt.plot(xA[:-1],yA,'s-',linewidth=2,label='Across category')

# indicating effect sizes
plt.axvline(.2,linestyle='--',color='r',label='Small effect')
plt.axvline(.8,linestyle=':',color='m',label='Large effect')

plt.legend()
plt.gca().set(xlabel="Cohen's d",ylabel='Count')

plt.tight_layout()
plt.savefig('ch7_proj44_part6b.png')
plt.show()

# **Part 7: Laminar profile of tuning**

In [None]:
n_layers = len(model.transformer.h)

In [None]:
# remove previous hook
handle.remove()

def outerHook(layeri):
  def hook(module,input,output):
    mlp_acts[f'L{layeri}_{whichdata}'] = output.detach().numpy().squeeze()
  return hook


# surgery ;)
handles = []
for layeri in range(n_layers):
  h = model.transformer.h[layeri].mlp.c_fc.register_forward_hook(outerHook(layeri))
  handles.append(h)

In [None]:
# this cell takes around 20 s

# reinialize data-dictionary
mlp_acts = {}

with torch.no_grad():

  # run the split1 nouns
  whichdata = 'nouns_split1'
  model(torch.tensor(nouns_split1).unsqueeze(1))

  # split2 nouns
  whichdata = 'nouns_split2'
  model(torch.tensor(nouns_split2).unsqueeze(1))

  # the split1 verbs
  whichdata = 'verbs_split1'
  model(torch.tensor(verbs_split1).unsqueeze(1))

  # and the split2 verbs
  whichdata = 'verbs_split2'
  model(torch.tensor(verbs_split2).unsqueeze(1))


In [None]:
for k,v in mlp_acts.items():
  print(f"mlp_acts['{k}'] has shape {list(v.shape)}")

In [None]:
sig_neurons = np.zeros((n_layers,5))

for layeri in range(n_layers):

  # run the t-tests
  T_split1 = stats.ttest_ind(mlp_acts[f'L{layeri}_nouns_split1'],
                             mlp_acts[f'L{layeri}_verbs_split1'],axis=0)
  T_split2 = stats.ttest_ind(mlp_acts[f'L{layeri}_nouns_split2'],
                             mlp_acts[f'L{layeri}_verbs_split2'],axis=0)

  # boolean of significant tests
  issig1 = stats.false_discovery_control(T_split1.pvalue)<.05
  issig2 = stats.false_discovery_control(T_split2.pvalue)<.05

  # proportion of significant neurons
  sig_neurons[layeri,0] = np.mean(issig1)
  sig_neurons[layeri,1] = np.mean(issig2)

  # average significant t-values
  sig_neurons[layeri,2] = np.mean(abs(T_split1.statistic[issig1]))
  sig_neurons[layeri,3] = np.mean(abs(T_split2.statistic[issig2]))

  # correlation between them
  sig_neurons[layeri,4] = np.corrcoef(T_split1.statistic,T_split2.statistic)[0,1]


In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3))

# proportion significant neurons
axs[0].plot(np.arange(n_layers)-.15,sig_neurons[:,0],'ko',markerfacecolor=[.7,.9,.7],markersize=10,label='Split-1')
axs[0].plot(np.arange(n_layers)+.15,sig_neurons[:,1],'ks',markerfacecolor=[.9,.7,.7],markersize=10,label='Split-2')
axs[0].set(xlabel='Transformer layer',ylabel='Proportion significant neurons',title='A) Proportion significant neurons')
axs[0].legend()

# average t-values
axs[1].plot(np.arange(n_layers)-.15,sig_neurons[:,2],'ko',markerfacecolor=[.7,.9,.7],markersize=10,label='Split-1')
axs[1].plot(np.arange(n_layers)+.15,sig_neurons[:,3],'ks',markerfacecolor=[.9,.7,.7],markersize=10,label='Split-2')
axs[1].set(xlabel='Transformer layer',ylabel='Average t-values',title='B) |T| of significant neurons')
axs[1].legend()

# correlation
axs[2].plot(sig_neurons[:,4],'kh',markersize=12,markerfacecolor=[.7,.7,.9])
axs[2].set(xlabel='Transformer layer',ylabel='Correlation coefficient',title='C) T-value split correlations')


plt.tight_layout()
plt.savefig('ch7_proj44_part7.png')
plt.show()

# **Part 8: Tokens in vs. out of order**

In [None]:
# source: https://en.wikipedia.org/wiki/Coconut
text = 'The coconut (Cocos nucifera) is a member of the palm family (Arecaceae) and the only living species of the genus Cocos.'

tokens = tokenizer.encode(text,return_tensors='pt')
print(f'There are {len(text)} characters and {len(tokens[0])} tokens.')

In [None]:
# scramble and invert
scrambled_idx = torch.randperm(len(tokens[0]))
scrambled_tokens = tokens[0,scrambled_idx].unsqueeze(0)
inverse_idx = torch.argsort(scrambled_idx)

print(f'Original sentence:\n {tokenizer.decode(tokens[0,:])}\n')
print(f'Scrambled sentence:\n {tokenizer.decode(scrambled_tokens[0,:])}\n')
print(f'Inverted scrambling:\n {tokenizer.decode(scrambled_tokens[0,inverse_idx])}')


In [None]:
mlp_acts = {}

with torch.no_grad():

  whichdata = 'sentence'
  model(tokens)

  whichdata = 'words'
  model(tokens.T)

  whichdata = 'scrambled'
  model(scrambled_tokens)

In [None]:
for k,v in mlp_acts.items():
  print(f"mlp_acts['{k}'] has shape {list(v.shape)}")

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3.5))

skip = 10
layer = 3

sent = mlp_acts[f'L{layer}_sentence'].flatten()[::skip]
word = mlp_acts[f'L{layer}_words'].flatten()[::skip]
scrm = mlp_acts[f'L{layer}_scrambled'][inverse_idx,:].flatten()[::skip]

axs[0].plot(sent,word,'ko',markersize=3,markerfacecolor=[.7,.7,.9,.3])
axs[0].set(xlabel='Sentence',ylabel='Words',title=f'A) Sentence vs. words (r = {np.corrcoef(sent,word)[0,1]:.3f})')

axs[1].plot(sent,scrm,'ks',markersize=3,markerfacecolor=[.7,.9,.7,.3])
axs[1].set(xlabel='Sentence',ylabel='Scrambled sentence',title=f'B) Sentence vs. scrambled (r = {np.corrcoef(sent,scrm)[0,1]:.3f})')

axs[2].plot(scrm,word,'k^',markersize=3,markerfacecolor=[.9,.7,.7,.3])
axs[2].set(xlabel='Scrambled sentence',ylabel='Words',title=f'C) Scrambled vs. words (r = {np.corrcoef(scrm,word)[0,1]:.3f})')

plt.tight_layout()
plt.savefig('ch7_proj44_part8a.png')
plt.show()

In [None]:
Rs = np.zeros((n_layers,3))

# loop over layers
for layeri in range(n_layers):

  # extract the activations
  sent = mlp_acts[f'L{layeri}_sentence'].flatten()
  word = mlp_acts[f'L{layeri}_words'].flatten()
  scrm = mlp_acts[f'L{layeri}_scrambled'][inverse_idx,:].flatten()

  # correlation coefficients
  Rs[layeri,0] = np.corrcoef(sent,word)[0,1]
  Rs[layeri,1] = np.corrcoef(sent,scrm)[0,1]
  Rs[layeri,2] = np.corrcoef(scrm,word)[0,1]

# and the visualizations
plt.figure(figsize=(10,3))
plt.plot(np.arange(n_layers)-.1,Rs[:,0],'s-',linewidth=2,markersize=8,markerfacecolor='w',label='Sentence-word')
plt.plot(np.arange(n_layers)   ,Rs[:,1],'o-',linewidth=2,markersize=8,markerfacecolor='w',label='Sentence-scrambled')
plt.plot(np.arange(n_layers)+.1,Rs[:,2],'^-',linewidth=2,markersize=8,markerfacecolor='w',label='Scrambled-word')

plt.axhline(0,linestyle='--',color='k',linewidth=.4,zorder=-10)
plt.gca().set(xlabel='Transformer layer',ylabel='Correlation coefficient',
              title='Correlations across token organizations')
plt.legend()

plt.tight_layout()
plt.savefig('ch7_proj44_part8b.png')
plt.show()