|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[43] Characterizing the MLP progression</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from sklearn.metrics.pairwise import cosine_similarity

# to import text of a wiki page
!pip install wikipedia
import wikipedia

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Tokens**

In [None]:
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model.eval()

n_layers = model.config.n_layer

In [None]:
# one sequence of N tokens vs N sequences of 10 tokens that end in the same token
# full page text:
text = wikipedia.page('Coffee_production_in_Vietnam').content
tokens = tokenizer.encode(text)
text

In [None]:
# confirm single-token word
target_word = ' coffee'

coffee_token = tokenizer.encode(target_word)

print(f'"{target_word}" comprises token(s): {coffee_token}')

# just the number, not a list
coffee_token = coffee_token[0]

coffee_idx = np.where(np.array(tokens)==coffee_token)[0]
ntargets = len(coffee_idx)

print(f'There are {ntargets} "{target_word}" tokens, and the first is in position {coffee_idx[0]}')

In [None]:
# one sequence of N tokens
oneLongSeq = torch.tensor(tokens[:ntargets+1]).unsqueeze(0)

nseq = 10
batch = torch.zeros((ntargets,nseq),dtype=torch.long)

for bi in range(ntargets):
  batch[bi,:] = torch.tensor(tokens[coffee_idx[bi]-nseq+1:coffee_idx[bi]+1])

batch

In [None]:
oneLongSeq.shape, batch.shape

# **Part 2: Multistage MLP activations**

In [None]:
# hook function that gets the four MLP processing stages

# dictionary to store the mlp activations
mlp_values = {}

def outerHook(whichlayer):
  def hookmlp(module,input,output):

    # run through the MLP calculations
    X1 = input[0].detach() # input to MLP
    X2 = module.c_fc(X1)   # expansion pre-gelu
    X3 = module.act(X2)    # expansion post-gelu
    # X4 = module.c_proj(X3) # projection from MLP; same as 'output' variable

    # and store
    mlp_values[f'L{whichlayer}_1'] = X1.numpy()
    mlp_values[f'L{whichlayer}_2'] = X2.numpy()
    mlp_values[f'L{whichlayer}_3'] = X3.numpy()
    mlp_values[f'L{whichlayer}_4'] = output.detach().numpy()
  return hookmlp

# implantations
handles = []
for layi in range(n_layers):
  h = model.transformer.h[layi].mlp.register_forward_hook(outerHook(layi))
  handles.append(h)

In [None]:
# labels of the four MLP stages
stagelabels = [ 'Inpu','Expa','GELU','Proj' ]

In [None]:
# forward pass to get MLP activations from one-sequence
with torch.no_grad(): model(oneLongSeq)

# copy the dictionary b/c it will overwrite
mlp_oneseq = mlp_values.copy()
mlp_values = {} # optionally re-initialize dictionary

# forward pass to get MLP activations from many-sequence
with torch.no_grad(): model(batch)
mlp_manyseq = mlp_values.copy()

In [None]:
for k,v in mlp_oneseq.items():
  print(f'"{k}" is shape {list(v.shape)}')

In [None]:
for k,v in mlp_manyseq.items():
  print(f'"{k}" is shape {list(v.shape)}')

In [None]:
# some visualizations
whichlayer = n_layers//2

_,axs = plt.subplots(1,2,figsize=(12,4))

for i in range(4):

  # name of key to plot
  keyname = f'L{whichlayer}_{i+1}'

  # one sequence
  y,x = np.histogram(mlp_oneseq[keyname][0,1:,:].flatten(),bins='fd',density=True)
  axs[0].plot(x[:-1],y,linewidth=2,label=f'L{whichlayer} {stagelabels[i]}')

  # many sequences
  y,x = np.histogram(mlp_manyseq[keyname][:,-1,:].flatten(),bins='fd',density=True)
  axs[1].plot(x[:-1],y,linewidth=2,label=f'L{whichlayer} {stagelabels[i]}')


axs[0].set(xlabel='Activation value',xlim=[-5,5],ylabel='Density',ylim=[0,4],title='A) One-sequence batch')
axs[0].legend()

axs[1].set(xlabel='Activation value',xlim=[-5,5],ylabel='Density',ylim=[0,4],title='B) Many-sequence batch')
axs[1].legend()

