|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating token embeddings<h1>|
|<h2>Lecture:</h2>|<h1><b>Path length and logit token prediction<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# dictionary of modelname:identifier
model_ids = {
    'small':  'gpt2',        # 124M
    'medium': 'gpt2-medium', # 355M
    'large':  'gpt2-large',  # 774M
    'xl':     'gpt2-xl'      # 1.6B
}

# load all models into a dictionary
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
models = {}
for name,id in model_ids.items():
  models[name] = AutoModelForCausalLM.from_pretrained(id).to(device)
  models[name].eval();

# Push text through the SMALL model and get hidden_states

In [None]:
# https://en.wikipedia.org/wiki/Friedrich_Nietzsche
tokz = tokenizer.encode("Nietzsche was a widely influential German philosopher. He began his career as a classical philologist, turning to philosophy early in his academic career. In 1869, aged 24, Nietzsche became the youngest professor to hold the Chair of Classical Philology at the University of Basel. Plagued by health problems for most of his life, he resigned from the university in 1879, after which he completed much of his core writing in the following decade. Nietzsche's work spans philosophical polemics, poetry, cultural criticism, and fiction while displaying a fondness for aphorism and irony. Prominent elements of his philosophy include his radical critique of truth in favour of perspectivism; a genealogical critique of religion and Christian morality and a related theory of master–slave morality; the aesthetic affirmation of life in response to both the 'death of God' and the profound crisis of nihilism; the notion of Apollonian and Dionysian forces; and a characterisation of the human subject as the expression of competing wills, collectively understood as the will to power. He also developed influential concepts such as the Übermensch and his doctrine of eternal return. In his later work, he became increasingly preoccupied with the creative powers of the individual to overcome cultural and moral mores in pursuit of new values and aesthetic health. His body of work touched a wide range of topics, including art, philology, history, music, religion, tragedy, culture, and science, and drew inspiration from Greek tragedy as well as figures such as Zoroaster, Arthur Schopenhauer, Ralph Waldo Emerson, Richard Wagner, Fyodor Dostoevsky, and Johann Wolfgang von Goethe.",return_tensors='pt')
tokz = tokz.to(device)

with torch.no_grad():
  outputs = models['small'](tokz, output_hidden_states=True, labels=tokz)

# Calculate path lengths and correlate with next-token logits

In [None]:
# initialize
pathlen = np.zeros((models['small'].config.n_layer+1,len(tokz[0])))
nextTokenLogits = np.zeros(len(tokz[0]))


for toki in range(len(tokz[0])):

  # path length from previous
  for layeri in range(1,models['small'].config.n_layer+1): # +1 b/c hs[0] is embeddings

    # extract hidden-state vectors, and norm the difference vector
    currAct = outputs.hidden_states[layeri  ][0,toki,:]
    prevAct = outputs.hidden_states[layeri-1][0,toki,:]
    pathlen[layeri,toki] = torch.norm( currAct-prevAct ).cpu()


  # logit for correct next-token prediction
  if toki<len(tokz[0])-1:
    logits = outputs.logits[0,toki,:].detach().cpu()
    nextTokenLogits[toki] = logits[tokz[0,toki+1]]


# cumulative pathlengths
cumpathlen = np.cumsum(pathlen,axis=0)
cumpathlen[0,:] = np.nan
r_pathLogits = np.corrcoef(nextTokenLogits[1:],pathlen[-1,1:])[0,1]


# show the results!
_,axs = plt.subplots(1,2,figsize=(12,4))

# path length progressions for each token
for i in range(1,cumpathlen.shape[1]):
  axs[0].plot(range(0,models['small'].config.n_layer+1),cumpathlen[:,i],'s-',alpha=.2,linewidth=.5,markersize=7,
              color=mpl.cm.plasma(i/cumpathlen.shape[1]))

axs[0].set(title='Cumulative path length (each line is a token)',xlabel='Transformer layer',ylabel='Cumulative path length')

axs[1].plot(nextTokenLogits[1:],pathlen[-1,1:],'ko',markersize=10,markerfacecolor=[.9,.7,.7],alpha=.6)
axs[1].set(title=f'r = {r_pathLogits:.2}',xlabel='Next-token logits',ylabel='Path length at final layer')

plt.tight_layout()
plt.show()

# Repeat for all models

In [None]:
_,axs = plt.subplots(2,2,figsize=(11,7))
axs = axs.flatten()
axi = 0

linecolors = [ [.9,.7,.7],[.7,.9,.7],[.7,.7,.9],[.7,.9,.9] ]


for name, model in models.items():

  ### push data through the model
  with torch.no_grad():
    outputs = model(tokz, output_hidden_states=True, labels=tokz)


  # initialize
  pathlen = np.zeros((model.config.n_layer+1,len(tokz[0])))
  nextTokenLogits = np.zeros(len(tokz[0]))


  for toki in range(len(tokz[0])):

    # path length from previous
    for layeri in range(1,model.config.n_layer+1): # +1 b/c hs[0] is embeddings
      pathlen[layeri,toki] = torch.norm( outputs.hidden_states[layeri][0,toki,:] - outputs.hidden_states[layeri-1][0,toki,:] )

    # error for correct next-token prediction
    if toki<len(tokz[0])-1:
      logits = outputs.logits[0,toki,:].detach()
      nextTokenLogits[toki] = logits[tokz[0,toki+1]]


  # cumulative pathlengths
  cumpathlen = np.cumsum(pathlen,axis=0)
  cumpathlen[0,:] = np.nan
  r_pathLogits = np.corrcoef(nextTokenLogits[1:],pathlen[-1,1:])[0,1]


  print(f'{name:>8} has pathlength-logit correlation of {r_pathLogits:.3f}')

  axs[axi].plot(nextTokenLogits[1:],pathlen[-1,1:],'ko',markersize=10,
                markerfacecolor=linecolors[axi],alpha=.6)
  axs[axi].set(title=f'{name} (r = {r_pathLogits:.2})',xlabel='Next-token logits',ylabel='Cumulative path length')
  axi += 1


plt.tight_layout()
plt.show()