|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[38] Laminar profile of RSA and category selectivity</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from scipy.stats import ttest_ind

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Hook all layers and get activations**

In [None]:
# load GPT2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-xl')
tokenizer = AutoTokenizer.from_pretrained('gpt2-xl')

n_layers = model.config.n_layer

model.eval()

In [None]:
# define a hook function to store QKV vectors
activations = {}

def implant_hook(layer_number):
  def hook(module,input,output):

    # get and separate the QKV
    q,k,v = output.detach().split(model.config.n_embd,dim=2)

    # numpyify and store
    activations[f'attn_{layer_number}_q'] = q.numpy()
    activations[f'attn_{layer_number}_k'] = k.numpy()
    activations[f'attn_{layer_number}_v'] = v.numpy()
  return hook


# surgery ;)
for layeri in range(model.config.n_layer):
  model.transformer.h[layeri].attn.c_attn.register_forward_hook(implant_hook(layeri))

In [None]:
# list of words in three categories
words = [
    [1, 'galaxy'],
    [1, 'asteroid'],
    [1, 'comet'],
    [1, 'cosmos'],
    [1, 'space'],
    [1, 'sun'],
    [1, 'planet'],
    [1, 'moon'],
    [1, 'star'],
    [1, 'orbit'],
    [2, 'ceiling'],
    [2, 'sofa'],
    [2, 'couch'],
    [2, 'carpet'],
    [2, 'door'],
    [2, 'window'],
    [2, 'lamp'],
    [2, 'chair'],
    [2, 'table'],
    [2, 'rug'],
    [2, 'bed'],
    [2, 'floor'],
    [2, 'wall'],
    [3, 'pear'],
    [3, 'grape'],
    [3, 'banana'],
    [3, 'cherry'],
    [3, 'peach'],
    [3, 'apple'],
    [3, 'seed'],
    [3, 'jelly'],
    [3, 'orange'],
    [3, 'lime'],
    [3, 'fruit'],
]


# vector of word categories
labels = np.array([[w[0] for w in words]]).T

# pre-mask matrices
same_cat = labels == labels.T
upper_tr = np.triu(np.ones_like(same_cat,dtype=bool),1)

# create the two category masks
within_mask = same_cat & upper_tr
across_mask = (~same_cat) & upper_tr

# and visualize
_,axs = plt.subplots(2,2,figsize=(10,6))
axs[0,0].imshow(same_cat)
axs[0,0].set(xlabel='Words',ylabel='Words',title='A) Same-category mask')

axs[0,1].imshow(upper_tr)
axs[0,1].set(xlabel='Words',ylabel='Words',title='B) Upper-triangular mask')

axs[1,0].imshow(within_mask)
axs[1,0].set(xlabel='Words',ylabel='Words',title='C) Within-category mask')

axs[1,1].imshow(across_mask)
axs[1,1].set(xlabel='Words',ylabel='Words',title='D) Across-category mask')

plt.tight_layout()
plt.savefig('ch6_proj38_part1.png')
plt.show()

In [None]:
# create a batch of tokens
batch = torch.zeros((len(words),5),dtype=torch.long)

for i,w in enumerate(words):
  batch[i,:] = tokenizer.encode(f'The next word is {w[1]}',return_tensors='pt')

# push through the model
with torch.no_grad():
  outs = model(batch)

In [None]:
activations.keys(), activations['attn_5_q'].shape

# **Part 2: Correlating Q, K, and V activations**

In [None]:
layeri = 6

# get activations
actsQ = activations[f'attn_{layeri}_q'][:,1,:].squeeze()
actsK = activations[f'attn_{layeri}_k'][:,1,:].squeeze()
actsV = activations[f'attn_{layeri}_v'][:,1,:].squeeze()

# scatter plots and correlations
_,axs = plt.subplots(1,3,figsize=(12,3.5))

# Q by K
r = np.corrcoef(actsQ.flatten(),actsK.flatten())
axs[0].plot(actsQ.flatten(),actsK.flatten(),'k.',markeredgewidth=.4,markerfacecolor=[.9,.7,.7,.3])
axs[0].set(xlabel='Q activations',ylabel='K activations',title=f'Q by K (r = {r[0,1]:.3f})')

