|<h2>Book:</h2>|<h1>50 ML projects to understand LLMs</h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[24] Cosine similarities within and across layers</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import requests

import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Extract hidden states**

In [None]:
# GPT2 model and its tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-medium')
tokenizer = AutoTokenizer.from_pretrained('gpt2')

# toggle model into "evaluation" mode
model.eval()

In [None]:
model.config

In [None]:
# some tokens
txt = 'A wise man once said: Penguins are cute.'
tokens =
num_tokens =

for key,item in tokens.items():
  print(f'"{}" contains:\n  {}\n')

In [None]:
# forward pass and inspect output sizes
with torch.no_grad():
  outputs =

print('Keys in "outputs":\n  ',outputs.keys())
print('\nSize of outputs.logits:\n  ',outputs.logits.shape)
print('\nNumber of hidden states:\n  ',len(outputs.hidden_states))
print('\nSize of each hidden state:\n  ',outputs.hidden_states[0].shape)

# **Part 2: Inspect hidden states**

In [None]:
# some convenience variables
hs = outputs.hidden_states
num_hidden =
hidden_dim = model.config.

In [None]:
# all hidden-states for one token
whichToken = 8

# initialize stdev matrix
token_stds = torch.zeros(())

# setup the figure
_,axs = plt.subplots(1,3,figsize=(12,3.5))

