|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[34] QKV activation characteristics</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import requests

import torch
from transformers import AutoModelForCausalLM, GPT2Tokenizer
import torch.nn.functional as F

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Model, tokens, QKV activations**

In [None]:
# load GPT2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-large')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

model.eval()

In [None]:
# variable for the number of embedding dimensions
n_emb = model.

# and the number of layers
n_layers =

In [None]:
# GreatGatsby
url = 'https://www.gutenberg.org/cache/epub/64317/pg64317.txt'
text = requests.get(url).text
tokens = np.array( tokenizer.encode(text) )

target_word =

target_idx =
print(f'{len(target_idx)} instances of "{target_word}"')

In [None]:
# size of pre-target context (tokens)
contextwin =

# remove any target indices that are <contextwin
print(f'Target count before filtering: {len(target_idx)}')
target_idx = target_idx[]

print(f'Target count after filtering: {len(target_idx)}')

In [None]:
# create a batch
batch = torch.zeros((len(target_idx),contextwin+1),dtype=torch.long)

for i in range(len(target_idx)):
  toks = tokens[]
  batch[i,:] =

batch.shape

In [None]:
# note the identical final token
batch

In [None]:
# a few examples
for i in range(5):
  print(tokenizer.decode())

In [None]:
# Define a hook function to store QKV vectors
activations = {}

def implant_hook(layer_number):
  def hook(module,input,output):
    activations[f'attn_{layer_number}_qkv'] = output.()
  return hook


# surgery ;)
handles = []
for i in range(n_layers):
  h = model...register_forward_hook(implant_hook(i))
  handles.append(h)

In [None]:
# push through the model
with torch.no_grad(): model(batch)

for k,i in activations.items():
  print()

In [None]:
# concatenated activations from one layer
layeri = 6

wide_acts = activations[]

plt.figure(figsize=(10,3))
plt.imshow(,aspect='auto',vmin=-1,vmax=1,cmap='plasma')
plt.axvline(,linestyle='--',color='w')
plt.axvline(,linestyle='--',color='w')
plt.colorbar(pad=.01)

plt.gca().set(xticks=[],ylabel='Batch sequences',
              xlabel='Queries dimensions         |           Keys dimensions             |           Values dimensions')


plt.tight_layout()
plt.savefig('ch6_proj34_part1.png')
plt.show()

# **Part 2: Activation characteristics**

In [None]:
# just one layer for now
layeri = 6

