|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[47] Supervised probing with XGBoost</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import xgboost as xgb
from sklearn.model_selection import train_test_split

from datasets import load_dataset

import torch
from transformers import AutoModelForCausalLM,GPT2Tokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: XGBoost in a toy example**

In [None]:
samplesize = 500
n_features = 10

# class 0 has mean 0
X0 = np.random.normal(loc=0,scale=1,size=(samplesize,n_features))

# class 1 has varying means
X1 = np.random.normal(loc=np.linspace(.7,-.7,n_features),scale=1,size=(samplesize,n_features))

X = np.vstack([X0,X1])
y = np.hstack([np.zeros(samplesize),np.ones(samplesize)])

X.shape

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,3))

h = axs[0].imshow(X,aspect='auto',vmin=-1,vmax=1,cmap='plasma')
fig.colorbar(h,ax=axs[0],pad=.01)
axs[0].axhline(samplesize,linestyle='--',color='k')
axs[0].set(xlabel='Feature (variable)',ylabel='Observation',title='A) Full data matrix',)

for i in range(n_features):
  axs[1].plot(np.zeros(samplesize)+i/n_features/2-.25,X0[:,i],'ko',markerfacecolor=plt.cm.rainbow(i/n_features),alpha=.3)
  axs[1].plot(np.ones(samplesize)+i/n_features/2-.25,X1[:,i],'ks',markerfacecolor=plt.cm.rainbow(i/n_features),alpha=.3)

axs[1].set(xlim=[-1,2],xticks=[0,1],xticklabels=['Data 0','Data 1'],ylabel='Data value',title='B) Scatter plot of all data')

plt.tight_layout()
plt.savefig('ch7_proj47_part1a.png')
plt.show()

In [None]:
# train/test split
X_train,X_test, y_train,y_test = train_test_split(X,y,test_size=.25,stratify=y)

# create an XGBoost classifier object
model = xgb.XGBClassifier(
    n_estimators  = 100, # number of sequential trees
    max_depth     = 2,   # model complexity (depth of tree)
    learning_rate = .05, # learning rate
    subsample     = .8,  # percent data to sample at each iteration
    colsample_bytree=.8, # percent features to sample
    eval_metric   = 'logloss'
)

# fit the model to the train set
model.fit(X_train,y_train)

# test performance on train set
yHat_train = model.predict(X_train)

# and on the test set
yHat_test = model.predict(X_test)

# print the results
print(f'Train accuracy: {np.mean(yHat_train==y_train):.2%}')
print(f'Test accuracy: {np.mean(yHat_test==y_test):.2%}')

In [None]:
# feature importance
booster = model.get_booster()
importances = booster.get_score(importance_type='gain')

# and visualize
plt.figure(figsize=(8,3))
plt.plot(importances.values(),'kh',markerfacecolor=[.9,.7,.7],markersize=15)
plt.gca().set(xlabel='Features',xticks=range(n_features),xticklabels=importances.keys(),ylabel='Gain')

plt.tight_layout()
plt.savefig('ch7_proj47_part1b.png')
plt.show()

# **Part 2: Get "that" categories**

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [None]:
# import wikitext training data
wikitxt = load_dataset('wikitext','wikitext-2-raw-v1',split='train')
tokens = tokenizer.encode('\n\n'.join(wikitxt['text']))

In [None]:
print(f'There are {len(tokens):,} tokens in the wikitext dataset')

# token id for " that"
that_token = tokenizer.encode(' that')[0]

In [None]:
# words that typically follow "that" as complementizer
post_comp = [' he', ' she', ' it', ' they', ' we', ' I', ' you',
             ' the', ' a', ' an', ' this', ' that', ' these', ' those']
post_comp_toks = [tokenizer.encode(w)[0] for w in post_comp]

# words that typically follow "that" as demonstrative pronoun
post_dp = [' is', ' was', ' are', ' were', ' has', ' have', ' had', ' do', ' did', ' does',
           ' works', ' happened', ' sucks', ' matters', ' helps', ' fails', ' changed']
post_dp_toks = [tokenizer.encode(w)[0] for w in post_dp]