# loop over layers
for layeri in range():

  # extract the activations from this layer and this token
  acts = hs[

  # plot all the activations
  axs[0].plot(,,'ko',markersize=8,
           markerfacecolor=plt.cm.plasma((layeri+1)/num_hidden),alpha=.4)

  axs[1].plot(,,'kh',markersize=12,
           markerfacecolor=plt.cm.plasma((layeri+1)/num_hidden))

  # plot the variance of the activations
  axs[2].plot(,,'ks',markersize=12,
           markerfacecolor=plt.cm.plasma((layeri+1)/num_hidden))

  # standard deviation for all token positions
  token_stds[layeri,:] =

# names of the layers, for the x-axis tick labels
layer_labels = ['Emb'] + [f'L{i}' for i in range(num_hidden-1)]

# adjust the axes
axs[0].set(xticks=range(0,num_hidden,2),xticklabels=layer_labels[::2],xlabel='Hidden layer (model depth)',
           ylabel='Activation value',title=f'A) Hidden state activations for token "{tokenizer.decode(tokens['input_ids'][0,whichToken])}"')

axs[1].set(xticks=range(1,num_hidden,2),xticklabels=layer_labels[1::2],xlabel='Hidden layer (model depth)',
           ylabel='Vector mean',title=f'B) Activation means for token "{tokenizer.decode(tokens['input_ids'][0,whichToken])}"')

axs[2].set(xticks=range(1,num_hidden,2),xticklabels=layer_labels[1::2],xlabel='Hidden layer (model depth)',
           ylabel='Vector variance',title=f'C) Activation variances for token "{tokenizer.decode(tokens['input_ids'][0,whichToken])}"')


plt.tight_layout()
plt.savefig('ch5_proj24_part2a.png')
plt.show()

In [None]:
plt.figure(figsize=(8,5))
plt.imshow()
plt.gca().set(ylabel='Layer',xticks=range(num_tokens),title='Hidden-states variability',
              xticklabels=[tokenizer.decode(t) for t in tokens['input_ids'][0]],
              yticks=range(num_hidden),yticklabels=layer_labels)
ch = plt.colorbar(pad=.01)

plt.tight_layout()
plt.savefig('ch5_proj24_part2b.png')
plt.show()

# **Part 3: Cosine similarities across layers**

In [None]:
# pick 4 evenly spaced tokens including the first and final
tokens2analyze = np.linspace

fig,axs = plt.subplots(1,4,layout='constrained',figsize=(12,3))

cos_sims = []

# loop over selected tokens
for toki in range(len(tokens2analyze)):

  # extract the hidden-state activations from this token into a matrix
  all_hiddens = torch.zeros((num_hidden,hidden_dim))
  for layeri in range(num_hidden):
    all_hiddens[layeri,:] =

  # and calculate the cosine similarity matrix on all pairs of layers
  M = F.cosine_similarity()
  cos_sims.append(M)

  # show the matrix
  h = axs[toki].imshow(M,cmap='plasma',vmin=.8,vmax=1,origin='lower')

# adjustments
axs[0].set(xlabel='Hidden layer (model depth)',ylabel='Hidden layer (model depth)')

plt.savefig('ch5_proj24_part3a.png')
plt.show()

In [None]:
# create a distance matrix
l = np.arange(num_hidden).reshape(1,-1)
D =

plt.imshow()
plt.colorbar(pad=.02)
plt.gca().set(title='Inter-layer distances',ylabel='Layer index',xlabel='Layer index')

plt.tight_layout()
plt.savefig('ch5_proj24_part3b.png')
plt.show()

In [None]:
# show the relationships
_,axs = plt.subplots(1,4,figsize=(12,3))

# non-redundant rows and columns
rows,cols = np.nonzero(np.triu(D,1))

for toki in range(len(tokens2analyze)):

  for i in range(num_hidden):
    x = D[,]
    y = cos_sims[toki][,]
    axs[toki].plot(x,y

  axs[toki].set(xlabel='Inter-layer distance',ylabel='Cosine similarity',
                title=f'"{}"')

plt.tight_layout()
plt.savefig('ch5_proj24_part3c.png')
plt.show()

# **Part 4: Cosine similarities across tokens**

In [None]:
# convert tokens into a list for axis labeling
toks_list = []

# 4 evenly spaced layers
layers2analyze = np.linspace(0,num_hidden-1,4

fig,axs = plt.subplots(1,4, layout='constrained',figsize=(12,3))

# loop over layers
for layeri in range(len(layers2analyze)):

  # cosine similarity matrix over all token pairs for this layer
  cos_sim = F.cosine_similarity(,,dim=

  # show the matrix
  h = axs[layeri].imshow(cos_sim,cmap='plasma',vmin=.7,vmax=1,origin='lower')


fig.colorbar(h,ax=axs[-1],label='Cosine similarity',pad=.02,shrink=.96)

plt.savefig('ch5_proj24_part4a.png')
plt.show()

In [None]:
plt.figure(figsize=(10,4))

for layeri in range(num_hidden):

  # similarities across all tokens, excluding the first
  cos_sim = F.cosine_similarity(hs[],hs[],dim=)
  unique_sim =

  # and plot all the dots
  plt.plot(np.random.normal(layeri,.05,len(unique_sim)),,'ko',markersize=8,
           markerfacecolor=plt.cm.plasma((layeri+1)/num_hidden),alpha=.4)

# adjust the axis properties
plt.gca().set(xticks=range(num_hidden),xticklabels=layer_labels,
              xlabel='Hidden layer (model depth)',ylabel='Cosine similarity',
              title=f'Laminar profile of inter-token cosine similarities')

plt.tight_layout()
plt.savefig('ch5_proj24_part4b.png')
plt.show()

# **Part 5: Descriptives in a larger dataset**

In [None]:
url = 'https://www.gutenberg.org/cache/epub/64317/pg64317.txt'
text = requests.get(url).text
tokens = np.array( tokenizer.encode(text) )
print(f'There are {len(tokens):,} tokens, {len(set(tokens))} of which are unique.')

# get some data
batch = torch.zeros((2,model.config.n_ctx),dtype=torch.long)
batch[0,:] =
batch[1,:] =

with torch.no_grad():
  outs = model(

# check the shape
outs.hidden_states[4].shape

In [None]:
_,axs = plt.subplots(2,2,figsize=(10,5))

# normalization function for mapping layer index onto color
norm = mpl.colors.Normalize(vmin=0,vmax=num_hidden)


for layeri in range(num_hidden):

  # extract means and stdevs of hidden state vectors
  mAbs = abs(
  m = .mean(dim=-1).flatten()
  s = .std(dim=-1).flatten()

  # plot the individual characteristics
  axs[0,0].plot(np.random.normal(layeri,.05,len(m)),m,'ko',markersize=3,alpha=.3,
                markeredgewidth=.4,markerfacecolor=plt.cm.rainbow(layeri/num_hidden))
  axs[0,1].plot()

  axs[1,0].plot()

  # and their relationships
  axs[1,1].plot(,color=mpl.cm.rainbow(norm(layeri)))


# final adjustments
axs[0,0].axhline(0,color='k',linestyle='--',linewidth=.8,zorder=-1000)
axs[0,0].set(xlabel='Hidden layer',ylabel='Arithmetic mean',title='A) Hidden-state arithmetic means')
axs[0,1].set(xlabel='Hidden layer',ylabel='L1 mean',title='B) Hidden-state L1 means')
axs[1,0].set(xlabel='Hidden layer',ylabel='Standard deviation',title='C) Vector stdevs')
axs[1,1].set(xlabel='Arithmetic mean',ylabel='Standard deviation',title='D) Relationships')


# add colorbars
sm = mpl.cm.ScalarMappable(cmap=mpl.cm.rainbow,norm=norm)
cbar = plt.colorbar(sm,ax=axs[1,1])
cbar.set_label('Hidden layer')

plt.tight_layout()
plt.savefig('ch5_proj24_part5.png')
plt.show()