|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Modifying MLP<h1>|
|<h2>Lecture:</h2>|<h1><b>Explorations in subspace removal<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

import scipy.stats as stats

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GPT2LMHeadModel.from_pretrained('gpt2-xl').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

nlayers = model.config.n_layer
model.eval()

# Forward pass and get MLP activations

In [None]:
# https://en.wikipedia.org/wiki/Maurice_Ravel
text = "Joseph Maurice Ravel (7 March 1875 – 28 December 1937) was a French composer, pianist and conductor. He is often associated with Impressionism along with his elder contemporary Claude Debussy, although both composers rejected the term. In the 1920s and 1930s Ravel was internationally regarded as France's greatest living composer."
tokens = tokenizer.encode(text,return_tensors='pt').to(device)
tokens.shape, tokens

In [None]:
# dictionary to store the mlp activations
mlp_values = {}

def hook(module, input, output):
  mlp_values[f'L{whichlayer}'] = output[0].detach().cpu() # detach from the computational graph

# surgery ;)
whichlayer = 9
handle = model.transformer.h[whichlayer].mlp.c_fc.register_forward_hook(hook)

In [None]:
with torch.no_grad():
  out_clean = model(tokens,output_hidden_states=True)

handle.remove()

mlp = mlp_values[f'L{whichlayer}']
mlp.shape

In [None]:
# check the activations distribution
plt.figure(figsize=(8,4))
plt.hist(mlp.flatten(),80,edgecolor='k',facecolor='lightgray')

plt.gca().set(xlabel='Activation value',ylabel='Count (a.u.)',yticks=[])
plt.show()

## SVD and reconstruct data without top PC

In [None]:
U,s,Vt = torch.linalg.svd(mlp,full_matrices=False)
S = torch.diag(s)

# remove the top component
S[0,0] = 0

# reconstruct without the top component
proj = U @ S @ Vt

proj.shape

In [None]:
# add back the mean offset
proj = proj + mlp.mean(dim=-1,keepdims=True)

proj[4,:].mean(), mlp[4,:].mean()

In [None]:
## Visualize the impact
cov_mlp = torch.cov(mlp)
cov_prj = torch.cov(proj)


fig,axs = plt.subplots(2,3,figsize=(12,7))

# show the data
axs[0,0].plot(mlp.flatten()[::14],proj.flatten()[::14],'ko',markersize=3,markerfacecolor=[.7,.9,.7,.3])
axs[0,0].set(xlabel='Original data',ylabel='Reconstructed')

# and their distributions
binedges = np.linspace(-5,4,121)
y,_ = np.histogram(mlp.flatten(),binedges)
axs[0,1].plot(binedges[:-1],y,label='Original')
y,_ = np.histogram(proj.flatten(),binedges)
axs[0,1].plot(binedges[:-1],y,label='Recon')
axs[0,1].legend()
axs[0,1].set(xlim=binedges[[0,-1]],xlabel='Data value',ylabel='Count',yticks=[],title='Activation value distributions')

# eigenspectra
mlp_evals,_ = np.linalg.eig(cov_mlp)
rcn_evals,_ = np.linalg.eig(cov_prj)
axs[0,2].plot(100*mlp_evals/mlp_evals.sum(),'ks-',markersize=8,linewidth=.4,markerfacecolor=[.9,.7,.7,.6],label='Orig')
axs[0,2].plot(100*rcn_evals/rcn_evals.sum(),'ko-',markersize=8,linewidth=.4,markerfacecolor=[.7,.7,.9,.6],label='Recon')
axs[0,2].legend()
axs[0,2].set(ylabel='% variance explained',xticks=range(0,len(mlp_evals),2),xlim=[-1,15.5],title='Eigenspectra',xlabel='Component')


# show the covariance matrices
h = axs[1,0].imshow(cov_mlp,vmin=-.2,vmax=.6,cmap='plasma',aspect='auto')
fig.colorbar(h,ax=axs[1,0],pad=.01,fraction=.047)
axs[1,0].set(xlabel='Tokens',ylabel='Tokens',title='Original covariance')

h = axs[1,1].imshow(cov_prj,vmin=-.2,vmax=.6,cmap='plasma',aspect='auto')
fig.colorbar(h,ax=axs[1,1],pad=.01,fraction=.047)
axs[1,1].set(xlabel='Tokens',ylabel='Tokens',title='Reconstructed covariance')

