|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating layers<h1>|
|<h2>Lecture:</h2>|<h1><b>Pairwise mutual information through the LLM<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# MI
from sklearn.feature_selection import mutual_info_regression

import torch

from transformers import AutoModelForCausalLM,GPT2Tokenizer

In [None]:
# load pretrained GPT-2 model and tokenizer
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

gpt2.eval()

In [None]:
# from https://en.wikipedia.org/wiki/Turkish_coffee
text = 'Turkish coffee is very finely ground coffee brewed by boiling. Any coffee bean may be used; arabica varieties are considered best, but robusta or a blend is also used.[1] The coffee grounds are left in the coffee when served.[2][3] The coffee may be ground at home in a manual grinder made for the very fine grind, ground to order by coffee merchants in most parts of the world, or bought ready-ground from many shops.'
tokens = tokenizer.encode(text,return_tensors='pt')
print(f'There are {len(tokens[0])} tokens, {len(set(tokens[0].tolist()))} of which are unique.')

In [None]:
with torch.no_grad():
  output = gpt2(tokens,output_hidden_states=True)

In [None]:
nLayers = len(output.hidden_states)
output.hidden_states[3].shape

# A function to calculate mutual information based on entropy

In [None]:
# a function for mutual information
def mutInfo_manual(x,y):

  # histogram and convert to proportion (estimate of probability)
  Z  = np.histogram2d(x,y,bins=15)[0]
  pZ = Z / Z.sum()
  px = pZ.sum(axis=1)
  py = pZ.sum(axis=0)

  # calculate entropy
  eps = 1e-12
  Hx = -np.sum( px * np.log2(px+eps) )
  Hy = -np.sum( py * np.log2(py+eps) )
  HZ = -np.sum( pZ * np.log2(pZ+eps) )

  return Hx+Hy - HZ

# Example MI between one pair of dimensions across all tokens

In [None]:
# example MI

dim1 = 4
dim2 = 5

# extract some data
x = output.hidden_states[3][0,:,dim1].detach()
y = output.hidden_states[3][0,:,dim2].detach()

# 2D histogram
Z,xx,yy = np.histogram2d(x,y,bins=15)

# mutual information
mi = mutual_info_regression(x.reshape(-1,1),y)[0]


# plotting
_,axs = plt.subplots(1,2,figsize=(11,5))

minmax = 1.1 * abs(torch.cat((x,y),dim=0)).max() # extreme value for axis
axs[0].plot(x,y,'bo',markerfacecolor=[.4,.4,.9])
axs[0].set(xlabel=f'Dimension {dim1}',ylabel=f'Dimension {dim2}',xlim=[-minmax,minmax],ylim=[-minmax,minmax],
           title='Full resolution data\nEach dot is a token')

h = axs[1].imshow(Z.T,extent=[xx[0],xx[-1],yy[0],yy[-1]],vmin=0,vmax=Z.max()*.5,origin='lower',aspect='auto')
axs[1].set(xlabel='x',ylabel='y',title=f'Discretized data (MI = {mi:.3f})')
axs[1].plot(x,y,'wo',markerfacecolor=[.4,.4,.4],alpha=.7)
plt.colorbar(h,ax=axs[1])

plt.tight_layout()
plt.show()

# All-to-all MI in each layer

In [None]:
# this cell takes waaayyyy too long
# initialize
mi = np.zeros((nLayers,gpt2.config.n_embd,gpt2.config.n_embd))

# loop over layers
for layeri in range(nLayers):

  # extract matrices to detach only once
  hidden_states = output.hidden_states[layeri].detach().squeeze()

  # double-loop over dimension pairs within this layer (skipping every second dimension)
  for dimi in range(0,gpt2.config.n_embd,2):

    mi[layeri,dimi,:] = mutual_info_regression(hidden_states,hidden_states[:,dimi])

  print(f'Finished {layeri+1:2}/{nLayers} layers')

In [None]:
layeri,dimi

In [None]:
# this cell takes ~6 mins
# initialize
mi = np.zeros((nLayers,gpt2.config.n_embd,gpt2.config.n_embd))

# loop over layers
for layeri in range(nLayers):

  # extract matrices to detach only once
  hidden_states = output.hidden_states[layeri].detach().squeeze()

  # double-loop over dimension pairs within this layer (skipping every second dimension)
  for dimi in range(0,gpt2.config.n_embd,2):
    for dimj in range(dimi+1,gpt2.config.n_embd,2):

      # pairwise mutual information
      mi[layeri,dimi,dimj] = mutInfo_manual( hidden_states[:,dimi],hidden_states[:,dimj] )

  print(f'Finished {layeri+1:2}/{nLayers} layers')

In [None]:
# visualize all pairwise MI from one layer

layer2show = 2

_,axs = plt.subplots(1,2,figsize=(10,4))

# submatrix for this layer
mitmp = mi[layer2show,:,:]

axs[0].hist(mitmp[np.nonzero(mitmp)],bins=60,color='gray',edgecolor='k')
axs[0].set(xlabel='Mutual information',ylabel='Count',title='Distribution of MI values')

h = axs[1].imshow(mitmp[::2,1::2],vmin=.5,vmax=1.2,origin='lower')
plt.colorbar(h,ax=axs[1])
axs[1].set(xlabel='Embedding dimension',ylabel='Embedding dimension',title='Pairwise mutual information')

plt.tight_layout()
plt.show()

In [None]:
# histograms from all layers
plt.figure(figsize=(10,4))

# bin boundaries
edges = np.linspace(.2,1.9,64)


for layeri in range(nLayers):

  # get distribution
  mitmp = mi[layeri,:,:]
  yy,_ = np.histogram(mitmp[np.nonzero(mitmp)],bins=edges,density=True)

  # plot
  plt.plot(edges[:-1],yy,color=mpl.cm.plasma(layeri/gpt2.config.n_layer),label=f'L{layeri}')


plt.legend()
plt.gca().set(xlim=[.2,2],xlabel='Mutual information value',ylabel='Density (pdf estimate)',title='Distribution of MI over layers')
plt.show()