|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>Singular value spectrum of embedding submatrices<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load BERT tokenizer and model
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
embeddings = model.embeddings.word_embeddings.weight.detach().numpy()

In [None]:
# the figure in the beginning of the lecture
plt.figure(figsize=(6,5))
e1 = embeddings[tokenizer.encode('1',add_special_tokens=False)[0]]
e2 = embeddings[tokenizer.encode('2',add_special_tokens=False)[0]]
plt.plot(e1-e1.mean(),e2-e2.mean(),'ko',markerfacecolor=[.9,.7,.7,.6])
plt.gca().set(xlabel='Embedding 1',ylabel='Embedding 2')
plt.show()

# Extract submatrices and mean-center

In [None]:
# get the tokens for lower-case letters
import string
lettertoks = [tokenizer.encode(i,add_special_tokens=False)[0] for i in string.ascii_lowercase]

# and the same number of random tokens
randtoks = np.random.randint(0,high=tokenizer.vocab_size,size=len(lettertoks))

In [None]:
# create embeddings submatrices for letters and randomly selected tokens
subembLetters = embeddings[lettertoks,:]
subembRandom = embeddings[randtoks,:]

# mean-center
lettersMeanVect = subembLetters.mean(axis=0)[None,:]
randomMeanVect  = subembRandom.mean(axis=0)[None,:]

subembLetters -= lettersMeanVect
subembRandom  -= randomMeanVect

# SVD

In [None]:
# SVD of both matrices (don't need the U matrices, so those can overwrite)
U,sRandom,VtRandom = np.linalg.svd(subembRandom)
U,sLetters,VtLetters = np.linalg.svd(subembLetters)

# print sizes
print(f'Embeddings is size {subembRandom.shape}')
print(f'U  is size {U.shape}')
print(f's  is size {sLetters.shape}')
print(f'Vh is size {VtLetters.shape}')

# Visualize the spectra

In [None]:
fig = plt.figure(figsize=(12,5))
gs = GridSpec(2,2,figure=fig)

ax0 = fig.add_subplot(gs[0,0])
ax1 = fig.add_subplot(gs[1,0])
ax2 = fig.add_subplot(gs[:,1])

ax0.plot(VtLetters[0],'k')
ax0.set(xlim=[0,model.config.hidden_size],xticklabels=[],ylabel='Score',
        title='Top letters singular vector')

ax1.plot(VtRandom[0],'b')
ax1.set(xlim=[0,model.config.hidden_size],xticklabels=[],xlabel='Embedding dimension',ylabel='Score',
        title='Top random-token singular vector')


# plot their spectra
ax2.plot(sLetters,'ks-',markerfacecolor=[.7,.7,.9],label='Letters',markersize=8)
ax2.plot(sRandom,'bo-',markerfacecolor=[.9,.7,.7],label='Random',markersize=8)
ax2.set(xlabel='Component (sorted index)',ylabel='Singular value (max-norm)',title='Spectrum of embeddings submatrix')
ax2.legend()

plt.tight_layout()
plt.show()

# Project all embeddings onto the eigenvector

In [None]:
# mean-center
embeddingsCentered = embeddings - lettersMeanVect

# project all embeddings onto the singular vector
projections = embeddingsCentered @ VtLetters[0,:]

plt.figure(figsize=(12,4))
plt.plot(projections,'ko',markerfacecolor=[.7,.7,.7,.6])
plt.gca().set(xlabel='Token',ylabel='Projection')
plt.show()

In [None]:
# print top magnitude-projection tokens
sortidx = np.argsort(abs(projections))[::-1]

# print the top positive projections
for i in range(20):
  token = tokenizer.decode(sortidx[i])
  if token not in string.ascii_letters:
    print(f'{projections[sortidx[i]]:6.3f} for "{token}"')

In [None]:
# correlation matrix of all letter embeddings
plt.imshow(abs(np.corrcoef(subembLetters)),vmin=0,vmax=1)

plt.gca().set(xlabel='Tokens',ylabel='Tokens',title='|R| for letter embeddings')
plt.colorbar()
plt.show()