# histogram of covariance values (note: doubled b/c of laziness :P )
binedges = np.linspace(-.3,.8,81)
y,_ = np.histogram(cov_mlp.flatten(),binedges)
axs[1,2].plot(binedges[:-1],y,label='Orig')
y,_ = np.histogram(cov_prj.flatten(),binedges)
axs[1,2].plot(binedges[:-1],y,label='Recon')
axs[1,2].legend()
axs[1,2].set(xlim=binedges[[0,-1]],xlabel='Covariance value',ylabel='Count (a.u.)',yticks=[],title='Covariance distributions')


plt.tight_layout()
plt.show()

# Explore subspace removal in one layer

In [None]:
def implant_hook(layer_number):
  def hook(module, input, output):

    # SVD and reconstruct without top dimension
    U,s,Vt = torch.linalg.svd(output[0],full_matrices=False)
    S = torch.diag(s)
    S[0,0] = 0
    proj = U@S@Vt

    # add back the mean offset
    proj += output[0].mean(dim=-1,keepdims=True)

    # return the data
    return proj

  return hook


# implant to halfway through the model
whichlayer = nlayers // 2
handle = model.transformer.h[whichlayer].mlp.c_fc.register_forward_hook(implant_hook(whichlayer))

In [None]:
with torch.no_grad():
  out_proj = model(tokens,output_hidden_states=True)

handle.remove()

In [None]:
norms = np.zeros(nlayers+1)

for i in range(nlayers+1):
  norms[i] = torch.norm(out_clean.hidden_states[i]-out_proj.hidden_states[i])

plt.figure(figsize=(8,3))
plt.plot(norms,'ko',markerfacecolor=[.9,.7,.7])
plt.axvline(whichlayer+.5,linestyle='--',color='gray')

plt.gca().set(xlabel='Transformer block',ylabel='HS difference norm',
              title='Impact of subspace removal on hidden states vector lengths')
plt.show()

In [None]:
# cosine similarities for clean data
H = out_clean.hidden_states[whichlayer+1][0,1:,:].cpu()
H = H / torch.linalg.vector_norm(H,axis=1,keepdims=True)
csM_clean = H @ H.T

# and subspace rejection
H = out_proj.hidden_states[whichlayer+1][0,1:,:].cpu()
H = H / torch.linalg.vector_norm(H,axis=1,keepdims=True)
csM_proj = H @ H.T


# how the matrices
fig,axs = plt.subplots(1,3,figsize=(12,3.5))
h = axs[0].imshow(csM_clean,vmin=.3,vmax=.8,cmap='plasma')
fig.colorbar(h,ax=axs[0],pad=.01,fraction=.047)
axs[0].set(xlabel='Tokens',ylabel='Tokens',xticks=[],yticks=[],
           title='Clean model: S$_C$ of token pairs')

h = axs[1].imshow(csM_proj,vmin=.3,vmax=.8,cmap='plasma')
fig.colorbar(h,ax=axs[1],pad=.01,fraction=.047)
axs[1].set(xlabel='Tokens',ylabel='Tokens',xticks=[],yticks=[],
           title='Subspace-removed: S$_C$ of token pairs')

# and their distributions
binedges = np.linspace(0,1.1,101)
y,_ = np.histogram(csM_clean.flatten(),binedges)
axs[2].plot(binedges[:-1],y,label='Clean')
y,_ = np.histogram(csM_proj.flatten(),binedges)
axs[2].plot(binedges[:-1],y,label='Projected')
axs[2].legend()

plt.tight_layout()
plt.show()

# Scatter plot and text heatmap of impact on logits

In [None]:
# get output logits
logitdiffs = np.zeros(len(tokens[0]))

plt.figure(figsize=(14,4))

for toki in range(len(tokens[0])-1):

  # get the logit differences
  lsm_clean = F.log_softmax(out_clean.logits[0,toki,:],dim=-1)
  lsm_proj  = F.log_softmax(out_proj.logits[0,toki,:],dim=-1)

  # logit difference from the current token logit calculated in the previous token position
  logitdiffs[toki+1] = lsm_clean[tokens[0,toki+1]] - lsm_proj[tokens[0,toki+1]]



# need a normalization to map negative values onto blue
norm = mpl.colors.TwoSlopeNorm(vmin=logitdiffs.min(),vcenter=0,vmax=logitdiffs.max())

# and visualize
for toki in range(1,len(tokens[0])):
  plt.plot([toki,toki],[logitdiffs.min(),logitdiffs[toki]],':',color=mpl.cm.coolwarm(norm(logitdiffs[toki])),linewidth=.7)
  plt.plot(toki,logitdiffs[toki],'kh',markersize=10,markerfacecolor=mpl.cm.coolwarm(norm(logitdiffs[toki])))