# Q by V
r = np.corrcoef(actsQ.flatten(),actsV.flatten())
axs[1].plot(actsQ.flatten(),actsV.flatten(),'k.',markeredgewidth=.4,markerfacecolor=[.7,.9,.7,.3])
axs[1].set(xlabel='Q activations',ylabel='V activations',title=f'Q by V (r = {r[0,1]:.3f})')

# K by V
r = np.corrcoef(actsK.flatten(),actsV.flatten())
axs[2].plot(actsK.flatten(),actsV.flatten(),'k.',markeredgewidth=.4,markerfacecolor=[.7,.7,.9,.3])
axs[2].set(xlabel='K activations',ylabel='V activations',title=f'K by V (r = {r[0,1]:.3f})')

plt.tight_layout()
plt.savefig('ch6_proj38_part2a.png')
plt.show()

In [None]:
corrs = np.zeros((n_layers,3))

for layeri in range(n_layers):

  # get activations
  actsQ = activations[f'attn_{layeri}_q'][:,1,:].squeeze()
  actsK = activations[f'attn_{layeri}_k'][:,1,:].squeeze()
  actsV = activations[f'attn_{layeri}_v'][:,1,:].squeeze()

  # correlations
  corrs[layeri,0] = np.corrcoef(actsQ.flatten(),actsK.flatten())[0,1]
  corrs[layeri,1] = np.corrcoef(actsQ.flatten(),actsV.flatten())[0,1]
  corrs[layeri,2] = np.corrcoef(actsK.flatten(),actsV.flatten())[0,1]


plt.figure(figsize=(10,3))
plt.plot(corrs[:,0],'rs-',markerfacecolor=[.9,.7,.7,.7],markersize=8,label='Q-K')
plt.plot(corrs[:,1],'go-',markerfacecolor=[.7,.9,.7,.7],markersize=8,label='Q-V')
plt.plot(corrs[:,2],'b^-',markerfacecolor=[.7,.7,.9,.7],markersize=8,label='K-V')
plt.gca().set(xlabel='Transformer layer',ylabel='Correlation coefficient')

plt.legend()

plt.tight_layout()
plt.savefig('ch6_proj38_part2b.png')
plt.show()

# **Part 3: Cosine similarities and RSA (one layer)**

In [None]:
layeri = 6


### cosine similarity matrices

# for Q
actsQ = activations[f'attn_{layeri}_q'][:,-1,:].squeeze()
actsQ /= np.linalg.norm(actsQ,axis=1,keepdims=True)
cs_actsQ = actsQ @ actsQ.T

# for K
actsK = activations[f'attn_{layeri}_k'][:,-1,:].squeeze()
actsK /= np.linalg.norm(actsK,axis=1,keepdims=True)
cs_actsK = actsK @ actsK.T

# for V
actsV = activations[f'attn_{layeri}_v'][:,-1,:].squeeze()
actsV /= np.linalg.norm(actsV,axis=1,keepdims=True)
cs_actsV = actsV @ actsV.T


fig,axs = plt.subplots(1,3,figsize=(10,3))
h0 = axs[0].imshow(cs_actsQ,vmin=.3,vmax=1,cmap='magma')
h1 = axs[1].imshow(cs_actsK,vmin=.3,vmax=1,cmap='magma')
h2 = axs[2].imshow(cs_actsV,vmin=.3,vmax=1,cmap='magma')

# colormaps
fig.colorbar(h0,ax=axs[0],orientation='horizontal',pad=.02,fraction=.047)
fig.colorbar(h1,ax=axs[1],orientation='horizontal',pad=.02,fraction=.047)
fig.colorbar(h2,ax=axs[2],orientation='horizontal',pad=.02,fraction=.047)

# add the category lines
for i in range(3):
  axs[i].axhline((labels==1).sum()-.5,color='w',linestyle='--',linewidth=.5)
  axs[i].axhline((labels<3).sum()-.5,color='w',linestyle='--',linewidth=.5)
  axs[i].axvline((labels==1).sum()-.5,color='w',linestyle='--',linewidth=.5)
  axs[i].axvline((labels<3).sum()-.5,color='w',linestyle='--',linewidth=.5)

