|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Interfering with attention <h1>|
|<h2>Lecture:</h2>|<h1><b>Impact of head-silencing on cosine similarity<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import scipy.stats as stats

import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

In [None]:
# some useful variables
nheads = model.config.n_head
nlayers = model.config.n_layer
n_emb = model.config.n_embd
head_dim = model.config.n_embd // nheads

In [None]:
def implant_hook(layer_number):
  def hook4attn(module,input):

    # modify the activation only for this layer
    if layer_number==layer2silence:

      # reshape so we can index heads
      head_tensor = input[0].view(nbatches,ntokens,nheads,head_dim)

      # ablate
      head_tensor[:,:,head2silence,:] = 0

      # reshape back to tensor
      head_tensor = head_tensor.view(nbatches,ntokens,n_emb)

      # return a tuple matching the original
      input = (head_tensor,*input[1:])

    return input
  return hook4attn


handles = []
for layeri in range(nlayers):
  h = model.transformer.h[layeri].attn.c_proj.register_forward_pre_hook(implant_hook(layeri))
  handles.append(h)

# Tokens (my apologies to the Italians :P )

In [None]:
tokens = tokenizer.encode('Peanut butter and pineapple taste great on pizza',return_tensors='pt')
nbatches,ntokens = tokens.shape

for i in range(ntokens):
  print(f'Token position {i} is index {tokens[0,i]:6} and is "{tokenizer.decode(tokens[0,i])}"')

# Cosine similarities in an unmodulated model

In [None]:
layer2silence = 1000
head2silence = 1000

with torch.no_grad():
  out_clean = model(tokens,output_hidden_states=True)

# check hidden states sizes
hs_clean = out_clean.hidden_states
hs_clean[4].shape

In [None]:
# get cossim from all layers

# how many unique tokens are in the CS matrix
uniqueCScount = int((ntokens-1)*(ntokens-2)/2)

cs_clean = np.zeros((nlayers,uniqueCScount))

for layeri in range(nlayers):

  # cosine similarities
  H = out_clean.hidden_states[layeri+1][0,1:,:]
  H = H / torch.linalg.vector_norm(H,axis=1,keepdims=True)
  csMat = H @ H.T

  # and extract the unique elements
  cs = torch.triu(csMat,1).flatten()
  cs_clean[layeri,:] = cs[cs!=0].numpy()


# Disrupt one head and measure impact

In [None]:
layer2silence = 5
head2silence = 1

with torch.no_grad():
  out_silence = model(tokens,output_hidden_states=True)

In [None]:
# cosine similarities
H = out_silence.hidden_states[layer2silence+1][0,1:,:]
H = H / torch.linalg.vector_norm(H,axis=1,keepdims=True)
csMat = H @ H.T

# and extract the unique elements
cs_silence = torch.triu(csMat,1).flatten()
cs_silence = cs_silence[cs_silence!=0]
cs_silence

# Manipulate and measure from the same layer

In [None]:
cs_manip = np.zeros((nlayers,uniqueCScount))

for layer2silence in range(nlayers):

  with torch.no_grad():
    out_manip = model(tokens,output_hidden_states=True)

  # cosine similarities
  H = out_manip.hidden_states[layer2silence+1][0,1:,:]
  H = H / torch.linalg.vector_norm(H,axis=1,keepdims=True)
  csMat = H @ H.T

  # and extract the unique elements
  cs = torch.triu(csMat,1).flatten()
  cs_manip[layer2silence,:] = cs[cs!=0].numpy()


In [None]:
fig,axs = plt.subplots(1,3,figsize=(12,3))

for i in range(nlayers):

  # plot the raw cossim values
  axs[0].plot(np.zeros(uniqueCScount)+i,cs_clean[i,:],'ks',markersize=4,alpha=.5,markerfacecolor=mpl.cm.plasma(i/nlayers))
  axs[1].plot(np.zeros(uniqueCScount)+i,cs_manip[i,:],'ko',markersize=4,alpha=.5,markerfacecolor=mpl.cm.plasma(i/nlayers))

  # test for a statistical difference
  tres = stats.ttest_rel(cs_manip[i,:],cs_clean[i,:])

  # plot marker depends on significance
  if tres.pvalue<.05:
    axs[2].plot(np.zeros(uniqueCScount)+i,cs_manip[i,:]-cs_clean[i,:],'k^',markersize=4,alpha=.5,markerfacecolor=mpl.cm.plasma(i/nlayers))
  else:
    axs[2].plot(np.zeros(uniqueCScount)+i,cs_manip[i,:]-cs_clean[i,:],'rx',markersize=4,alpha=.5)


# some embellishments etc
axs[0].set(xlabel='Transformer block',ylabel='Cosine similarity',title='Clean model')
axs[1].set(xlabel='Transformer block',ylabel='Cosine similarity',title='Layer-specific silencing')
axs[2].set(xlabel='Transformer block',ylabel='$\Delta S_c$',title='Silenced - clean difference')
axs[2].axhline(0,color='k',zorder=-30,linewidth=.5)

plt.tight_layout()
plt.show()

# Manipulate one layer and measure downstream impact

In [None]:
layer2silence = 3

# forward pass
with torch.no_grad():
  out_silence = model(tokens,output_hidden_states=True)

# no more forward passes or silencing; just measure cs
cs_manip2 = np.zeros((nlayers,uniqueCScount))
for layeri in range(nlayers):

  # cosine similarities
  H = out_silence.hidden_states[layeri+1][0,1:,:]
  H = H / torch.linalg.vector_norm(H,axis=1,keepdims=True)
  csMat = H @ H.T

  # and extract the unique elements
  cs_silence = torch.triu(csMat,1).flatten()
  cs_manip2[layeri,:] = cs_silence[cs_silence!=0].numpy()


In [None]:
fig,axs = plt.subplots(1,3,figsize=(12,3))

for i in range(nlayers):

  # plot the raw cossim values
  axs[0].plot(np.zeros(uniqueCScount)+i,cs_clean[i,:],'ks',markersize=4,alpha=.5,markerfacecolor=mpl.cm.plasma(i/nlayers))
  axs[1].plot(np.zeros(uniqueCScount)+i,cs_manip2[i,:],'ko',markersize=4,alpha=.5,markerfacecolor=mpl.cm.plasma(i/nlayers))

  # test for a statistical difference
  tres = stats.ttest_1samp(cs_manip2[i,:]-cs_clean[i,:],0)

  # plot marker depends on significance
  if tres.pvalue<.05:
    axs[2].plot(np.zeros(uniqueCScount)+i,cs_manip2[i,:]-cs_clean[i,:],'k^',markersize=4,alpha=.5,markerfacecolor=mpl.cm.plasma(i/nlayers))
  else:
    axs[2].plot(np.zeros(uniqueCScount)+i,cs_manip2[i,:]-cs_clean[i,:],'rx',markersize=4)


# some embellishments etc
axs[0].set(xlabel='Transformer block',ylabel='Cosine similarity',title='Clean model')
axs[1].set(xlabel='Transformer block',ylabel='Cosine similarity',title=f'Silenced layer {layer2silence}')
axs[2].set(xlabel='Transformer block',ylabel='$\Delta S_c$',title='Silenced - clean difference')
axs[2].axhline(0,color='k',zorder=-30,linewidth=.5)

plt.tight_layout()
plt.show()