|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[40] Downstream impacts of head silencing</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
import torch.nn.functional as F

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Hook all layers and get activations**

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
model = AutoModelForCausalLM.from_pretrained('gpt2-medium')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model.eval()
model = model.to(device)

In [None]:
# some useful variables
n_heads = model.config.n_head
n_layers = model.config.n_layer
n_embd = model.config.n_embd
head_dim = n_embd // n_heads

In [None]:
## hook to silence attention heads

# initialize out-of-bounds variables
layer2silence = 1000
head2silence = 1000

def implant_hook(layer_number):
  def hook4attn(module,input):

    # modify the activation only for this layer
    if layer_number==layer2silence:

      # reshape so we can index heads (number of batches and tokens extracted locally)
      nB,nT,_ = input[0].shape
      head_tensor = input[0].view(nB,nT,n_heads,head_dim)

      # ablate
      head_tensor[:,:,head2silence,:] = 0

      # reshape back to tensor
      head_tensor = head_tensor.view(nB,nT,n_embd)

      # return a tuple matching the original
      input = (head_tensor,*input[1:])

    return input
  return hook4attn


head_handles = []
for layeri in range(n_layers):
  h = model.transformer.h[layeri].attn.c_proj.register_forward_pre_hook(implant_hook(layeri))
  head_handles.append(h)

In [None]:
## hook to grab the hidden states

hook_hs = {}

def outerHook(layeri):
  def hook(module, input, output):
    hook_hs[f'layer_{layeri}'] = output[0].detach().cpu().numpy()
  return hook

hs_handles = []
for layeri in range(model.config.n_layer):
  h = model.transformer.h[layeri].register_forward_hook(outerHook(layeri))
  hs_handles.append(h)

In [None]:
# text and tokens (quote from Igor Stravinsky)
text = 'Look for the music on all things, and life will be a symphony of joy. My music is best understood by children and'
# (original quote ends with " animals.")

# https://www.azquotes.com/quote/1459556
# https://en.wikiquote.org/wiki/Igor_Stravinsky

tokens = tokenizer.encode(text,return_tensors='pt')
nbatches,ntokens = tokens.shape

# move to GPU
tokens = tokens.to(device)

print(f'There are {ntokens} tokens:')
for t in tokens[0]:
  print(f'{t:>5} is "{tokenizer.decode(t)}"')

In [None]:
layer2silence = 1000

with torch.no_grad():
  out_clean = model(tokens)

# make a copy of the hidden states, b/c the variable will be overwritten in the next forward pass
hs_clean = hook_hs.copy()
logits_clean = out_clean.logits.cpu()

In [None]:
print('Keys in hs_clean:\n  ',hs_clean.keys())
print('\nShape of hs vectors:\n  ',hs_clean['layer_2'].shape)
print('\nShape of output logits:\n  ',logits_clean.shape)

# **Part 2: Cosine similarity in the "clean" model**

In [None]:
# how many unique tokens are in the CS matrix
uniqueCScount = int((ntokens-1)*(ntokens-2)/2)

cs_clean = np.zeros((n_layers,uniqueCScount))

for layeri in range(n_layers):

  # cosine similarities
  H = hs_clean[f'layer_{layeri}'][0,1:,:]
  H = H / np.linalg.norm(H,axis=1,keepdims=True)
  csMat = H @ H.T

  # and extract the unique elements
  cs_clean[layeri,:] = csMat[np.triu_indices(ntokens-1,1)]

cs_clean.shape

In [None]:
fig,ax = plt.subplots(1,figsize=(10,4))

Nbins = 31
binbounds = np.linspace(.6,1,Nbins+1)
Y = np.zeros((Nbins,n_layers))

for layeri in range(n_layers):

  # extract the histogram
  Y[:,layeri],_ = np.histogram(cs_clean[layeri],bins=binbounds)

  # and plot
  ax.plot(layeri,cs_clean[layeri].mean(),'kh',markerfacecolor=[.7,.9,.9],markeredgewidth=1.5,markersize=12)


# and the histogram image underneath
bincenters = (binbounds[:-1]+binbounds[1:])/2
h = ax.pcolor(range(n_layers),bincenters,Y,vmin=0,vmax=Y.max()*.8,cmap='afmhot_r')
fig.colorbar(h,ax=ax,pad=.01,label='Count')