# adjustments
axs[0].set(xticklabels=[],ylabel='Words',title='A) Q cosine similarities')
axs[1].set(xticklabels=[],ylabel='Words',title='B) K cosine similarities')
axs[2].set(xticklabels=[],ylabel='Words',title='C) V cosine similarities')

plt.tight_layout()
plt.savefig('ch6_proj38_part3a.png')
plt.show()

In [None]:
# extract the non-redundant similarity values
qcs = cs_actsQ[np.triu_indices(len(labels),1)]
kcs = cs_actsK[np.triu_indices(len(labels),1)]
vcs = cs_actsV[np.triu_indices(len(labels),1)]

# setup the figure
fig,axs = plt.subplots(1,3,figsize=(10,3))

# plot the RSA for each matrix pair
axs[0].plot(qcs,kcs,'k.',markeredgewidth=.4,markersize=8,markerfacecolor=[.9,.7,.7,.7])
axs[0].set(xlabel='Q cossim',ylabel='K cossim',title=f'Q and K (r = {np.corrcoef(qcs,kcs)[0,1]:.3f})')

axs[1].plot(qcs,vcs,'k.',markeredgewidth=.4,markersize=8,markerfacecolor=[.7,.9,.7,.7])
axs[1].set(xlabel='Q cossim',ylabel='V cossim',title=f'Q and V (r = {np.corrcoef(qcs,vcs)[0,1]:.3f})')

axs[2].plot(kcs,vcs,'k.',markeredgewidth=.4,markersize=8,markerfacecolor=[.7,.7,.9,.7])
axs[2].set(xlabel='K cossim',ylabel='V cossim',title=f'K and V (r = {np.corrcoef(kcs,vcs)[0,1]:.3f})')

plt.tight_layout()
plt.savefig('ch6_proj38_part3b.png')
plt.show()

# **Part 4: Laminar profile of RSA scores**

In [None]:
rsa = np.zeros((n_layers,3))

for layeri in range(n_layers):


  ### cosine similarity matrices

  # for Q
  actsQ = activations[f'attn_{layeri}_q'][:,-1,:].squeeze()
  actsQ /= np.linalg.norm(actsQ,axis=1,keepdims=True)
  cs_actsQ = actsQ @ actsQ.T
  qcs = cs_actsQ[np.triu_indices(len(labels),1)]

  # for K
  actsK = activations[f'attn_{layeri}_k'][:,-1,:].squeeze()
  actsK /= np.linalg.norm(actsK,axis=1,keepdims=True)
  cs_actsK = actsK @ actsK.T
  kcs = cs_actsK[np.triu_indices(len(labels),1)]

  # for V
  actsV = activations[f'attn_{layeri}_v'][:,-1,:].squeeze()
  actsV /= np.linalg.norm(actsV,axis=1,keepdims=True)
  cs_actsV = actsV @ actsV.T
  vcs = cs_actsV[np.triu_indices(len(labels),1)]


  # calculate RSA scores
  rsa[layeri,0] = np.corrcoef(qcs,kcs)[0,1]
  rsa[layeri,1] = np.corrcoef(qcs,vcs)[0,1]
  rsa[layeri,2] = np.corrcoef(kcs,vcs)[0,1]



plt.figure(figsize=(10,3))
plt.plot(rsa[:,0],'rs-',markerfacecolor=[.9,.7,.7,.7],markersize=8,label='Q-K')
plt.plot(rsa[:,1],'go-',markerfacecolor=[.7,.9,.7,.7],markersize=8,label='Q-V')
plt.plot(rsa[:,2],'b^-',markerfacecolor=[.7,.7,.9,.7],markersize=8,label='K-V')
plt.gca().set(xlabel='Transformer layer',ylabel='RSA scores')

plt.legend()

plt.tight_layout()
plt.savefig('ch6_proj38_part4.png')
plt.show()

# **Part 5: Category selectivity in one layer**

In [None]:
layeri = 6