comp_idx = []   # complementizer "that"
detr_idx = []   # demonstrative determiner "that"
pron_idx = []   # demonstrative pronoun "that" (used here only as a rejective filter)


for i in range(len(tokens)-1):

  # skip if its not a "that" token
  if tokens[i] != that_token: continue

  # complementizer if "that + pronoun" or "that + determiner"
  if tokens[i+1] in post_comp_toks:
    comp_idx.append(i)

  # demonstrative pronoun if "that + finite verb"
  elif tokens[i+1] in post_dp_toks:
    pron_idx.append(i)

  # demonstrative determiner otherwise
  else:
    detr_idx.append(i)

print('There are')
print(f'  {len(comp_idx)} complementizer "that"s')
print(f'  {len(detr_idx)} demonstrative determiner "that"s')
print(f'  {len(pron_idx)} demonstrative pronoun "that"s')

In [None]:
# some data parameters
batchsize    = 2500 # sample size
context_pre  = 5    # tokens before each target

In [None]:
# initialize batches
batch_detr = torch.zeros((batchsize,context_pre+1),dtype=torch.long)
batch_comp = torch.zeros((batchsize,context_pre+1),dtype=torch.long)

# select random tokens
detr_tokens = np.random.choice(detr_idx,batchsize,replace=False)
comp_tokens = np.random.choice(comp_idx,batchsize,replace=False)

### create batches
for i in range(batchsize):
  batch_detr[i,:] = torch.tensor(tokens[detr_tokens[i]-context_pre:detr_tokens[i]+1])
  batch_comp[i,:] = torch.tensor(tokens[comp_tokens[i]-context_pre:comp_tokens[i]+1])


print('batch_detr has shape:',list(batch_detr.shape))
print('And it looks like this:\n',batch_detr)
print('\n\n')

print('batch_comp has shape:',list(batch_comp.shape))
print('And it looks like this:\n',batch_comp)

# **Part 3: Get MLP activations**

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

gpt2 = AutoModelForCausalLM.from_pretrained('gpt2-medium')
gpt2 = gpt2.to(device)
gpt2.eval()

In [None]:
n_layers = len(gpt2.transformer.h)
print(f'There are {n_layers} transformer layers')

In [None]:
# dictionary to store the mlp activations
mlp_acts = {}

def implant_hook(layer_number):
  def hook(module, input, output):
    mlp_acts[f'L{layer_number}'] = output.detach().cpu().numpy()
  return hook

# implant the hooks into the 'act' (gelu) layer
handles = []
for layi in range(n_layers):
  h = gpt2.transformer.h[layi].mlp.act.register_forward_hook(implant_hook(layi))
  handles.append(h)

In [None]:
# process the determiner tokens
with torch.no_grad():
  outs = gpt2(batch_detr.to(device))

# note: 'mlp_acts' gets overwritten! make sure to copy beforehand :)
detr_mlp = mlp_acts.copy()
mlp_acts = {}
logits_clean_detr = outs.logits[:,-1,:].cpu()


# repeat for complementizer tokens
with torch.no_grad():
  outs = gpt2(batch_comp.to(device)) # variable 'outs' overwrites, but that saves on GPU RAM
comp_mlp = mlp_acts.copy()
logits_clean_comp = outs.logits[:,-1,:].cpu()

# remove the handles
for h in handles:
  h.remove()


In [None]:
# check sizes of activations matrices
for k,v in detr_mlp.items():
  print(f'"{k}" has shape {list(v.shape)}')

In [None]:
# visualize
_,axs = plt.subplots(1,3,figsize=(12,4))

whichlayer = 'L10'

axs[0].imshow(detr_mlp[whichlayer][:,-1,:],aspect='auto',vmin=-.1,vmax=.1,cmap='plasma')
axs[0].set(xlabel='Expansion neurons',ylabel='Batch sequence',title='A) Determiner words activations')

axs[1].imshow(comp_mlp[whichlayer][:,-1,:],aspect='auto',vmin=-.1,vmax=.1,cmap='plasma')
axs[1].set(xlabel='Expansion neurons',ylabel='Batch sequence',title='B) Complementizer words activations')