# split into separate matrices
Q,K,V = torch.split(
Q.shape,K.shape,V.shape

In [None]:
# target: means of means and variances
qMeans_t = torch.tensor([ torch.mean,
                          torch.mean,
                          torch.var ])

# non-target: means of means and variances
qMeans_n = torch.tensor([  ])


plt.figure(figsize=(8,3))
plt.bar(np.arange(3)-.2,qMeans_t,width=.5,edgecolor='k',label='Target')
plt.bar(np.arange(3)+.2,qMeans_n,width=.5,edgecolor='k',label='Non-target')
plt.legend()
plt.gca().set(xticks=range(3),xticklabels=['Arithmetic mean','L1 mean','Variance'],ylabel='Descriptive values')

plt.tight_layout()
plt.savefig('ch6_proj34_part2a.png')
plt.show()

In [None]:
edges = torch.linspace(-7,7,234)

ytQ,_ = torch.histogram(Q[:,-1,:].flatten(),edges,density=)
ynQ,_ = torch.histogram()
ytK,_ = torch.histogram(
ynK,_ = torch.histogram(
ytV,_ = torch.histogram(
ynV,_ = torch.histogram(

plt.figure(figsize=(10,3))
plt.plot(,label='Q target')
plt.plot(label='Q nontarget')
plt.plot(,label='K target')
plt.plot(,label='K nontarget')
plt.plot(,label='V target')
plt.plot(,label='V nontarget')

plt.legend()
plt.gca().set(xlim=edges[[0,-1]],xlabel='Activation value',ylabel='Density',
              title=f'Attention matrices distributions from layer {layeri}')

plt.tight_layout()
plt.savefig('ch6_proj34_part2b.png')
plt.show()

# **Part 3: Laminar trends in activation descriptives**

In [None]:
# helper function
def meansAndVar(X):

  # target
  t1 = torch...(X[:,-1,:],      # arithmetic mean
  t2 =  # L1 mean
  t3 =        # variance

  # non-target
  n1 =       # arithmetic mean
  n2 =  # L1 mean
  n3 =        # variance

  return

# test
meansAndVar(Q)

In [None]:
# initialize: layers X matrix X feature X category
descriptives = torch.zeros((n_layers,3,3,2))

for layeri in range(n_layers):

  # split into separate matrices
  Q,K,V = torch.split(

  # Q: get the descriptives
  T,N = meansAndVar(Q)
  descriptives[layeri,0,:,0] =
  descriptives[layeri,0,:,1] =

  # K: get the descriptives
  T,N = meansAndVar(K)
  descriptives[
  descriptives[

  # V: get the descriptives
  T,N = meansAndVar(V)
  descriptives[
  descriptives[


descriptives.shape

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3))

feature = ['Arithmetic mean','L1 mean','Variance']

for i in range(len(axs)):
  axs[i].plot(,'gs-',markerfacecolor=[.7,.9,.7,.7],label='Q')
  axs[i].plot(,'ro-',markerfacecolor=[.9,.7,.7,.7],label='K')
  axs[i].plot(,'b^-',markerfacecolor=[.7,.7,.9,.7],label='V')
  axs[i].legend()
  axs[i].set(xlabel='Transformer layer',xticks=range(0,n_layers,4),title=f'{feature[i]}')

plt.tight_layout()
plt.savefig('ch6_proj34_part3a.png')
plt.show()

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3))

for i in range(len(axs)):
  axs[i].axhline(0,linestyle='--',linewidth=.5,color=[.7,.7,.7])
  axs[i].plot(descriptives....diff(dim=-1),'gs-',markerfacecolor=[.7,.9,.7,.7],label='Q')
  axs[i].plot(,'ro-',markerfacecolor=[.9,.7,.7,.7],label='K')
  axs[i].plot(,'b^-',markerfacecolor=[.7,.7,.9,.7],label='V')
  axs[i].legend()
  axs[i].set(xlabel='Transformer layer',xticks=range(0,n_layers,4),title=f'{feature[i]}')

plt.tight_layout()
plt.savefig('ch6_proj34_part3b.png')
plt.show()

# **Part 4: Laminar trends in activation distributions**

In [None]:
# helper function
edges = torch.linspace(-4,4,101)

def calculateHists(X):
  yt,_ = torch.histogram()
  yn,_ = torch.histogram()
  return yt,yn

In [None]:
# initialize: layers X matrix X bins X category
histograms = torch.zeros((n_layers,3,len(edges)-1,2))

for layeri in range(n_layers):

  # split into separate matrices
  Q,K,V =

  # get the histograms
  histograms[layeri,0,:,0],histograms[layeri,0,:,1] = calculateHists(
  histograms[layeri,1,:,0],histograms[layeri,1,:,1] =
  histograms[layeri,2,:,0],histograms[layeri,2,:,1] =

histograms.shape

In [None]:
fig,axs = plt.subplots(2,3,figsize=(12,6))

for i in range(3):
  h = axs[0,i].imshow(,vmin=0,vmax=.4,aspect='auto',cmap='plasma',
                  extent=[edges[0],edges[-1],0,n_layers],origin='lower')
  axs[0,i].set(xticklabels=[],ylabel='Layer',yticks=range(0,n_layers,5),title=f"Target {'QKV'[i]}")
  fig.colorbar(h,ax=axs[0,i],pad=.03,orientation='horizontal')

  axs[1,i].imshow(,vmin=0,vmax=.4,aspect='auto',cmap='plasma',
                  extent=[edges[0],edges[-1],0,n_layers],origin='lower')
  axs[1,i].set(xlabel='Activation value',ylabel='Layer',yticks=range(0,n_layers,5),title=f"Non-target {'QKV'[i]}")


plt.tight_layout()
plt.savefig('ch6_proj34_part4.png')
plt.show()

# **Part 5: Extracting and characterizing heads**

In [None]:
n_heads = model.config.
head_dim =

In [None]:
# extract Q
Q,K,V = torch.split(activations['attn_6_qkv'],n_emb,dim=-1)

# and reshape to have a "head" dimension
Qh = Q.view

print(f'Q is size {list(Q.shape)} and Qh is size {list(Qh.shape)}')

In [None]:
# visualize
_,axs = plt.subplots(4,5,figsize=(12,6))

for i,ax in enumerate(axs.flatten()):
  ax.pcolor(,cmap='plasma',vmin=-2,vmax=2)
  ax.text(contextwin-.5,head_dim-1,f'Qh{i}',fontsize=12,fontweight='bold',color='k',ha='right',va='top')
  ax.text(contextwin,head_dim-2,f'Qh{i}',fontsize=12,fontweight='bold',color='w',ha='right',va='top')
  ax.set(xticks=[],yticks=[])

# finalize
axs[3,0].set(ylabel='Head dim',xlabel='Token position')

plt.tight_layout()
plt.savefig('ch6_proj34_part5a.png')
plt.show()

In [None]:
head_stdevs = np.zeros((n_layers,3))

for layeri in range(n_layers):

  # split into Q, K, and V
  Q,K,V =

  # reshape to have a heads dimension
  Qh = Q.view
  Kh =
  Vh =

  # then get the standard deviation along the tokens
  head_stdevs[layeri,0] = Qh.
  head_stdevs[layeri,1] =
  head_stdevs[layeri,2] =


# and visualize
plt.figure(figsize=(10,3))

plt.plot(head_stdevs[:,0],'gs-',markersize=9,markerfacecolor=[.7,.9,.7],label='Q')
plt.plot(head_stdevs[:,1],'ro-',markersize=9,markerfacecolor=[.9,.7,.7],label='K')
plt.plot(head_stdevs[:,2],'b^-',markersize=9,markerfacecolor=[.7,.7,.9],label='V')

plt.legend()
plt.gca().set(xlabel='Transformer layer',ylabel='Average standard deviation')

plt.tight_layout()
plt.savefig('ch6_proj34_part5b.png')
plt.show()