|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: SVD projections of related embeddings<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load BERT tokenizer and model
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
embeddings = model.embeddings.word_embeddings.weight.detach().numpy()

# Exercise 1: Tokenize and create embeddings submatrices


In [None]:
digitTokens = np.zeros(10,dtype=int)

# find the token index for this numer
for i in range(10):

  # confirm they're all single-token words
  toks = tokenizer.encode(str(i), add_special_tokens=False)
  print(f'{len(toks)} token for "{i}"')

  digitTokens[i] = toks[0]

In [None]:
# list of 10 EU countries
EUstates = ['estonia','france','germany','italy','latvia','lithuania','netherlands','poland','romania','slovenia' ]


# find the token index for this numer
EUtokens = np.zeros(len(EUstates),dtype=int)
for i in range(len(EUstates)):

  # confirm they're all single-token words
  toks = tokenizer.encode(EUstates[i], add_special_tokens=False)
  print(f'{len(toks)} token for "{EUstates[i]}"')

  EUtokens[i] = toks[0]

In [None]:
# get the category-mean embeddings vector for later centering

# for countries (expand to a row vector for later broadcasting)
EU_embedMean = embeddings[EUtokens,:].mean(axis=0)[None,:]

# and for digits
digs_embedMean = embeddings[digitTokens,:].mean(axis=0)[None,:]


# any obvious relationship?
plt.plot(EU_embedMean,digs_embedMean,'ko',markerfacecolor=[.9,.9,.7,.6])
plt.gca().set(xlabel='EU average',ylabel='Digits average',
              title=f'Correlation = {np.corrcoef(EU_embedMean,digs_embedMean)[0,1]:.3f}')
plt.show()

In [None]:
# create embeddings submatrices

# numbers
subembDigs = embeddings[digitTokens,:] - digs_embedMean

# countries
subembEU = embeddings[EUtokens,:] - EU_embedMean

# Exercise 2: Extract and visualize the singular value spectrum

In [None]:
# SVDs (note: python returns Vt [technically Vh], so the rows of V are the singular vectors)
U_dig,s_dig,V_dig = np.linalg.svd(subembDigs)

# and for EU
U_EU,s_EU,V_EU = np.linalg.svd(subembEU)

# print sizes
print(f'Embeddings is size {subembDigs.shape}')
print(f'U  is size {U_EU.shape}')
print(f's  is size {s_EU.shape}')
print(f'Vh is size {V_EU.shape}')

In [None]:
# visualize
_,axs = plt.subplots(1,2,figsize=(12,5))

# plot their spectra
axs[0].plot(s_dig,'ks-',markerfacecolor=[.7,.7,.9],label='Digits',markersize=10)
axs[0].plot(s_EU,'ko-',markerfacecolor=[.9,.7,.7],label='EU',markersize=10)
axs[0].set(xlabel='Component (sorted index)',ylabel='Singular value',title='Spectrum of embeddings submatrices',xlim=[-.5,9.5])
axs[0].legend()

# plot the top singular vectors
axs[1].plot(V_EU[0],V_dig[0],'ko',markerfacecolor=[.7,.9,.7,.7])
axs[1].set(xlabel='EU basis vector',ylabel='Digits basis vector',
           title=f'Singular vectors (r = {np.corrcoef(V_EU[0],V_dig[0])[0,1]:.2f})')


plt.tight_layout()
plt.show()

# Exercise 3: Project embeddings onto basis vectors

In [None]:
# projection of EU-centered embeddings onto the top eigenvector (first row of Vt)
projEU = (embeddings-EU_embedMean) @ V_EU[0,:]
projDg = (embeddings-digs_embedMean) @ V_dig[0,:]

plt.figure(figsize=(10,4))

# histograms
yD,xD = np.histogram(projDg,bins=90,density=True)
yE,xE = np.histogram(projEU,bins=90,density=True)

plt.plot(xD[:-1],yD,linewidth=2,label='Digits')
plt.plot(xE[:-1],yE,linewidth=2,label='EU countries')
plt.legend()

plt.gca().set(xlabel='Embedding dimension',ylabel='Density',title='Distributions of projections')
plt.show()

In [None]:
# print top 30 projection tokens
sortidx = np.argsort(projEU)

# print the top positive projections
print('--- POSITIVE projections ---')
for i in range(30):

  # get this token
  token = tokenizer.decode(sortidx[-i])

  # print if it's not in the 'seed' list
  if not token in EUstates:
    print(f'  {projEU[sortidx[-i]]:6.3f} for "{token}"')


# repeat for top negative projections
print('\n\n\n--- NEGATIVE projections ---')
for i in range(30):

  # get this token
  token = tokenizer.decode(sortidx[i])

  # print if it's not in the 'seed' list
  if not token in EUstates:
    print(f'  {projEU[sortidx[i]]:6.3f} for "{token}"')

In [None]:
# repeat for numbers
sortidx = np.argsort(projDg)

# print the top positive projections
print('--- POSITIVE projections ---')
for i in range(30):

  # get this token
  token = tokenizer.decode(sortidx[-i])

  # print if it's not in the 'seed' list
  if not token in '0123456789':
    print(f'  {projDg[sortidx[-i]]:6.3f} for "{token}"')


# repeat for top negative projections
print('\n\n\n--- NEGATIVE projections ---')
for i in range(30):

  # get this token
  token = tokenizer.decode(sortidx[i])

  # print if it's not in the 'seed' list
  if not token in '0123456789':
    print(f'  {projDg[sortidx[i]]:6.3f} for "{token}"')

# Exercise 4: Cosine similarities of the embeddings

In [None]:
# normalize each vector to its norm (unit length)
E_digs = subembDigs  / np.linalg.norm(subembDigs, axis=1,keepdims=True)
E_EU   = subembEU / np.linalg.norm(subembEU,axis=1,keepdims=True)

# cosine similarity matrices
csM_EU = E_EU  @ E_EU.T
csM_dg = E_digs @ E_digs.T

fig,axs = plt.subplots(1,2,figsize=(10,4))

h = axs[0].imshow(csM_EU,vmin=-.5,vmax=.5)
axs[0].set(xticks=range(10),xticklabels=EUstates,yticks=range(10),yticklabels=EUstates)
axs[0].tick_params(axis='x',labelrotation=90)
fig.colorbar(h,ax=axs[0],pad=.02)

axs[1].imshow(csM_dg,vmin=-.5,vmax=.5)
axs[1].set(xticks=range(10),yticks=range(10))
fig.colorbar(h,ax=axs[1],pad=.02)

plt.tight_layout()
plt.show()