plt.tight_layout()
plt.savefig('ch7_proj43_part2.png')
plt.show()

# **Part 3: Cosine similarities in one layer**

In [None]:
fig,axs = plt.subplots(2,4,figsize=(12,6))

for i in range(4):

  # name of key to plot
  keyname = f'L{whichlayer}_{i+1}'

  # one sequence
  cMat = cosine_similarity(mlp_oneseq[keyname][0,1:,:].squeeze())
  h = axs[0,i].imshow(cMat,cmap='magma')
  cbh = fig.colorbar(h,ax=axs[0,i],pad=.01,fraction=.047,location='top')
  cbh.ax.tick_params(labelsize=8)
  axs[0,i].set(xticks=[],yticks=[],xlabel='Tokens',ylabel='Tokens',title=f'One seq: L{whichlayer} {stagelabels[i]}\n\n')

  # many sequences
  cMat = cosine_similarity(mlp_manyseq[keyname][:,-1,:].squeeze())
  h = axs[1,i].imshow(cMat,cmap='magma')
  cbh = fig.colorbar(h,ax=axs[1,i],pad=.01,fraction=.047,location='top')
  cbh.ax.tick_params(labelsize=8)
  axs[1,i].set(xticks=[],yticks=[],xlabel='Tokens',ylabel='Tokens',title=f'Many seqs: L{whichlayer} {stagelabels[i]}\n\n')


plt.tight_layout()
plt.savefig('ch7_proj43_part3a.png')
plt.show()

In [None]:
fig,axs = plt.subplots(1,2,figsize=(12,3.5))

for i in range(4):

  # name of key to plot
  keyname = f'L{whichlayer}_{i+1}'

  # one sequence
  cMat = cosine_similarity(mlp_oneseq[keyname][0,1:,:].squeeze())
  y,x = np.histogram(cMat[np.triu_indices(ntargets,1)],bins='fd',density=True)
  axs[0].plot(x[:-1],y,linewidth=2,label=f'L{whichlayer} {stagelabels[i]}')

  # many sequences
  cMat = cosine_similarity(mlp_manyseq[keyname][:,-1,:].squeeze())
  y,x = np.histogram(cMat[np.triu_indices(ntargets,1)],bins='fd',density=True)
  axs[1].plot(x[:-1],y,linewidth=2,label=f'L{whichlayer} {stagelabels[i]}')

axs[0].legend()
axs[1].legend()
axs[0].set(xlabel='Cosine similarity',ylabel='Density',ylim=[0,None],title='A) One sequence')
axs[1].set(xlabel='Cosine similarity',ylabel='Density',ylim=[0,None],title='B) Many sequences')

plt.tight_layout()
plt.savefig('ch7_proj43_part3b.png')
plt.show()

In [None]:
fig,axs = plt.subplots(1,3,figsize=(10,3.5))

# cosine similiarities between input and projection
cMat_s1s4 = cosine_similarity(mlp_oneseq[f'L{whichlayer}_1'][0,1:,:].squeeze(),mlp_oneseq[f'L{whichlayer}_4'][0,1:,:].squeeze())
h = axs[0].imshow(cMat_s1s4,cmap='plasma',vmin=-.3,vmax=.3)
fig.colorbar(h,ax=axs[0],location='top',pad=.01,fraction=.047)
axs[0].set(xticks=[],yticks=[],xlabel='Tokens (MLP input)',ylabel='Tokens (MLP projection)',title='A) Input and projection\n\n')

# cosine similarities between pre- and post-gelu
cMat_s2s3 = cosine_similarity(mlp_oneseq[f'L{whichlayer}_2'][0,1:,:].squeeze(),mlp_oneseq[f'L{whichlayer}_3'][0,1:,:].squeeze())
h = axs[1].imshow(cMat_s2s3,cmap='plasma')#,vmin=-.3,vmax=.3)
fig.colorbar(h,ax=axs[1],location='top',pad=.01,fraction=.047)
axs[1].set(xticks=[],yticks=[],xlabel='Tokens (expansion)',ylabel='Tokens (post-GELU)',title='B) Expansion and GELU\n\n')