ax.set(xlabel='Layer',xticks=range(0,n_layers),ylabel='Cosine similarity',
       title='Cosine similarity averages and distributions')

plt.tight_layout()
plt.savefig('ch6_proj40_part2.png')
plt.show()

# **Part 3: Downstream impact of silencing one head**

In [None]:
layer2silence = 4
head2silence = 1

with torch.no_grad():
  out_silence = model(tokens)

hs_silence = hook_hs.copy()
logits_silence = out_silence.logits.cpu() # used in part 4

In [None]:
# create a figure
fig = plt.figure(figsize=(11,3))
gs = GridSpec(1,3,figure=fig)
ax0 = fig.add_subplot(gs[:-1])
ax1 = fig.add_subplot(gs[-1])

# for the colormap of the lines
norm = mpl.colors.Normalize(vmin=0,vmax=n_layers)
cmap = mpl.colormaps['plasma'].resampled(n_layers+1)


for layeri in range(n_layers):

  # similarities in this layer
  H = hs_silence[f'layer_{layeri}'][0,1:,:]
  H = H / np.linalg.norm(H,axis=1,keepdims=True)
  csMat = H @ H.T
  cs_silence = csMat[np.triu_indices(ntokens-1,1)]

  # differences and scatter plot of mean difference
  cs_diff = cs_clean[layeri] - cs_silence
  ax1.plot(layeri,cs_diff.mean(),'kh',markerfacecolor=cmap(norm(layeri)),markersize=10)

  # histograms
  y,x = np.histogram(cs_diff,bins='fd')
  if layeri >= layer2silence:
    ax0.plot(x[:-1],y,'.-',markersize=8,linewidth=2,color=cmap(norm(layeri)))

# and the color bar
sm = mpl.cm.ScalarMappable(cmap=cmap,norm=norm)
cbar = fig.colorbar(sm,ax=ax0,pad=.02)
cbar.set_label('Layer')

ax0.axvline(0,linestyle='--',color='k',linewidth=.4)
ax0.set(xlabel='$\\Delta$ Cosine similarity',ylabel='Count',ylim=[0,None],
        title='A) Difference distributions')

ax1.axvline(layer2silence,linestyle='--',color='k',zorder=-10,linewidth=.4)
ax1.axhline(0,linestyle='--',color='k',zorder=-10,linewidth=.4)
ax1.set(xlabel='Layer',xticks=range(0,n_layers,2),ylabel='Similarity differences',title='B) Average differences')

plt.tight_layout()
plt.savefig('ch6_proj40_part3.png')
plt.show()

# **Interlude: Copying data from the GPU to the CPU**

In [None]:
#-- still on the GPU
# next_tok = torch.argmax(out_clean.logits[0,-1,:])
# print(next_tok.device)

#-- explicitly move to the CPU
# next_tok = torch.argmax(out_clean.logits[0,-1,:]).cpu()
# print(next_tok.device)

#-- move to CPU with .item()
# next_tok = torch.argmax(out_clean.logits[0,-1,:]).item()
# print(next_tok.device)
# print(torch.tensor(next_tok).device)

#-- but item() is only for scalars
# lsm_logits = F.log_softmax(out_clean.logits[0,-1,:],dim=-1).item()

# **Part 4: Impact on output logits**

In [None]:
print(f'Original text is "{text}"\n')

# from the clean model
# NOTE: variable 'new_tok_clean' is important! That's the target token in part 5
logsoftmax_clean = F.log_softmax(logits_clean[0,-1,:],dim=-1)
new_tok_clean = torch.argmax(logsoftmax_clean).item()
maxlsm_clean = logsoftmax_clean[new_tok_clean]
print(f'   CLEAN: Predicted next token is "{tokenizer.decode(new_tok_clean)}" with {torch.exp(maxlsm_clean):.3%}')

# from the head-silenced model
logsoftmax_silence = F.log_softmax(logits_silence[0,-1,:],dim=-1)
next_tok = torch.argmax(logsoftmax_silence).item()
print(f'SILENCED: Predicted next token is "{tokenizer.decode(next_tok)}" with {torch.exp(logsoftmax_silence[next_tok]):.3%}')

In [None]:
print('Top 10 logits from CLEAN model:')
for t in torch.topk(logsoftmax_clean,10)[1]:
  print(f'"{tokenizer.decode(t)}", ',end='')