binbounds = np.linspace(-.2,.2,101)
yd,_ = np.histogram(detr_mlp[whichlayer][:,-1,:].flatten(),bins=binbounds,density=True)
yc,_ = np.histogram(comp_mlp[whichlayer][:,-1,:].flatten(),bins=binbounds,density=True)

axs[2].plot(binbounds[:-1],yd,linewidth=2,label='Determiner')
axs[2].plot(binbounds[:-1],yc,linewidth=2,label='Complementizer')
axs[2].legend()
axs[2].set(xlabel='Activation value',ylabel='Density',ylim=[0,None],xlim=[binbounds[0],binbounds[-1]],title='C) Distributions')

plt.suptitle(f'MLP post-GELU expansion activations from layer {whichlayer}',fontsize=16,fontweight='bold')

plt.tight_layout()
plt.savefig('ch7_proj47_part3a.png')
plt.show()

In [None]:
histbounds = np.linspace(-.2,.5,123)

hist_diffs = np.zeros((n_layers,len(histbounds)-1))
for layi in range(n_layers):
  yi,_ = np.histogram(detr_mlp[f'L{layi}'][:,-1,:].flatten(),bins=histbounds,density=True)
  yt,_ = np.histogram(comp_mlp[f'L{layi}'][:,-1,:].flatten(),bins=histbounds,density=True)
  hist_diffs[layi,:] = yt-yi


plt.figure(figsize=(10,4))
plt.imshow(hist_diffs,origin='lower',aspect='auto',vmin=-.1,vmax=.1,cmap='plasma',
           extent=[histbounds[0],histbounds[-1],0,n_layers-1])
plt.colorbar(pad=.02)
plt.gca().set(xlabel='Activation values',ylabel='Layer index',
              title='Density difference for determiner minus complementizer')

plt.tight_layout()
plt.savefig('ch7_proj47_part3b.png')
plt.show()

# **Part 4: XGBoost in one layer**

In [None]:
data = np.vstack((detr_mlp['L4'][:,-1,:],comp_mlp['L4'][:,-1,:]))
labels = np.hstack([np.zeros(batchsize),np.ones(batchsize)])

print(f'Data matrix is size {data.shape} and labels vector is size {labels.shape}')

In [None]:
# train/test split
data_train,data_test, labels_train,labels_test = train_test_split(data,labels,test_size=.25,stratify=labels)

# create an XGBoost classifier object
xgb_model = xgb.XGBClassifier(
    n_estimators  = 100, # number of boosting iterations
    max_depth     = 2,   # model complexity (depth of tree)
    learning_rate = .05, # learning rate
    subsample     = .8,  # percent data to sample at each iteration
    colsample_bytree=.8, # percent features to sample at each iteration
    reg_alpha     = 5,   # L1 sparsity
    reg_lambda    = 5,   # L2 shrinkage
    eval_metric   = 'logloss'
)


# fit the model to the train set
xgb_model.fit(data_train,labels_train)

# test performance on train set
yHat_train = xgb_model.predict(data_train)

# and on the test set
yHat_test = xgb_model.predict(data_test)

# print the results
print(f'Train accuracy: {np.mean(yHat_train==labels_train):.2%}')
print(f'Test accuracy: {np.mean(yHat_test==labels_test):.2%}')

In [None]:
# feature importance
importances = xgb_model.get_booster().get_score(importance_type='gain')

# dense vector of importances
imp_by_features = np.random.normal(0,.03,data.shape[1])
for ni in importances.keys():
  imp_by_features[int(ni[1:])] += 1


# and visualize
fig = plt.figure(figsize=(11,3))
gs = GridSpec(1,3)
axs1 = fig.add_subplot(gs[:-1])
axs2 = fig.add_subplot(gs[-1])

axs1.plot(imp_by_features,'kh',markerfacecolor=[.9,.7,.7,.7],markeredgewidth=.3,markersize=5)
axs1.set(xlabel='Neuron index',ylabel='Gain',yticks=[0,1],yticklabels=['Unused','Used'],ylim=[-.3,1.3],
           title=f'A) Scatter plot of important features ({len(importances.keys())}/{len(imp_by_features)} neurons)')

axs2.hist(importances.values(),bins='fd',edgecolor='k',facecolor='m')
axs2.set(xlabel='Gain value',ylabel='Count',title='B) Histogram of feature gains')