# histograms
yOne,xOne = np.histogram(cMat_s1s4.flatten(),bins='fd')
yMany,xMany = np.histogram(cMat_s2s3.flatten(),bins='fd')
axs[2].plot(xOne[:-1],yOne,linewidth=2,label='Stages 1,4')
axs[2].plot(xMany[:-1],yMany,linewidth=2,label='Stages 2,3')
axs[2].legend(fontsize=8)
axs[2].set(xlabel='Cosine similarity',ylabel='Count',ylim=[0,None],title='C) Distributions\n')

plt.tight_layout()
plt.savefig('ch7_proj43_part3c.png')
plt.show()

# **Part 4: RSA in one layer**

In [None]:
def calc_RSA(keyname1,keyname2):

  ### ---  for one-seq
  cMat = cosine_similarity(mlp_oneseq[keyname1][0,1:,:])
  x1 = cMat[np.triu_indices(ntargets,1)]

  cMat = cosine_similarity(mlp_oneseq[keyname2][0,1:,:])
  x2 = cMat[np.triu_indices(ntargets,1)]
  oneSeq_r = np.corrcoef(x1,x2)[0,1]

  ### --- for many-seq
  cMat = cosine_similarity(mlp_manyseq[keyname1][:,-1,:])
  x1 = cMat[np.triu_indices(ntargets,1)]

  cMat = cosine_similarity(mlp_manyseq[keyname2][:,-1,:])
  x2 = cMat[np.triu_indices(ntargets,1)]
  manySeq_r = np.corrcoef(x1,x2)[0,1]

  return oneSeq_r,manySeq_r

In [None]:
_,axs = plt.subplots(1,2,figsize=(10,3))

pairnames = []
row = 0

for i in range(4):
  for j in range(i+1,4):

    # define keynames for MLP dictionary
    keyname1 = f'L{whichlayer}_{i+1}'
    keyname2 = f'L{whichlayer}_{j+1}'
    pairnames.append( f'{stagelabels[i][0]}.v.{stagelabels[j][0]}' )

    # get the RSA correlation scores
    oneSeq_r,manySeq_r = calc_RSA(keyname1,keyname2)

    # and visualize
    axs[0].bar(row,oneSeq_r)
    axs[1].bar(row,manySeq_r)
    row += 1

axs[0].set(xlabel='Pair',ylabel='RSA score',ylim=[.3,1],title='A) One-sequence',
              xticks=range(6),xticklabels=pairnames)
axs[1].set(xlabel='Pair',ylabel='RSA score',ylim=[.3,1],title='B) Many-sequence',
              xticks=range(6),xticklabels=pairnames)

plt.tight_layout()
plt.savefig('ch7_proj43_part4.png')
plt.show()

# **Part 5: Participation ratio and effective dimensionality in one layer**

In [None]:
def calc_partrat_effdim(keyname):

  ### --- for one-sequence
  # calculate participation ratio
  X = mlp_oneseq[keyname][0,1:,:]         # extract the data
  X = X - X.mean(axis=0,keepdims=True)    # mean-center
  s = np.linalg.svd(X,compute_uv=False)   # get singular values
  s2 = s**2                               # square the sigmas
  part_rat_oneseq = sum(s2)**2/sum(s2**2) # participation ratio

  # calculate effective dimensionality
  cumPctVar = 100*np.cumsum(s2)/s2.sum()
  effective_dim_oneseq = np.where(cumPctVar>effectiveDimThresh)[0][0]


  ### --- for many-sequence
  # calculate participation ratio
  X = mlp_manyseq[keyname][:,-1,:]       # extract the data
  X = X - X.mean(axis=0,keepdims=True)    # mean-center
  s = np.linalg.svd(X,compute_uv=False)   # get singular values
  s2 = s**2                               # square the sigmas
  part_rat_manyseq = sum(s2)**2/sum(s2**2)# participation ratio

  # calculate effective dimensionality
  cumPctVar = 100*np.cumsum(s2)/s2.sum()
  effective_dim_manyseq = np.where(cumPctVar>effectiveDimThresh)[0][0]


  return part_rat_oneseq,effective_dim_oneseq,part_rat_manyseq,effective_dim_manyseq

In [None]:
effectiveDimThresh = 95


