|<h2>Course:</h2>|<h1><b><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></b></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Identifying circuits and components<h1>|
|<h2>Lecture:</h2>|<h1><b>SAE in GPT2 learns about Hungarian Palinka<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, GPT2Tokenizer

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# model & tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model     = AutoModelForCausalLM.from_pretrained('gpt2')
model.eval().to(device)

In [None]:
# hook the MLP activations
activations = {}

def mlp_hook(module, input, output):
  activations[f'mlp_{layer2hook}'] = output.detach()

# hook a middle layer
layer2hook = model.config.n_layer//2
handle = model.transformer.h[layer2hook].mlp.c_fc.register_forward_hook(mlp_hook)

# note: I'm setting a handle to the hook to remove it later

# Tokens and activations

In [None]:
# https://en.wikipedia.org/wiki/P%C3%A1linka
texts =  'Palinka is a traditional fruit spirit (or fruit brandy) with origins in the medieval Hungary, known under several names. Protected as a geographical indication of the European Union, only fruit spirits mashed, distilled, matured and bottled in Hungary and similar apricot spirits from four provinces of Austria can be called "palinka". A similar product exists in the Czech Republic and Slovakia where it is known as palenka, and in Romania (Transylvania), Italy, and Greece under the name palinca. While palinka is traditionally made from a mash of ripe fruit, the law does not control the addition of non-concentrated fruit juice, and explicitly allows the use of fruit pulp. Dried fruits are excluded from the mash only, and may be used in the aging process. While palinka has PDO on its own, some regions of Hungary are especially suitable for the production of certain fruits, and palinka of excellent quality has been produced in those regions for centuries. These local variations are protected as separate geographical indications and have their own well-detailed regulations.'

tokens = tokenizer.encode(texts, return_tensors='pt').to(device)
numtoks = len(tokens[0])
numtoks

In [None]:
# run the model to get the activations
with torch.no_grad():
  model(tokens)

# remove the hook
handle.remove()

In [None]:
# grab the activations
X = activations[f'mlp_{layer2hook}'].squeeze()
print(f'Activations size: {X.shape}')

# move to GPU
X = X.to(device)

# Create and train the autoencoder

In [None]:
class SparseAE(nn.Module):
  def __init__(self, input_dim, latent_dim, k=None, sparsity_weight=1, decor_weight=.0005):
    super().__init__()
    self.encoder = nn.Linear(input_dim, latent_dim, bias=False)
    # self.decoder = nn.Linear(latent_dim, input_dim, bias=False)
    # note: I've tied the weights by transposing the encoder weights in forward()

    self.sparsity_weight = sparsity_weight
    self.decor_weight = decor_weight

    # k-sparse parameter defaults to 50% of input
    if k==None:
      self.k = input_dim//2
    else:
      self.k = k

  def forward(self, x):

    # forward pass to the latent layer
    latent = F.relu(self.encoder(x))

    # "k-sparsify": force sparsity by zeroing out small activations
    topk_vals = torch.topk(latent,self.k,dim=1)[0]
    thresh = topk_vals[:,-1].unsqueeze(1) # kth-largest value is the smallest of the sorted top-k
    mask = (latent >= thresh).float() # mask is 0's and 1's
    latent_sparse = latent * mask

    # finally, decode via tied weights
    y = F.linear(latent_sparse, self.encoder.weight.t())

    return y,latent_sparse

  def sparsity_loss(self, z):
    return self.sparsity_weight * torch.mean(torch.abs(z))

  # penalty on inter-latent covariance
  def decorrelation_loss(self, estLatent):
    cov = torch.cov(estLatent.T)
    off_diag = cov - torch.diag(torch.diag(cov))
    return self.decor_weight * torch.sum(off_diag**2)


In [None]:
# create an instance of the autoencoder (kx latent dimensions)
num_latent = 2*X.shape[1]
ae = SparseAE(input_dim=X.shape[1], k=X.shape[1]//3, latent_dim=num_latent)
ae = ae.to(device)

In [None]:
# training parameters and initializations
optimizer = optim.Adam(ae.parameters(), lr=.0001)
mse_loss  = nn.MSELoss()

n_epochs  = 100
losses = np.zeros((n_epochs,3))


# train!!
for epoch in range(n_epochs):

  # forward pass
  optimizer.zero_grad()
  x_pred,latent = ae(X)

  # get the three losses
  L1 = ae.sparsity_loss(latent)
  deCor = ae.decorrelation_loss(latent)
  loss = mse_loss(x_pred,X) + L1 + deCor

  # and store them for later inspection
  losses[epoch,0] = loss.item()
  losses[epoch,1] = L1.item()
  losses[epoch,2] = deCor.item()

  # do backprop
  loss.backward()
  optimizer.step()

  if epoch%11 == 0:
    print(f'Epoch {epoch+1:3d}, loss = {losses[epoch,0]:.4f}')



## final run to get latent activations
with torch.no_grad():
  aeout,latent = ae(X)

latent = latent.cpu().numpy()
latent.shape

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3))

# the plots
axs[0].plot(losses[:,0],'ks-',markerfacecolor=[.9,.7,.7,.5],linewidth=.5)
axs[1].plot(losses[:,1],'ko-',markerfacecolor=[.7,.9,.7,.5],linewidth=.5)
axs[2].plot(losses[:,2],'k^-',markerfacecolor=[.7,.7,.9,.5],linewidth=.5)

# the labeling
losslabels = [ 'Total loss','Sparsity loss','Decorrelation loss']
for i in range(3):
  axs[i].set(xlabel='Epoch',ylabel='Loss',title=losslabels[i])

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10,3))
plt.imshow(latent,vmin=0,vmax=.1,aspect='auto')

