|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[25] Category selectivity via cosine similarity</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

from scipy.stats import ttest_ind
import requests

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Model, tokens, hidden states**

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained(
model.eval()

In [None]:
# list of words in three categories
words = [
    [1, 'galaxy'],
    [1, 'asteroid'],
    [1, 'comet'],
    [1, 'cosmos'],
    [1, 'space'],
    [1, 'sun'],
    [1, 'planet'],
    [1, 'moon'],
    [1, 'star'],
    [1, 'orbit'],
    [2, 'ceiling'],
    [2, 'sofa'],
    [2, 'couch'],
    [2, 'carpet'],
    [2, 'door'],
    [2, 'window'],
    [2, 'lamp'],
    [2, 'chair'],
    [2, 'table'],
    [2, 'rug'],
    [2, 'bed'],
    [2, 'floor'],
    [2, 'wall'],
    [3, 'pear'],
    [3, 'grape'],
    [3, 'banana'],
    [3, 'cherry'],
    [3, 'peach'],
    [3, 'apple'],
    [3, 'seed'],
    [3, 'jelly'],
    [3, 'orange'],
    [3, 'lime'],
    [3, 'fruit'],
]

# confirm they're single-token words
for w in words:
  t = tokenizer.encode()
  print

In [None]:
# create a batch of tokens
batch = torch.zeros((,),dtype=)

for i,w in enumerate(words):
  batch[i,:] = tokenizer.encode

# push through the model
with torch.no_grad():
  outs = model(batch)

# check shape of one hidden states (even though they're not toggled on in the above call!)
outs.hidden_states[3].shape

In [None]:
plt.figure(figsize=(8,3))

for i in range(batch.shape[1]):
  hs = outs.hidden_states
  plt.plot(np.random.normal(i,.03,hs.shape[1]),hs.var(dim=0),'o',
           markeredgecolor='k',markersize=10,alpha=.3)

plt.tight_layout()
plt.savefig('ch5_proj25_part1.png')
plt.show()

# **Part 2: Cosine similarities in one layer**

In [None]:
# vector of word categories
labels =
print(labels.shape)

# mask for cosine similarity matrix
cat_mask = np.triu( )

# create a discrete colormap
N = cat_mask.max()
base = plt.get_cmap('rainbow')
cmapN = base.from_list(None,base(np.linspace(0,1,N)),N)

plt.imshow(cat_mask,cmap=cmapN)
plt.gca().set(xlabel='Word indices',ylabel='Word indices')
plt.colorbar(ticks=np.unique(cat_mask),pad=.02)

plt.tight_layout()
plt.savefig('ch5_proj25_part2a.png')
plt.show()

In [None]:
layeri = 4
hs = outs.hidden_states[][]
hs /= torch.linalg.norm(
csMat =  @

# extract the unique values
within = torch.concatenate((csMat[cat_mask==],csMat[cat_mask==],csMat[cat_mask==]))
across = torch.concatenate((csMat[cat_mask==],csMat[cat_mask==],csMat[cat_mask==]))

# t-test
tval = ttest_ind(,).statistic

# histograms
yWithin,xWithin = torch.histogram(,bins=20,density=True)
yAcross,xAcross = torch.histogram(,bins=20,density=True)

_,axs = plt.subplots(1,2,figsize=(10,4))
h = axs[0].imshow(csMat,vmin=.5,vmax=.85)
plt.colorbar(h,ax=axs[0],pad=.01)
axs[0].set(xlabel='Word index',ylabel='Word index',title='A) Cosine similarity matrix')

axs[1].plot(label='Within-category')
axs[1].plot(label='Across-category')
axs[1].set(ylim=[0,None],xlabel='Cosine similarity',ylabel='Density',
           title=f'B) Layer {layeri} distributions ($\\Delta$ = {within.mean()-across.mean():.2f}, t = {tval:.2f})')
axs[1].legend()


plt.tight_layout()
plt.savefig('ch5_proj25_part2b.png')
plt.show()

# **Part 3: Laminar profiles**

In [None]:
num_hidden = len(outs.hidden_states)
# names of the layers, for the x-axis tick labels
layer_labels =

In [None]:
# initializations
dist_metrics = np.zeros((4,num_hidden))


for layeri in range(num_hidden):

  # cosine similarity matrix
  hs = # extract
  hs /= # normalize
  csMat = # calculate

  # extract the similarity values
  w =
  a =

  # calculate the metrics
  dist_metrics[0,layeri] =
  dist_metrics[1,layeri] =
  dist_metrics[2,layeri] =
  dist_metrics[3,layeri] =

In [None]:
# show the results
_,axs = plt.subplots(2,2,figsize=(10,7))
axs[0,0].plot(,'kh',markerfacecolor=[.7,.9,.7],markersize=13,label='Within')
axs[0,0].plot(,'ko',markerfacecolor=[.9,.7,.7],markersize=12,label='Across')

# thin connecting lines representing \Delta
for i in range(1,num_hidden): # plot 0 separately for the legend
  axs[0,0].plot([i,i],dist_metrics[:2,i],'k:',linewidth=.3,zorder=-10)
axs[0,0].plot([0,0],dist_metrics[:2,0],'k:',linewidth=.3,zorder=-10,label='$\\Delta$')


axs[0,1].scatter(range(num_hidden),,s=150,edgecolor='k',
            c=plt.cm.plasma(np.linspace(0,1,num_hidden)),alpha=.7)
axs[0,1].set(xlabel='Hidden state layer',xticks=range(len(layer_labels)),xticklabels=layer_labels,
           ylabel='$\\mathbf{\\Delta}$',title='B) Similarity mean differences')

axs[1,0].scatter(,,s=150,edgecolor='k',
            c=plt.cm.plasma(np.linspace(0,1,num_hidden)),alpha=.7)
axs[1,0].set(xlabel='Hidden state layer',xticks=range(len(layer_labels)),xticklabels=layer_labels,
           ylabel='T-value',title='C) T-values')

axs[1,1].scatter(,,s=150,edgecolor='k',
            c=plt.cm.plasma(np.linspace(0,1,num_hidden)),alpha=.7)
axs[1,1].set(xlabel='$\\mathbf{\\Delta}$',ylabel='T-value',title='D) Impact of normalization')


plt.tight_layout()
plt.savefig('ch5_proj25_part3.png')
plt.show()

In [None]:
for seqi in range(len(words)):

  # find max token
  maxtok = torch.argmax

  # build and print the sentence
  sentence = f'{}{}'
  print(sentence)

# **Part 4: Same words, different contexts**

In [None]:
# Note: Some variables here overwrite variables defined in earlier parts.

In [None]:
# GreatGatsby
url = 'https://www.gutenberg.org/cache/epub/64317/pg64317.txt'
text = requests.get(url).text
tokens = np.array( tokenizer.encode(text) )
print(f'There are {len(tokens):,} tokens, {len(set(tokens))} of which are unique.')

In [None]:
targets = [' years',' garage', ' lawn']

for i,w in enumerate(targets):

  # this token index
  t = tokenizer.

  # find all the matches
  matches =
  labels =
  together =

  # stitch together


  # print count and confirm single-token words
  print(f'{w:>10} is token {t} and appears {} times.')

In [None]:
contextwin = 50

# create a batch
batch = torch.zeros((target_idxs.shape[0],contextwin+1),dtype=torch.long)


for i in range(target_idxs.shape[0]):
  toks = tokens[:]
  batch[i,:] = torch.tensor(toks)

batch.shape

In [None]:
# show last few tokens in each sentence
for b in batch:
  print(tokenizer.decode(b[-5:]))

# **Part 5: Within- and across-word similarities**

In [None]:
# forward pass (~25 sec on CPU)
with torch.no_grad():
  outs = model

outs.hidden_states[3].shape

In [None]:
# extract embeddings vectors for final token
activations = outs.hidden_states

# mask for cosine similarity matrix
cat_mask = np.triu( )

plt.imshow(cat_mask)
plt.gca().set(xlabel='Sequence index',ylabel='Sequence index')
plt.colorbar(pad=.02)

plt.tight_layout()
plt.savefig('ch5_proj25_part5a.png')
plt.show()

In [None]:
# cosine similarity matrix for one layer
layeri = 4
hs = outs.hidden_states[][,,]
hs /= (hs,axis=1,keepdims=True)
csMat = hs @ hs.T

# extract the block-wise data
within =
acros1 =
acros2 =

# distributions
yWithin,xWithin = torch.histogram(within,bins=30,density=True)
yAcros1,xAcros1 = torch.histogram(acros1,bins=30,density=True)
yAcros2,xAcros2 = torch.histogram(acros2,bins=30,density=True)


# show the similarity matrix
_,axs = plt.subplots(1,2,figsize=(10,4))
h = axs[0].imshow(csMat,vmin=.5,vmax=1)
plt.colorbar(h,ax=axs[0],pad=.01)
axs[0].set(xlabel='Word index',ylabel='Word index',title='Cosine similarity matrix')

# and the distributions
axs[1].plot(xWithin[:-1],yWithin,linewidth=2,label='Within-category')
axs[1].plot(xAcros1[:-1],yAcros1,linewidth=2,label='<years,garage/lawn>')
axs[1].plot(xAcros2[:-1],yAcros2,linewidth=2,label='<garage,lawn>')
axs[1].set(ylim=[0,None],xlabel='Cosine similarity',ylabel='Density',
           title=f'Layer {layeri} distributions')
axs[1].legend()

plt.tight_layout()
plt.savefig('ch5_proj25_part5b.png')
plt.show()

# **Part 6: Laminar profiles of similarity distributions**

In [None]:
# common histogram bins for all distributions
binedges = torch.linspace(.4,1,35)
dx = binedges[1]-binedges[0]

# initialize: grouping X layer X bin
cs_dists = torch.zeros((3,num_hidden, ))

# loop over layers
for layeri in range():

  # calculate the similarity matrix
  hs = outs.hidden_states[layeri][:,-1,:]
  hs /= torch.linalg.norm(hs
  csMat =

  # extract the category elements
  within = torch.concatenate((csMat[cat_mask==1],csMat[cat_mask==4],csMat[cat_mask==9]))
  acros1 = torch.concatenate((csMat[cat_mask==2],csMat[cat_mask==3]))
  acros2 = csMat[cat_mask==6]

  # calculate the histograms
  cs_dists[0,layeri,:],_ = torch.histogram()
  cs_dists[1,layeri,:],_ = torch.histogram
  cs_dists[2,layeri,:],_ = torch.


# transform from density to probability/dx


In [None]:
# the tornado plots :D
_,axs = plt.subplots(1,3,figsize=(12,4))

titles = [ 'Within','<years,garage/lawn>','<lawn,garage>' ]

# generate the plots
for i in range(3):
  h = axs[i].imshow(cs_dists[i,:,:],vmin=0,vmax=.25,aspect='auto',origin='lower',
                extent=[binedges[0],binedges[-1],0,num_hidden],cmap='magma')
  plt.colorbar(h,ax=axs[i]
  axs[i].set(xlabel='Cosine similarity',ylabel='Hidden layer')
  axs[i].set_title()

plt.tight_layout()
plt.savefig('ch5_proj25_part6.png')
plt.show()