plt.axhline(0,color='gray',zorder=-100)

xtickskip = 1 # skipping tokens in x-ticks
plt.gca().set(ylabel='$\Delta$ logit from clean',xlim=[-1,len(tokens[0])],
              title='Change in token logit from subspace removal',
              xticks=range(0,len(tokens[0]),xtickskip),xticklabels=[tokenizer.decode(i) for i in tokens[0,range(0,len(tokens[0]),xtickskip)]])
plt.xticks(rotation=90)

plt.show()

In [None]:
# get width of one letter
fig,ax = plt.subplots(figsize=(10,2))
temp_text = ax.text(0,0,'n',fontsize=12,fontfamily='monospace')
bbox = temp_text.get_window_extent(renderer=fig.canvas.get_renderer())
inv = ax.transAxes.inverted()
bbox_axes = inv.transform([[bbox.x0,bbox.y0], [bbox.x1,bbox.y1]])
en_width = bbox_axes[1,0] - bbox_axes[0,0]
plt.close(fig)

In [None]:
tokCount = 0

x_pos = 0  # starting x position (in axis coordinates)
y_pos = 1  # vertical center

fig, ax = plt.subplots(figsize=(10,2))
ax.axis('off')

for toki in range(len(tokens[0])):

  # text of this token
  toktext = tokenizer.decode([tokens[0,toki]])

  # width of the token
  token_width = en_width*len(toktext)

  # text object with background color matching the "activation"
  ax.text(x_pos+token_width/2, y_pos, toktext, fontsize=12, ha='center', va='center',fontfamily='monospace',
          bbox = dict(boxstyle='round,pad=.3', facecolor=mpl.cm.coolwarm(norm(logitdiffs[toki])), edgecolor='none', alpha=.8))

  # update the token counter and x_pos
  tokCount += 1
  x_pos += token_width + .01 # plus a small gap

  # end of the line; reset coordinates and counter
  if tokCount>=20:
    y_pos -= .17
    x_pos = 0
    tokCount = 0

# Repeat over all layers

In [None]:
data = np.zeros((nlayers,2))

# loop over layers
for layeri in range(nlayers):

  # create the hook
  def hook(module, input, output):

    # SVD and reconstruct without top dimension
    U,s,Vt = torch.linalg.svd(output[0],full_matrices=False)
    S = torch.diag(s)
    data[layeri,1] = 100*s[0]**2 / (s**2).sum() # should be squared; this was a typo in the video
    S[0,0] = 0
    proj = U@S@Vt

    # add back the mean offset
    proj += output[0].mean(dim=-1,keepdims=True)

    # return the data
    return proj
  # implant the hook
  handle = model.transformer.h[layeri].mlp.c_fc.register_forward_hook(hook)

  # forward pass and remove the hook
  with torch.no_grad(): out_proj = model(tokens)
  handle.remove()

  # get output logits
  logitdiffs = np.zeros(len(tokens[0]))

  for toki in range(len(tokens[0])-1):
    lsm_clean = F.log_softmax(out_clean.logits[0,toki,:],dim=-1)
    lsm_proj  = F.log_softmax(out_proj.logits[0,toki,:],dim=-1)
    logitdiffs[toki+1] = lsm_clean[tokens[0,toki+1]] - lsm_proj[tokens[0,toki+1]]

  # and a t-test on all tokens against zero
  data[layeri,0] = stats.ttest_1samp(logitdiffs,0).statistic

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3.5))

for i in range(nlayers):
  axs[0].plot(i,data[i,0],'ko',markerfacecolor=mpl.cm.plasma(i/nlayers),markersize=10,alpha=.7)
  axs[1].plot(i,data[i,1],'ks',markerfacecolor=mpl.cm.plasma(i/nlayers),markersize=10,alpha=.7)
  axs[2].plot(data[i,0],data[i,1],'kh',markerfacecolor=mpl.cm.plasma(i/nlayers),markersize=10,alpha=.7)

# glow-ups
axs[0].axhline(0,linestyle='--',color='gray',linewidth=.8,zorder=-109)
axs[0].set(xlabel='Transformer block',ylabel='T-statistic',title='T-test over tokens against zero')
axs[1].set(xlabel='Transformer block',ylabel='% variance',title='Variance explained by top component')
axs[2].set(xlabel='Logit difference norm',ylabel='% variance explained',title=f'r = {np.corrcoef(data.T)[1,0]:.2f}')

plt.tight_layout()
plt.show()