plt.gca().set(ylabel='Token index',xlabel='Latent component index',title='Latent activations')
plt.colorbar(pad=.01)
plt.show()

# Quantitative scores for each latent component

In [None]:
# create a mask that is nan for zero-valued activations
densitymask = np.full(latent.shape,np.nan)
densitymask[latent!=0] = 1

# percent of nonzero activations per latent component (plotted in next code cell)
latentDensity = 100 * np.nansum(densitymask,axis=0) / densitymask.shape[0]

# token-averaged activation, excluding zeros
nonzeroAct = np.nanmean(np.abs(latent*densitymask),axis=0)
nonzeroAct[np.isnan(nonzeroAct)] = 0

# average nonzero activations per latent component
densityActivation = latentDensity*nonzeroAct

nonzeroAct.shape

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3.3))

whichzero = latentDensity==0
axs[0].plot(np.where(whichzero)[0],latentDensity[whichzero],'rx',alpha=.1)
axs[0].plot(np.where(~whichzero)[0],latentDensity[~whichzero],'ko',markerfacecolor=[.9,.7,.7],alpha=.5)
axs[0].set(xlabel='Latent component index',ylabel='% tokens with activation',
              title=f'Latent component layer is {100*(1-whichzero.mean()):.2f}% dense')

axs[1].plot(np.where(whichzero)[0],nonzeroAct[whichzero],'rx',alpha=.1)
axs[1].plot(np.where(~whichzero)[0],nonzeroAct[~whichzero],'ko',markerfacecolor=[.7,.9,.7],alpha=.5)
axs[1].set(xlabel='Latent component index',ylabel='Nonzero activations',
              title=f'Latent component activations')


axs[2].plot(nonzeroAct[~whichzero],latentDensity[~whichzero],'ko',markerfacecolor=[.7,.7,.9],alpha=.5)
axs[2].set(xlabel='Nonzero activations',ylabel='% tokens with activation',
              title=f'Activations by density')

plt.tight_layout()
plt.show()

# Qualitative inspection (text heatmaps)

In [None]:
# get width of one letter
fig,ax = plt.subplots(figsize=(10,2))
temp_text = ax.text(0,0,'n',fontsize=12,fontfamily='monospace')
bbox = temp_text.get_window_extent(renderer=fig.canvas.get_renderer())
inv = ax.transAxes.inverted()
bbox_axes = inv.transform([[bbox.x0,bbox.y0], [bbox.x1,bbox.y1]])
en_width = bbox_axes[1,0] - bbox_axes[0,0]
plt.close(fig)

In [None]:
# function to draw the heatmap, given a latent component
def drawTextHeatmap(comp):

  # min-max scale
  compAct = (latent[:,comp] - latent[:,comp].min()) / (latent[:,comp].max() - latent[:,comp].min())

  tokCount = 0

  x_pos = 0  # starting x position (in axis coordinates)
  y_pos = 1  # vertical center

  fig, ax = plt.subplots(figsize=(10,2))
  ax.axis('off')

  for toki in range(numtoks):

    # text of this token
    toktext = tokenizer.decode([tokens[0,toki]])

    # width of the token
    token_width = en_width*len(toktext)

    # text object with background color matching the "activation"
    ax.text(x_pos+token_width/2, y_pos, toktext, fontsize=12, ha='center', va='center',fontfamily='monospace',
            bbox = dict(boxstyle='round,pad=.3', facecolor=mpl.cm.Reds(compAct[toki]**2), edgecolor='none', alpha=.8))

    # update the token counter and x_pos
    tokCount += 1
    x_pos += token_width + .01 # plus a small gap

    # end of the line; reset coordinates and counter
    if tokCount>=20:
      y_pos -= .17
      x_pos = 0
      tokCount = 0

  print(f'Latent component {comp}')
  plt.show()

In [None]:
# sort
sortByActivity = np.argsort(nonzeroAct)[::-1]
sortByTokenDensity = np.argsort(latentDensity)[::-1]

latent.shape

In [None]:
# comment one, run the other
drawTextHeatmap(sortByActivity[30])
# drawTextHeatmap(sortByTokenDensity[14])

# Pick a component based on latent activation to a word category

In [None]:
# single-token words to look for
geography_words = [ 'Hungary','geographical','European','Union',
                    'provinces','Austria','Czech','Republic',
                    'Slovakia','Romania','Italy','Greece','regions' ]

# vector of targets
target_tokens = np.zeros(numtoks,dtype=bool)

# mark target tokens
for w in geography_words:
  tokidx = tokenizer.encode(w,add_prefix_space=True)[0]
  target_tokens[np.where(tokens[0].cpu() == tokidx)[0]] = 1

target_tokens

In [None]:
# calculate a "geoscore" ratio of target tokens to non-target tokens
geoscore = np.zeros(latent.shape[1])

# loop over latent components
for comp in range(latent.shape[1]):

  # get the target and nontarget activations
  targ = latent[target_tokens,comp]
  nontarg = latent[~target_tokens,comp]

  # geoscore is the ratio of their non-zero token activations
  if (nontarg.sum()>0) & (targ.sum()>0):
    geoscore[comp] = targ[np.nonzero(targ)].mean() / nontarg[np.nonzero(nontarg)].mean()


# visualize the scores
plt.figure(figsize=(10,4))
plt.plot(geoscore,'ko',markerfacecolor=[.9,.7,.7,.5])
plt.gca().set(xlabel='Latent components',ylabel='Score',title='Geography token activation score')
plt.show()

In [None]:
# then make a text heatmap!
sortByGeoscore = np.argsort(geoscore)[::-1]
drawTextHeatmap(sortByGeoscore[9])