### recreate the cosine similarity matrices for this layer
actsQ = activations[f'attn_{layeri}_q'][:,-1,:].squeeze()
actsQ /= np.linalg.norm(actsQ,axis=1,keepdims=True)
cs_actsQ = actsQ @ actsQ.T

actsK = activations[f'attn_{layeri}_k'][:,-1,:].squeeze()
actsK /= np.linalg.norm(actsK,axis=1,keepdims=True)
cs_actsK = actsK @ actsK.T

actsV = activations[f'attn_{layeri}_v'][:,-1,:].squeeze()
actsV /= np.linalg.norm(actsV,axis=1,keepdims=True)
cs_actsV = actsV @ actsV.T



_,axs = plt.subplots(1,3,figsize=(12,3))


### for Q
w = cs_actsQ[within_mask].flatten()
a = cs_actsQ[across_mask].flatten()

selectT = ttest_ind(w,a).statistic
yw,xw = np.histogram(w,bins='fd',density=True)
ya,xa = np.histogram(a,bins='fd',density=True)
axs[0].plot(xw[:-1],yw,linewidth=2,label='Within')
axs[0].plot(xa[:-1],ya,linewidth=2,label='Across')
axs[0].legend()
axs[0].set(xlabel='Cosine similarity',ylabel='Density',ylim=[0,None],title=f'A) Q (t = {selectT:.2f})')


### for K
w = cs_actsK[within_mask].flatten()
a = cs_actsK[across_mask].flatten()

selectT = ttest_ind(w,a).statistic
yw,xw = np.histogram(w,bins='fd',density=True)
ya,xa = np.histogram(a,bins='fd',density=True)
axs[1].plot(xw[:-1],yw,linewidth=2,label='Within')
axs[1].plot(xa[:-1],ya,linewidth=2,label='Across')
axs[1].legend()
axs[1].set(xlabel='Cosine similarity',ylabel='Density',ylim=[0,None],title=f'B) K (t = {selectT:.2f})')


### for V
w = cs_actsV[within_mask].flatten()
a = cs_actsV[across_mask].flatten()

selectT = ttest_ind(w,a).statistic
yw,xw = np.histogram(w,bins='fd',density=True)
ya,xa = np.histogram(a,bins='fd',density=True)
axs[2].plot(xw[:-1],yw,linewidth=2,label='Within')
axs[2].plot(xa[:-1],ya,linewidth=2,label='Across')
axs[2].legend()
axs[2].set(xlabel='Cosine similarity',ylabel='Density',ylim=[0,None],title=f'C) V (t = {selectT:.2f})')

plt.suptitle(f'Category-specific similarity distributions, layer {layeri}',fontsize=15,fontweight='bold')

plt.tight_layout()
plt.savefig('ch6_proj38_part5.png')
plt.show()

# **Part 6: Laminar category selectivity**

In [None]:
selectTs = np.zeros((n_layers,3,3))


for layeri in range(n_layers):


  ### recreate the cosine similarity matrices for this layer
  actsQ = activations[f'attn_{layeri}_q'][:,-1,:].squeeze()
  actsQ /= np.linalg.norm(actsQ,axis=1,keepdims=True)
  cs_actsQ = actsQ @ actsQ.T

  actsK = activations[f'attn_{layeri}_k'][:,-1,:].squeeze()
  actsK /= np.linalg.norm(actsK,axis=1,keepdims=True)
  cs_actsK = actsK @ actsK.T

  actsV = activations[f'attn_{layeri}_v'][:,-1,:].squeeze()
  actsV /= np.linalg.norm(actsV,axis=1,keepdims=True)
  cs_actsV = actsV @ actsV.T


  # Note: For Part 7b, include "np.arctanh" before each "w=" and "a=" in the code below


  ### for Q
  w = (cs_actsQ[within_mask].flatten())
  a = (cs_actsQ[across_mask].flatten())
  selectTs[layeri,0,0] = w.mean()-a.mean()
  selectTs[layeri,0,1] = ttest_ind(w,a).statistic
  selectTs[layeri,0,2] = (w.mean()+a.mean())/2


  ### for K
  w = (cs_actsK[within_mask].flatten())
  a = (cs_actsK[across_mask].flatten())
  selectTs[layeri,1,0] = w.mean()-a.mean()
  selectTs[layeri,1,1] = ttest_ind(w,a).statistic
  selectTs[layeri,1,2] = (w.mean()+a.mean())/2


  ### for V
  w = (cs_actsV[within_mask].flatten())
  a = (cs_actsV[across_mask].flatten())
  selectTs[layeri,2,0] = w.mean()-a.mean()
  selectTs[layeri,2,1] = ttest_ind(w,a).statistic
  selectTs[layeri,2,2] = (w.mean()+a.mean())/2