print('\n\nTop 10 logits from SILENCED model:')
for t in torch.topk(logsoftmax_silence,10)[1]:
  print(f'"{tokenizer.decode(t)}", ',end='')

In [None]:
# still only silencing one head
head2silence = 1

plt.figure(figsize=(10,3))
for layer2silence in range(n_layers):

  # rerun the model
  with torch.no_grad():
    out_silence = model(tokens,output_hidden_states=True)

  lsm_logits = F.log_softmax(out_silence.logits[0,-1,:],dim=-1).cpu()
  next_tok = torch.argmax(lsm_logits).item()
  print(f'Next token is "{tokenizer.decode(next_tok)}" with {torch.exp(lsm_logits[next_tok]):.3%}  (silenced L{layer2silence}:H{head2silence})')

  plt.plot(layer2silence,lsm_logits[next_tok],'kh',markersize=12,markerfacecolor=mpl.cm.plasma(layer2silence/n_layers))

plt.axhline(maxlsm_clean,color='k',linestyle='--',linewidth=.5,zorder=-100)
plt.gca().set(xlabel='Manipulated layer',ylabel='Log softmax prob',title=f'Impact of silencing head #{head2silence} on max logit')

plt.tight_layout()
plt.savefig('ch6_proj40_part4.png')
plt.show()

# **Part 5: Silencing each head**

In [None]:
# initialize matrix of log-softmax values
lsm_vals = np.zeros((n_heads,n_layers,2))


# double loop over layers and heads per layer
for layer2silence in tqdm(range(n_layers)):
  for head2silence in range(n_heads):

    # rerun the model
    with torch.no_grad():
      out_silence = model(tokens,output_hidden_states=True)

    # get the log softmax
    lsm_logits = F.log_softmax(out_silence.logits[0,-1,:],dim=-1).cpu()

    # softmax of target token
    lsm_vals[head2silence,layer2silence,0] = lsm_logits[new_tok_clean].item()

    # and of the max token
    lsm_vals[head2silence,layer2silence,1] = torch.max(lsm_logits).item()

In [None]:
fig,axs = plt.subplots(1,2,figsize=(12,4))

for i in [0,1]:

  # image data to show
  I = lsm_vals[:,:,i] - maxlsm_clean.item()

  # find color limits based on L1-mean
  clim = abs(I).mean()*2

  # create the image
  h = axs[i].pcolor(range(n_layers),range(n_heads),I,vmin=-clim,vmax=clim,cmap='RdBu_r')
  axs[i].set(xlabel='Transformer layer',ylabel='Attention head index')
  fig.colorbar(h,ax=axs[i],pad=.02)
  axs[i].spines.top.set_visible(True) # switched off by default, but I want them here
  axs[i].spines.right.set_visible(True)


axs[0].set(title='A) Target log-softmax')
axs[1].set(title='B) Max token log-softmax')

plt.tight_layout()
plt.savefig('ch6_proj40_part5a.png')
plt.show()

In [None]:
plt.figure(figsize=(8,3))

# plot the data for the target token
X = (lsm_vals[:,:,0]-maxlsm_clean.item())
plt.errorbar(np.arange(n_layers)-.05,X.mean(axis=0),X.std(axis=0),
             color='r',linestyle='none')
plt.plot(np.arange(n_layers)-.05,X.mean(axis=0),'rs',markersize=9,
         markerfacecolor=[.9,.7,.7],label='Target token')

# and for the non-target token
X = (lsm_vals[:,:,1]-maxlsm_clean.item())
plt.errorbar(np.arange(n_layers)+.05,X.mean(axis=0),X.std(axis=0),
             color='g',linestyle='none')
plt.plot(np.arange(n_layers)+.05,X.mean(axis=0),'go',markersize=8,
         markerfacecolor=[.7,.9,.7],label='Maximum token')

# draw the clean max/target logit
plt.axhline(0,color='k',linestyle='--',linewidth=.4)

# axis adjustments
plt.gca().set(xlabel='Transformer block of attention silencing',ylabel='Log-softmax difference',
        xlim=[-.5,n_layers-.5],title='Average of single-head impacts')
plt.legend(loc='lower right')

plt.tight_layout()
plt.savefig('ch6_proj40_part5b.png')
plt.show()