fig,axs = plt.subplots(2,3,figsize=(12,5))

keynames = ['']*4

for i in range(4):

  # name of dictionary key
  keynames[i] = f'L{whichlayer} {stagelabels[i]}'

  part_rat_oneseq,effective_dim_oneseq,part_rat_manyseq,effective_dim_manyseq = calc_partrat_effdim(f'L{whichlayer}_{i+1}')

  # plot in the plots
  axs[0,0].bar(i,part_rat_oneseq)
  axs[0,1].bar(i,effective_dim_oneseq)
  axs[0,2].plot(part_rat_oneseq,effective_dim_oneseq,'kh',markersize=12,markerfacecolor='y')

  axs[1,0].bar(i,part_rat_manyseq)
  axs[1,1].bar(i,effective_dim_manyseq)
  axs[1,2].plot(part_rat_manyseq,effective_dim_manyseq,'kh',markersize=12,markerfacecolor='y')

for i in range(2):
  axs[i,0].set(xticks=range(4),xticklabels=keynames,title=f'{"AD"[i]}) Participation ratio',ylabel='PR')
  axs[i,1].set(xticks=range(4),xticklabels=keynames,title=f'{"BE"[i]}) Effective dim.',ylabel='Eff. dim.')
  axs[i,2].set(xlabel='Participation ratio',ylabel='Effective dim.',title=f'{"CF"[i]}) Comparison')

plt.suptitle(f'Dimensionality results from layer {whichlayer}')

plt.tight_layout()
plt.savefig('ch7_proj43_part5.png')
plt.show()

# **Part 6: Laminar profiles of RSA and PR**

In [None]:
dimension_results = np.zeros((n_layers,4,4))
RSA_results = np.zeros((n_layers,2,6))


# Order in 3rd dimension of 'dimension_results':
#  0: one-seq: part_rat
#  1: one-seq: effective_dim
#  2: many-seq: part_rat
#  3: many-seq: effective_dim


for layeri in range(n_layers):

  row = 0
  pairnames = [] # needlessly recalculated, but that's OK

  for i in range(4):

    keyname1 = f'L{layeri}_{i+1}'

    # order: part_rat_oneseq, effective_dim_oneseq, part_rat_manyseq, effective_dim_manyseq
    dimension_results[layeri,i,:] = calc_partrat_effdim(keyname1)

    for j in range(i+1,4):

      # define keynames for MLP dictionary
      keyname2 = f'L{layeri}_{j+1}'
      pairnames.append( f'{stagelabels[i][0]}.v.{stagelabels[j][0]}' )

      # get the RSA correlation scores
      oneSeq_r,manySeq_r = calc_RSA(keyname1,keyname2)
      RSA_results[layeri,:,row] = calc_RSA(keyname1,keyname2)
      row += 1

In [None]:
# Visualization 1: scatter plots of everything

fig,axs = plt.subplots(3,3,figsize=(12,8))

colors = 'rbmkyc'
shapes = 'so^P*X'

xticks = np.arange(n_layers)

### --- laminar profile of participation ratio
for i in range(4):

  # PR one-seq
  axs[0,0].plot(xticks,dimension_results[:,i,0],f'{colors[i]}{shapes[i]}-',markersize=8,linewidth=.5,
                  markeredgewidth=.5,markerfacecolor=colors[i],alpha=.7)

  # PR many-seq
  axs[1,0].plot(xticks,dimension_results[:,i,2],f'{colors[i]}{shapes[i]}-',markersize=8,linewidth=.5,
                  markeredgewidth=.5,markeredgecolor='k',markerfacecolor=colors[i],alpha=.7)

  # effective dimensionality one-seq
  axs[0,1].plot(xticks,dimension_results[:,i,1],f'{colors[i]}{shapes[i]}-',markersize=8,linewidth=.5,
                  markeredgewidth=.5,markeredgecolor='k',markerfacecolor=colors[i],alpha=.7)

  # effective dimensionality many-seq
  axs[1,1].plot(xticks,dimension_results[:,i,3],f'{colors[i]}{shapes[i]}-',markersize=8,linewidth=.5,
                  markeredgewidth=.5,markeredgecolor='k',markerfacecolor=colors[i],alpha=.7)


  ### --- relationship between the two dimensionality metrics
  axs[2,0].scatter(dimension_results[:,i,0],dimension_results[:,i,1],100,marker=shapes[i],
              edgecolor='k',c=plt.cm.plasma(xticks/n_layers),alpha=.7)
  axs[2,1].scatter(dimension_results[:,i,2],dimension_results[:,i,3],100,marker=shapes[i],
              edgecolor='k',c=plt.cm.plasma(xticks/n_layers),alpha=.7)


