|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[45] Minkowski distance, mutual information, and token positions</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

!pip install wikipedia
import wikipedia

# stats library for kendall correlation (when one variable is ordinal [sorted categorical])
import scipy.stats as stats

from sklearn.feature_selection import mutual_info_regression
from scipy.spatial.distance import minkowski

import torch
from transformers import AutoModelForCausalLM,GPT2Tokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Minkowski distances**

In [None]:
# generate some data
x = np.random.(100)
y = np.random.(100)

# implement the distance measure
for p in range(1,3):

  # manual calculation
  mink_man =

  # via scipy
  mink_sp =

  # show their equivalence
  print(f'p = {p}:\n manual: {mink_man:.3f}\n  scipy: {mink_sp:.3f}\n')

# **Part 2: Mutual information: manual and scikit-learn**

In [None]:
# the data
N = 347
x = np.random.uniform(low=,high=,size=)
y =  + np.random.normal


# 2D histogram
Z,xx,yy = np.histogram2d()

_,axs = plt.subplots(1,2,figsize=(12,4))
axs[0].plot(x,y,'ko',markersize=8,markerfacecolor=[.7,.7,.9,.3])
axs[0].set(xlabel='x',ylabel='y',title='A) Full resolution data')

h = axs[1].imshow()
axs[1].set(xlabel='x',ylabel='y',title='B) Discretized (binned) data')
axs[1].plot(x,y,'ko',markerfacecolor=[.7,.7,.9,.5],markersize=8)
plt.colorbar(h,ax=axs[1],pad=.01,label='Count')
plt.suptitle('Z,xx,yy = np.histogram2d(x,y,bins=8)', fontfamily='monospace')

plt.tight_layout()
plt.savefig('ch7_proj45_part2a.png')
plt.show()

In [None]:
eps = 1e-13

# joint entropy from proportion
p_Z = Z /
entropy_Z = -np.sum(  * np.log2() )

# single-variable entropies
p_x = np.sum(p_Z, axis=)
entropy_x =
p_y =
entropy_y =

# mutual information via direct translation of the formula
miMan =

print(f'Mutual information (manual) : {miMan:.2f}')