fig,axs = plt.subplots(1,2,figsize=(12,3))
for i in [0,1]:
  axs[i].plot(selectTs[:,0,i],'rs-',markerfacecolor=[.9,.7,.7,.7],markersize=8,label='Q')
  axs[i].plot(selectTs[:,1,i],'go-',markerfacecolor=[.7,.9,.7,.7],markersize=8,label='K')
  axs[i].plot(selectTs[:,2,i],'b^-',markerfacecolor=[.7,.7,.9,.7],markersize=8,label='V')
  axs[i].legend()

axs[0].set(xlabel='Transformer layer',ylabel='Differences',title='A) Within - Across')
axs[1].set(xlabel='Transformer layer',ylabel='Selectivity (t-stat)',title='B) T values')

plt.tight_layout()
plt.savefig('ch6_proj38_part6a.png')
plt.show()

In [None]:
_,axs = plt.subplots(1,2,figsize=(8,3))

# normalization function for mapping layer index onto color
norm = mpl.colors.Normalize(vmin=0,vmax=n_layers)

for i in range(n_layers):
  axs[0].plot(selectTs[i,:,0].flatten(),selectTs[i,:,2].flatten(),'ko',
              markerfacecolor=plt.cm.rainbow(norm(i)),alpha=.7)
  axs[1].plot(selectTs[i,:,1].flatten(),selectTs[i,:,2].flatten(),'ks',
              markerfacecolor=plt.cm.rainbow(norm(i)),alpha=.7)


# add colorbars
sm = mpl.cm.ScalarMappable(cmap=mpl.cm.rainbow,norm=norm)
cbar = plt.colorbar(sm,ax=axs[1],pad=.01)
cbar.set_label('Transformer layer')

axs[0].set(xlabel='Cosine similarity differences',ylabel='Average cosine similarity',
           title='A) Differences vs. averages')
axs[1].set(xlabel='T-values',ylabel='Average cosine similarity',
           title='B) T-values vs. averages')

plt.tight_layout()
plt.savefig('ch6_proj38_part6b.png') # comment out for Part 7


### --- uncomment for Part 7
# axs[0].set(xlabel='Fisher-z(cos.sim) diffs',ylabel='Fisher-z(cos.sim) ave.')
# axs[1].set(ylabel='Fisher-z(cos.sim) ave.')
# plt.savefig('ch6_proj38_part7b.png')
### ---


plt.show()

# **Part 7: Repeat with Fisher-z transform**

In [None]:
r = np.random.uniform(-1,1,size=10_000)
fz_r = np.arctanh(r)

_,axs = plt.subplots(1,2,figsize=(10,4))
axs[0].plot(r,fz_r,'ko',markerfacecolor=[.7,.9,.7,.3])
axs[0].set(xlabel='Cosine similarity',ylabel='Fisher-z transform',title='A) Impact of transformation')
axs[0].axline((-1,-1),slope=1,linestyle='--',color='k')
axs[0].grid(linestyle='--',linewidth=.3)

y,x = np.histogram(r,bins='fd')
axs[1].plot(x[:-1],y,linewidth=2,label='Cos. similarities')
y,x = np.histogram(fz_r,bins='fd')
axs[1].plot(x[:-1],y,linewidth=2,label='Fisher-z')
axs[1].set(xlabel='Similarity or Fisher-z value',ylabel='Count',title='B) Distributions')
axs[1].legend()

plt.tight_layout()
plt.savefig('ch6_proj38_part7a.png')
plt.show()