axs[0,0].set(ylabel='Participation ratio',title='A) One sequence (part.rat)')
axs[0,1].set(ylabel='Effective dim.',title='B) One sequence (effect.dim)')
axs[1,0].set(ylabel='Participation ratio',title='D) Many sequence (part.rat)')
axs[1,1].set(ylabel='Effective dim.',title='E) Many sequence (effect.dim)')

axs[2,0].set(xlabel='Participation ratio',ylabel='Effective dim.',
             title=f'G) One-seq comparison (r={np.corrcoef(dimension_results[:,:,0].flatten(),dimension_results[:,:,1].flatten())[0,1]:.2f})')
axs[2,1].set(xlabel='Participation ratio',ylabel='Effective dim.',
             title=f'H) Many-seq comparison (r={np.corrcoef(dimension_results[:,:,2].flatten(),dimension_results[:,:,3].flatten())[0,1]:.2f})')

axs[0,0].legend(stagelabels,fontsize=9,ncol=2)
axs[0,1].legend(stagelabels,fontsize=9,ncol=2)


### --- RSA visualization
for i in range(6):
  # RSA for one-seq
  axs[0,2].plot(xticks,RSA_results[:,0,i],f'{colors[i]}{shapes[i]}-',markersize=8,linewidth=.5,
                  markeredgewidth=.5,markerfacecolor=colors[i],alpha=.7)

  # RSA for many-seq
  axs[1,2].plot(xticks,RSA_results[:,1,i],f'{colors[i]}{shapes[i]}-',markersize=8,linewidth=.5,
                  markeredgewidth=.5,markerfacecolor=colors[i],alpha=.7)

  axs[2,2].scatter(RSA_results[:,0,i],RSA_results[:,1,i],100,marker=shapes[i],
              edgecolor='k',c=plt.cm.plasma(xticks/n_layers),alpha=.7)

axs[0,2].legend(pairnames,fontsize=9,ncol=2)
axs[1,2].legend(pairnames,fontsize=9,ncol=2)
axs[0,2].set(ylabel='RSA score',ylim=[.1,1],title='C) One sequence (RSA)')
axs[1,2].set(ylabel='RSA score',ylim=[.1,1],title='F) Many sequence (RSA)')
axs[2,2].set(xlabel='One-seq',ylabel='Many-seq',xlim=[.1,1],ylim=[.1,1],title='I) RSA comparison')



# colorbars for layer index
cmappable = ScalarMappable(Normalize(0,n_layers),cmap=plt.cm.plasma)
for a in axs[-1,:]:
  cbh = fig.colorbar(cmappable,ax=a,pad=.01)
  cbh.ax.tick_params(labelsize=8)


plt.tight_layout()
plt.savefig('ch7_proj43_part6a.png')
plt.show()

In [None]:
# Visualization 2: violin plots
fig,axs = plt.subplots(1,4,figsize=(12,3))

titles = [ 'A) One-seq part.rat','B) One-seq eff.dim','C) Many-seq part.rat','D) Many-seq eff.dim' ]

for i in range(4):
  # draw the plot
  v = axs[i].violinplot(dimension_results[:,:,i])

  # adjust the axes
  ylabel = ['Participation ratio','Effective dim.'][i%2]
  axs[i].set(xticks=range(1,5),xticklabels=stagelabels,title=titles[i],ylabel=ylabel)

  # change the colors of the violin plots
  for vi in range(4):
    v['bodies'][vi].set_facecolor(colors[vi])
    v['bodies'][vi].set_alpha(.7)
  v['cbars'].set_edgecolor('k')
  v['cmins'].set_edgecolor('k')
  v['cmaxes'].set_edgecolor('k')

plt.tight_layout()
plt.savefig('ch7_proj43_part6b.png')
plt.show()