plt.tight_layout()
plt.savefig('ch7_proj47_part4.png')
plt.show()

# **Part 5: Laminar sweep of XGBoost**

In [None]:
results_clean = np.zeros((n_layers,2))

important_neurons = []

for layeri in range(n_layers):

  # collect new data
  data = np.vstack((detr_mlp[f'L{layeri}'][:,-1,:],comp_mlp[f'L{layeri}'][:,-1,:]))
  data_train,data_test, labels_train,labels_test = train_test_split(data,labels,test_size=.25,stratify=labels)

  # need to recreate object b/c weights are stored not re-initialized on .fit() call
  xgb_model = xgb.XGBClassifier(
      n_estimators  = 100, # number of boosting iterations
      max_depth     = 2,   # model complexity (depth of tree)
      learning_rate = .05, # learning rate
      subsample     = .8,  # percent data to sample at each iteration
      colsample_bytree=.8, # percent features to sample at each iteration
      reg_alpha     = 5,   # L1 sparsity
      reg_lambda    = 5,   # L2 shrinkage
      eval_metric   = 'logloss'
  )

  # fit the model to the new data
  xgb_model.fit(data_train,labels_train)

  # test performances
  yHat_train = xgb_model.predict(data_train)
  yHat_test = xgb_model.predict(data_test)

  # capture the results
  results_clean[layeri,0] = np.mean(yHat_train==labels_train)
  results_clean[layeri,1] = np.mean(yHat_test==labels_test)

  # save the diagnostic neurons
  importances = xgb_model.get_booster().get_score(importance_type='gain')
  important_neurons.append( [int(k[1:]) for k in importances.keys()] )

  print(f'Finished layer {layeri+1:2}/{n_layers} with {results_clean[layeri,1]:.1%} test accuracy.')

In [None]:
# visualize
_,axs = plt.subplots(1,2,figsize=(12,3.5))

axs[0].plot(np.arange(n_layers)-.1,results_clean[:,0],'ro-',linewidth=.5,
            markerfacecolor=[.9,.7,.7],markersize=12,label='Train accuracy')
axs[0].plot(np.arange(n_layers)+.1,results_clean[:,1],'gs-',linewidth=.5,
            markerfacecolor=[.7,.9,.7],markersize=12,label='Test accuracy')
axs[0].legend()
axs[0].set(xlabel='Transformer layer',ylabel='Accuracy',title='A) Train and test accuracy by layer')

axs[1].plot([len(i) for i in important_neurons],'kh',markerfacecolor=[.7,.7,.9],markersize=12)
axs[1].set(xlabel='Transformer layer',ylabel='Number of "important" neurons',title='B) Non-zero-gain neurons per layer')

plt.tight_layout()
plt.savefig('ch7_proj47_part5.png')
plt.show()

# **Part 6: Ablate the "important" neurons**

In [None]:
# initialize
results_ablate = np.zeros_like(results_clean)
logit_diffs_norms = np.zeros((2,n_layers))
important_neurons_ablate = []

