|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Editing hidden states<h1>|
|<h2>Lecture:</h2>|<h1><b>Downstream impact of early layer scaling<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Import a model, implant hooks

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

In [None]:
# scaling factor
scaling_factor = 1

# scale output by some amount
def hook(module, input, output):
  hidden, *rest = output # unpack the tuple
  hs = hidden.clone()
  hs.mul_(scaling_factor)
  return (hs,*rest) # pack the output back into a tuple

# hard-coded to use layer 2
model.transformer.h[2].register_forward_hook(hook)

# Create tokens and get activations

In [None]:
tokens = tokenizer.encode('Duct tape will very useful after the apocalypse.',return_tensors='pt')

In [None]:
# get the hidden states
scaling_factor = 1
with torch.no_grad():
  out = model(tokens,output_hidden_states=True)

hs_pure = out.hidden_states

print(f'There are {len(hs_pure)} hidden_states.')
print(f'Each hidden state is of size {list(hs_pure[3].shape)}')

# Scale a layer and visualize the impact

In [None]:
# now scale one layer
scaling_factor = .5
with torch.no_grad():
  out = model(tokens,output_hidden_states=True)

hs_scale = out.hidden_states

In [None]:
layer_norm = np.zeros(len(hs_scale))

# calculate the norm of the change between each layer
for layeri in range(len(hs_scale)):

  # extract embeddings vectors for all-but-first token
  pure = hs_pure[layeri][0,1:,:].detach().numpy()
  scal = hs_scale[layeri][0,1:,:].detach().numpy()

  # calculate and store matrix norm difference
  layer_norm[layeri] = np.linalg.norm(pure-scal)

In [None]:
plt.figure(figsize=(8,3))

plt.plot(layer_norm,'ks-',markerfacecolor=[.7,.9,.7],markersize=10,linewidth=.5)

xlabels = ['emb'] + [f'T{i}' for i in range(model.config.n_layer)]
plt.gca().set(xlabel='Hidden state',xticks=range(len(hs_pure)),xticklabels=xlabels,
              ylabel='Matrix difference norm',title='Difference between scaled and unscaled hidden states')
plt.show()

# Explore a range of scaling factors

In [None]:
# the scaling
scale_factors = np.linspace(.5,1.5,21)

# initialize the results matrix
layer_norms = np.zeros((len(hs_scale),len(scale_factors)))


# loop over scaling factors
for scalei in range(len(scale_factors)):

  # set the new scaling factor
  scaling_factor = scale_factors[scalei]

  # run the model and get hidden states
  with torch.no_grad(): out=model(tokens,output_hidden_states=True)
  hs_scale = out.hidden_states

  # calculate the norm of the change between each layer
  for layeri in range(len(hs_scale)):

    # extract embeddings vectors for all-but-first token
    pure = hs_pure[layeri][0,1:,:].detach().numpy()
    scal = hs_scale[layeri][0,1:,:].detach().numpy()

    # calculate and store matrix norm difference
    layer_norms[layeri,scalei] = np.linalg.norm(pure-scal)

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,3))

# plot each layer (in a loop for color specification)
for scalei in range(len(scale_factors)):
  axs[0].plot(layer_norms[:,scalei],'.-',color=mpl.cm.plasma(scalei/len(scale_factors)))

# create a colorbar for the lines
norm = mpl.colors.Normalize(vmin=scale_factors[0],vmax=scale_factors[-1])
sm = mpl.cm.ScalarMappable(cmap=mpl.cm.plasma,norm=norm)
cbar = fig.colorbar(sm,ax=axs[0],pad=.01)
cbar.set_label(r'Scaling factor')

# axis aesthetics
axs[0].set(xlabel='Hidden state',xticks=range(0,len(hs_pure),2),xticklabels=xlabels[::2],xlim=[-.5,len(hs_pure)-.5],
           ylabel='Matrix difference norm',title='Norm of differences in lines')


# show an image of the same data
h = axs[1].imshow(layer_norms.T,aspect='auto',cmap=mpl.cm.hot,vmin=0,vmax=layer_norms.max()*.7,extent=[0,len(hs_pure),scale_factors[0],scale_factors[-1]],origin='lower')
axs[1].set(xlabel='Hidden state',xticks=range(1,len(hs_pure),2),xticklabels=xlabels[1::2],
           ylabel='Scaling factor',title='Norm of differences in an image')
fig.colorbar(h,ax=axs[1],pad=.01,label='Norm of $\Delta$')

plt.tight_layout()
plt.show()