In [None]:
# via scikit-learn's MI function optimized for continuous variables
miSk = mutual_info_regression(
print(f'Mutual information (scikit-learn): {miSk:.2f}')

In [None]:
# impact of discretization

bincounts = np.arange(4,25)
mi_by_bincount = np.zeros(len(bincounts))

for i in range(len(mi_by_bincount)):

  Z,xx,yy = np.histogram2d(x,y,bins=

  # proportion via sum-scaling
  p_Z =
  p_x =
  p_y =

  # calculate entropy
  entropy_x = -np.sum( p_x * np.log2(
  entropy_y = -np.sum(

  # MI as difference of entropies
  entropy_Z = -np.sum( p_Z * np.log2(p_Z+eps) )
  mi_by_bincount[i] =


plt.figure(figsize=(8,4))
plt.axhline(miSk,color='r',linestyle='--',linewidth=2,label='scikit-learn')
plt.plot(bincounts,mi_by_bincount,'kh',markersize=12,markerfacecolor=[.7,.7,.9],label='Manual')
plt.gca().set(xticks=bincounts[::2],xlabel='Number of bins',ylabel='Mutual information')
plt.legend()

plt.tight_layout()
plt.savefig('ch7_proj45_part2b.png')
plt.show()

In [None]:
# correlation coefficient
np.corrcoef(x,y)

# **Part 3: Model, hooks, tokens, and activations**

In [None]:
# load pretrained GPT-2 model and tokenizer
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2-large')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')

# variable for the number of transformer layers
nLayers = len(gpt2.

gpt2.eval()

In [None]:
# hook function to store attention and MLP projections (see also part5_neuron_hookVsHiddenStates.ipynb)
activations =

def implant_hook_attn(layer_number):
  def hook(module, input, output):
    activations[f'att_{layer_number}'] =
  return hook

# and mlp layers
def implant_hook_mlp(layer_number):
  def hook(module, input, output):
    activations[f'mlp_
  return hook

# implant hooks
handles = []
for layeri in range(nLayers):
  h1 = gpt2.transformer.h[layeri]....(implant_hook_attn(layeri))
  h2 = gpt2.transformer.h[layeri]....(implant_hook_mlp(layeri))

  handles.append(h1)
  handles.append(h2)

In [None]:
# from https://en.wikipedia.org/wiki/Turkish_coffee
text = wikipedia.page('Turkish_coffee').content
tokens = tokenizer.encode(text,return_tensors='pt')
print(f'There are {} tokens, {} of which are unique.')

In [None]:
# find all the "coffee" target indices
target =
target_idxs =

# just the first 10
n_targets =
target_idxs =
target_idxs

In [None]:
# forward pass to get activations
with torch.no_grad():
  gpt2(tokens[,])

In [None]:
for k,v in activations.items():
  print(f'{k} has shape {v.shape}')

# **Part 4: Relationships between distance measures**

In [None]:
distances = np.zeros((,,3))
tokdists = np.zeros((,))

# double-loop over the word pairs
for toki in range(n_targets):
  for tokj in range(

    # extract the data
    x = activations['mlp_3'][0,target_idxs[
    y = activations['mlp_3'

    # pairwise distance measures
    distances[toki,tokj,0] = mutual_info_regression(
    distances[toki,tokj,1] = minkowski
    distances[toki,tokj,2] = minkowski

    tokdists[toki,tokj] =

In [None]:
# indices of nonzero values
triu_idx = np.triu_indices(n_targets,k=1)

_,axs = plt.subplots(1,2,figsize=(10,3.5))

# without normalization
axs[0].plot(distances[triu_idx[0],triu_idx[1],0],
            distances[,,],'ro',
            markerfacecolor=[.9,.7,.7,.7],markersize=8,label='MI vs. $L_1$')
axs[0].plot(,label='MI vs. $L_2$')
axs[0].plot(,label='$L_1$ vs. $L_2$')


# with normalization
mi = distances[triu_idx[0],triu_idx[1],0]
mi /= max(mi)
L1 =
L1
L2 =
L2

axs[1].plot(,label='MI vs. $L_1$')
axs[1].plot(,label='MI vs. $L_2$')
axs[1].plot(,label='$L_1$ vs. $L_2$')

axs[0].set(xlabel='Measure',ylabel='Measure',title='A) Non-normalized')
axs[1].set(xlabel='Measure (norm.)',ylabel='Measure (norm.)',title='B) Max-value normalized')
axs[0].legend()
axs[1].legend()

plt.tight_layout()
plt.savefig('ch7_proj45_part4.png')
plt.show()

# **Part 5: Projection distances by position distances**

In [None]:
fig,axs = plt.subplots(2,4,figsize=(11,5))

metric_labels = [ 'Mut. Info','Mink. $L_1$','Mink. $L_2$' ]

for disti in range(3):

  # this distance measure
  D = distances[:,:,]

  # show the distance heatmap
  h = .imshow(D,)
  .set(xlabel='Target token index',ylabel='Target token index',title=f'{"ABC"[disti]}) {metric_labels[disti]}')
  fig.colorbar(h,ax=axs[0,disti],pad=.02,fraction=.047)

  # correlate vector distance with token distance
  r = stats.

  # scatter plot
  .plot(,,'ks',markersize=6,markerfacecolor=[.7,.7,.9,.7])
  .set(xlabel='Position distance',ylabel=f'{metric_labels[disti]}',
            title=f"{'EFG'[disti]}) $\\mathbf{{\\tau}}={r.statistic:.2f}\\; (p={r.pvalue:.3f})$")


# inter-token position differences
h = .imshow(tokdists,origin='lower',aspect='auto',vmin=0,vmax=tokdists.max()*.8,cmap='magma')
.set(xlabel='Target token index',ylabel='Target token index',title=f'D) Position dists.')
fig.colorbar(h,ax=axs[0,3],pad=.02,fraction=.047)
axs[1,3].axis('off')

plt.tight_layout()
plt.savefig('ch7_proj45_part5.png')
plt.show()

# **Part 6: Laminar sweep of correlations**

In [None]:
summary_results = np.zeros((2,nLayers,2,3))

# initialize temp matrices (overwritten in each layer)
dists_A = np.zeros((n_targets,n_targets,3)) # attention
dists_M =  # MLP


# loop over layers
for layeri in range(nLayers):

  # double-loop over the word pairs
  for toki in range(
    for tokj in range(

      ### ATTENTION block
      # extract the data
      x = activations[f'att_{layeri}'][0,target_idxs[toki],:]
      y = activations[f'att_{layeri}'][0,target_idxs[tokj],:]

      # pairwise distance measures
      dists_A[toki,tokj,0] = mutual_info_regression
      dists_A[toki,tokj,1] = minkowski
      dists_A[toki,tokj,2] =


      ### MLP block
      # extract the data
      x =
      y =

      # pairwise distance measures
      dists_M[toki,tokj,0] =
      dists_M[toki,tokj,1] =
      dists_M[toki,tokj,2] =


  # inter-token distances (doesn't change for attn-vs-mlp)
  uDi =

  # loop over distance measures
  for i in range(3):

    # ATTENTION summary statistics
    vals = dists_A[triu_idx[0],triu_idx[1],i]
    summary_results[0,layeri,0,i] =
    summary_results[0,layeri,1,i] =

    # MLP summary statistics
    vals = dists_M[triu_idx[0],triu_idx[1],i]
    summary_results[1,layeri,0,i] =
    summary_results[1,layeri,1,i] =

In [None]:
_,axs = plt.subplots(2,2,figsize=(12,6))

c = 'rgb'
s = 'osh'

for i in range(3):

  # Attention
  d = summary_results[0,:,0,i]
  axs[0,0].plot(,label=metric_labels[i])
  axs[0,1].plot(,label=metric_labels[i])

  # MLP
  d = summary_results[1,:,0,i]
  axs[1,0].plot(,label=metric_labels[i])
  axs[1,1].plot(,label=metric_labels[i])
  # note: flip the sign of MI by scaling summary_results by [-1,1,1][i]
  #       or just plot abs(summary_results)

# horizontal line at r=0
axs[0,1].axhline(0,color='k',linestyle='--',linewidth=.3)
axs[1,1].axhline(0,color='k',linestyle='--',linewidth=.3)


# adjustments
for a in axs.flatten(): a.legend()
axs[0,0].set(xlabel='Transformer layer',ylabel='Mean distance (max-norm)',title='A) ATTENTION: Average distances')
axs[0,1].set(xlabel='Transformer layer',ylabel='Correlation coefficient',title='B) ATTENTION: Correlations with position')
axs[1,0].set(xlabel='Transformer layer',ylabel='Mean distance (max-norm)',title='C) MLP: Average distances')
axs[1,1].set(xlabel='Transformer layer',ylabel='Correlation coefficient',title='D) MLP: Correlations with position')


plt.tight_layout()
plt.savefig('ch7_proj45_part6.png')
plt.show()

# **Part 7: A standard application of mutual information**

In [None]:
# batch size parameters
n_sequences = 100
n_context = 9 # actually this + 1 b/c context is before target

# get the target indices again
target_idxs =

# remove any that are too early to have a full context
target_idxs = target_idxs[

n_targets =
n_targets

In [None]:
batch = torch.zeros((,),dtype=torch.long)

for i in range(n_sequences):
  start = target_idxs[i]-n_context
  stop  =
  batch[i,:] = tokens

batch

In [None]:
with torch.no_grad(): gpt2(batch)

In [None]:
for k,v in activations.items():
  print(f'key "{k}" has shape {v.shape}')

In [None]:
mi = np.zeros((nLayers,nLayers,3))

for i in range(nLayers):
  for j in range(i+1,nLayers):

    # mutual information across attention projection vector norms
    normi = np.linalg.norm(activations[f'att_{i}'][,,],axis=)
    normj = np.linalg.norm(activations[f'att_{j}'
    mi[i,j,0] = mutual_info_regression(

    # mutual information across MLP projection vector norms
    normi =
    normj =
    mi[i,j,1] =

    # inter-layer distances
    mi[i,j,2] =

In [None]:
_,axs = plt.subplots(2,3,figsize=(12,7))

# mutual information matrices
axs[0,0].imshow(,aspect='auto',vmin=0,vmax=.3,cmap='magma')
axs[0,1].imshow(,aspect='auto',vmin=0,vmax=.3,cmap='magma')
axs[0,2].imshow(,aspect='auto',cmap='magma')

# scatter plots
axs[1,0].plot(mi[:,:,2][np.triu_indices_from(mi[:,:,2],k=1)],
              mi[:,:,0][np.triu_indices_from(mi[:,:,2],k=1)],'ro',markerfacecolor=[.9,.7,.7,.5])
axs[1,1].plot()
axs[1,2].plot()

# plot adjustments
axs[0,0].set(xlabel='Layer index',ylabel='Layer index',title='A) Attention: Pairwise mutual info')
axs[0,1].set(xlabel='Layer index',ylabel='Layer index',title='B) MLP: Pairwise mutual info')
axs[0,2].set(xlabel='Layer index',ylabel='Layer index',title='C) Inter-layer distances')

axs[1,0].set(xlabel='Inter-layer distance',ylabel='Mutual information',title='D) Attention: MI with distance')
axs[1,1].set(xlabel='Inter-layer distance',ylabel='Mutual information',title='E) MLP: MI with distance')
axs[1,2].set(xlabel='Attention mutual information',ylabel='MLP mutual information',title='F) MI: Attention vs. MLP')

plt.tight_layout()
plt.savefig('ch7_proj45_part7.png')
plt.show()