# loop over layers
for layeri in range(n_layers):


  ### ------ hook to manipulate this layer
  def mlp_ablate_hook(module,input,output):

    # zero-out the important neurons in this layer
    out = output.clone() # copy of the output
    out[:,-1,important_neurons[layeri]] = 0

    # and store those activations
    mlp_acts['acts'] = out.detach().cpu().numpy()

    return out # return the modified output

  handle = gpt2.transformer.h[layeri].mlp.act.register_forward_hook(mlp_ablate_hook)
  ### ------ hook to manipulate this layer


  # forward passes
  mlp_acts = {} # re-initialize dictionary
  with torch.no_grad():
    outs = gpt2(batch_detr.to(device))
  detr_mlp = mlp_acts.copy()

  # compare logits matrices
  diffmat = outs.logits[:,-1,:].cpu()-logits_clean_detr
  logit_diffs_norms[0,layeri] = torch.norm( diffmat ).item()



  mlp_acts = {} # re-initialize dictionary
  with torch.no_grad():
    outs = gpt2(batch_comp.to(device))
  comp_mlp = mlp_acts.copy()

  # remove the hook
  handle.remove()

  # compare logits matrices
  diffmat = outs.logits[:,-1,:].cpu()-logits_clean_comp
  logit_diffs_norms[1,layeri] = torch.norm( diffmat ).item()




  ### --- XGBoost analysis --- ###
  # collect new data
  data = np.vstack((detr_mlp['acts'][:,-1,:],comp_mlp['acts'][:,-1,:]))
  data_train,data_test, labels_train,labels_test = train_test_split(data,labels,test_size=.25,stratify=labels)

  # need to recreate object b/c weights are stored not re-initialized on .fit() call
  xgb_model = xgb.XGBClassifier(
      n_estimators  = 100, # number of boosting iterations
      max_depth     = 2,   # model complexity (depth of tree)
      learning_rate = .05, # learning rate
      subsample     = .8,  # percent data to sample at each iteration
      colsample_bytree=.8, # percent features to sample at each iteration
      reg_alpha     = 5,   # L1 sparsity
      reg_lambda    = 5,   # L2 shrinkage
      eval_metric   = 'logloss'
  )

  # fit the model to the new data
  xgb_model.fit(data_train,labels_train)

  # test performances
  yHat_train = xgb_model.predict(data_train)
  yHat_test = xgb_model.predict(data_test)

  # capture the results
  results_ablate[layeri,0] = np.mean(yHat_train==labels_train)
  results_ablate[layeri,1] = np.mean(yHat_test==labels_test)

  # save the diagnostic neurons
  importances = xgb_model.get_booster().get_score(importance_type='gain')
  important_neurons_ablate.append( [int(k[1:]) for k in importances.keys()] )

  print(f'Finished layer {layeri+1:2}/{n_layers} with {results_ablate[layeri,1]:.1%} test accuracy.')

In [None]:
# visualize
_,axs = plt.subplots(1,3,figsize=(12,3.5))

# train and test performance from the ablated LLM
axs[0].plot(np.arange(n_layers)-.1,results_ablate[:,0],'ro-',linewidth=.5,
            markerfacecolor=[.9,.7,.7],markersize=12,label='Train accuracy')
axs[0].plot(np.arange(n_layers)+.1,results_ablate[:,1],'gs-',linewidth=.5,
            markerfacecolor=[.7,.9,.7],markersize=12,label='Test accuracy')
axs[0].legend()
axs[0].set(xlabel='Transformer layer',ylabel='Accuracy',title='A) Classification accuracies')


# differences from clean model
axs[1].plot(np.arange(n_layers)-.1,results_clean[:,0]-results_ablate[:,0],'ro-',linewidth=.5,
            markerfacecolor=[.9,.7,.7],markersize=12,label='Train accuracy')
axs[1].plot(np.arange(n_layers)+.1,results_clean[:,1]-results_ablate[:,1],'gs-',linewidth=.5,
            markerfacecolor=[.7,.9,.7],markersize=12,label='Test accuracy')
axs[1].legend()
axs[1].set(xlabel='Transformer layer',ylabel='Accuracy',ylim=[-.05,.05],title='B) Change from "clean" run')
axs[1].axhline(0,linestyle='--',color='k',linewidth=.5)


# number of "important" neurons
axs[2].plot([len(i) for i in important_neurons_ablate],'kh',markerfacecolor=[.7,.7,.9],markersize=12)
axs[2].set(xlabel='Transformer layer',ylabel='Number of "important" neurons',title='C) Non-zero-gain neurons per layer')

plt.tight_layout()
plt.savefig('ch7_proj47_part6a.png')
plt.show()

In [None]:
plt.figure(figsize=(9,3))

plt.plot(logit_diffs_norms[0,:],'ro-',linewidth=.5,
            markerfacecolor=[.9,.7,.7],markersize=12,label='Determiner')
plt.plot(logit_diffs_norms[1,:],'bs-',linewidth=.5,
            markerfacecolor=[.7,.7,.9],markersize=12,label='Complementizer')

plt.legend()
plt.gca().set(xlabel='Ablated layer',ylabel='Norm of difference matrix')

plt.tight_layout()
plt.savefig('ch7_proj47_part6b.png')
plt.show()