|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[39] Token frequencies, attention adjustments, and QK^T</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
# note: check RAM via system resources; I used 8GB

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from datasets import load_dataset

from tqdm import tqdm

from scipy.stats import binned_statistic

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Tokenize Wikipedia data and count token frequencies**

In [None]:
# BERT tokenizer and model
from transformers import BertTokenizer, BertModel

# load BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertModel.from_pretrained('bert-large-uncased')
model.eval()

In [None]:
# ~7 mins on Colab...
wiki = load_dataset('wikimedia/wikipedia','20231101.en',split='train[:1%]')
wiki

In [None]:
wiki[1234]

In [None]:
# number of tokens in one text sample
len(tokenizer.encode(wiki[0]['text'],add_special_tokens=False))

In [None]:
targetlength =

token_props = torch.zeros()

nsofar = 0
row = 0
while nsofar<targetlength:

  # tokenize new text
  toks = tokenizer.encode(,add_special_tokens=)

  # increase counter for these tokens
  for i in toks:
    token_props[i]

  # update loop counters
  nsofar +=
  row += 1

  if (row)%10==0:
    print(f'{nsofar:>7,} tokens after {row} rows.')

print(f'\nFinal count: {nsofar:>7,} tokens after {row} rows.')

In [None]:
sidx = torch.argsort(token_props,

print('Most frequent tokens:')
for i in sidx[:20]:
  print(f'  {} occurrences of "{}"')

In [None]:
# convert from counts to proportion
token_props

In [None]:
plt.figure(figsize=(10,3))
plt.plot(,'k.',alpha=.3,markersize=3)
plt.gca().set(xlabel='Token index',ylabel='Token log-frequency',ylim=[-14,None],
              xlim=[0,tokenizer.vocab_size],title='Empirical proportions in 1M tokens')

plt.tight_layout()
plt.savefig('ch6_proj39_part1.png')
plt.show()

# **Part 2: Create a batch of tokens**

In [None]:
# create a batch
batchsize = 64
seqlen =

# initialize as list
batch = []

txti = -1
while len(batch)<:

  # increase the counter and get tokens
  txti += 1
  tokens = tokenizer.encode(

  # add a list item to the batch only if it's long enough
  if len(tokens)>=seqlen:
    batch.append

# convert the list to a tensor
batch =
batch.shape

In [None]:
plt.figure(figsize=(8,4))
plt.pcolor(batch,cmap='magma')
plt.gca().set(xlabel='Token index',ylabel='Sequence')

plt.tight_layout()
plt.savefig('ch6_proj39_part2.png')
plt.show()

# **Part 3: Hook attention QK^T and adjustments**

In [None]:
# reference
model.encoder.layer[3].attention

In [None]:
n_layers = model.config.
n_heads  = model.
emb_dim  =
head_dim =  //

print(f'{} transformer layers with dimensionality {},')
print(f'{} attention heads with dimensionality {}.')

In [None]:
model.encoder.layer[4].attention.output # only want 'dense'

In [None]:
qkt_acts = {}

def implant_hook_qkt(layer_number):
  def hook_qkt(module, input, output):

    # calculate Q and K as XW
    Q = module.query(input[0].detach())
    K = module.

    # reshape to get heads dimension
    Q = Q.view(batchsize,seqlen,n_heads,head_dim).transpose(1,2)
    K = K.

    # QK^t raw attention scores
    qkt =  @ .transpose(-2,-1)
    qkt = # scaling

    # average over heads and store
    qkt_acts[f'L{layer_number}'] = qkt.mean(dim=

  return hook_qkt


# hook the attention output adjustments
adj_norms = {}
def implant_hook_adj(layer_number):
  def hook_adj(module,input,output):
    O = output.detach()
    adj_norms[f'L{layer_number}'] =
  return hook_adj



# implant the hooks
handles = []
for i in range(n_layers):
  h = model.encoder.layer[i].....register_forward_hook(implant_hook_qkt(i))
  handles.append(h)
  h = model.encoder.layer[i].....register_forward_hook(implant_hook_adj(i))
  handles.append(h)

In [None]:
# push the batch through
# this is the only GPU-based improvement in the project (minutes -> seconds)
with torch.no_grad():
  outs = model(batch)

print(f'QK^T keys:\n',qkt_acts.keys())
print(f'\nAdjustment vectors keys:\n',adj_norms.keys())

In [None]:
qkt_acts['L5'].shape, adj_norms['L5'].shape

# **Part 4: Adjustment norms by token frequencies**

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,3))

for i,l in enumerate([,,]):

  # extract and plot the data excluding the first token
  x = torch.log().flatten()
  y = adj_norms[][].flatten()
  axs[i].plot(x,y,'k.',markersize=3,alpha=.3)
  axs[i].set(xlabel='Log token prop',ylabel='Vector norm',
            title=f'Layer {l}: (r = {})')

  # just the first token (not [CLS]!)
  x = torch.log(token_props[batch[
  y =
  axs[i].plot(x,y,'rx',markersize=5,alpha=.7)


plt.tight_layout()
plt.savefig('ch6_proj39_part4a.png')
plt.show()

In [None]:
# re-extract x
x = torch.log(token_props[batch[:,1:]]).flatten().numpy()

# loop over layers (same 'x' for all layers)
plt.figure(figsize=(8,3))
for i in range(n_layers):
  y =
  r =
  plt.plot(i,r,)

plt.axhline(0,linestyle='--',color='k',linewidth=.4)
plt.gca().set(xlabel='Transformer layer',ylabel='Correlation coefficient',
              title='Attention adjustment magnitudes vs. token frequency')

plt.tight_layout()
plt.savefig('ch6_proj39_part4b.png')
plt.show()

# **Part 5: QK^T by proportion product (one layer)**

In [None]:
# extract the token proportions from one sequence as a column vector
P = np.log( np.array((token_props[batch[0]],)).T )

# outer product (broadcast summing logs) to get a matrix
Pmat =  +

# and visualize
plt.imshow()
plt.gca().set(xlabel='Token index',ylabel='Token index',title='Token proportion product pairs')
plt.colorbar(pad=.01)

plt.tight_layout()
plt.savefig('ch6_proj39_part5a.png')
plt.show()

In [None]:
# confirmation of symmetry
Pmat-Pmat.T

In [None]:
# done per-sequence
for seqi in range(batchsize):

  P = np.array((token_props[batch[seqi]],)).T
  pmat =  +
  pmat = pmat.flatten()
  qk_t = qkt_acts[][,,].flatten(

  plt.plot(pmat[::500],qk_t[::500],'.',markersize=3,alpha=.3)

plt.gca().set(xlabel='Log-probabilty sum',ylabel='QK$^\\top$')

plt.tight_layout()
plt.savefig('ch6_proj39_part5b.png')
plt.show()

In [None]:
# FYI, difference between percentile (equal data binning) and linspace (equal x-axis binning)
# This figure is mentioned but not shown in the book.
edges1 = np.percentile(pmat,np.linspace(0,100,41))
edges2 = np.linspace(pmat.min(),pmat.max(),41)

# plot the results and the line of unity
plt.plot(edges1,edges2,'kh',markersize=10,markerfacecolor=[.9,.7,.9,.7])
plt.plot(edges1[[0,-1]],edges2[[0,-1]],'k--',linewidth=.5)
plt.gca().set(xlabel='Edges from percentile (equal data)',ylabel='Edges from linspace (equal grid spacing)')

plt.show()

In [None]:
# grid resolution and layer
N = 41
layeri = 7

# initialize correlations matrix
corrs = np.zeros((batchsize,2))

_,axs = plt.subplots(1,2,figsize=(10,3))
for seqi in range(batchsize):

  # get flattened co-probability matrix
  P =
  pmat =
  pmat =

  # define equal-sized bin boundaries and dataset for this sequence
  edges = np.percentile(pmat,np.linspace(0,100,N))

  qk_t =
  meansByBin,_,_ = binned_statistic(pmat,qk_t,statistic='mean',bins=edges)
  stdesByBin,_,_ = binned_statistic

  # scatter plots
  bincenters =
  axs[0].plot(bincenters,meansByBin,'h',markersize=4,alpha=.5)
  axs[1].plot(bincenters,stdesByBin,'h',markersize=4,alpha=.5)

  # correlations
  corrs[seqi,0] = np.corrcoef(
  corrs[seqi,1] = np.corrcoef(


# plot adjustments
axs[0].set(xlabel='Token pair log-probs',ylabel='$\\mathbf{QK^\\top}$ mean',title=f'A) Average scores by token probs, L{layeri}')
axs[1].set(xlabel='Token pair log-probs',ylabel='$\\mathbf{QK^\\top}$ stdev',title=f'B) Stdev scores by token probs, L{layeri}')

plt.tight_layout()
plt.savefig('ch6_proj39_part5c.png')
plt.show()

In [None]:
plt.figure(figsize=(8,3))
plt.plot(label='$\\mathbf{QK^\\top}$ stdev')
plt.plot(label='$\\mathbf{QK^\\top}$ mean')

plt.axhline(0,linestyle='--',linewidth=.3,color='w')
plt.legend()
plt.gca().set(xlabel='Batch sequence number',ylabel='Correlation coefficient',
              title='Relations between token frequencies and QK$^\\top$ features')

plt.tight_layout()
plt.savefig('ch6_proj39_part5d.png')
plt.show()

# **Part 6: Laminar profiles of QK^T by proportion**

In [None]:
N = 41

descriptivesByProb = np.zeros((,,))

# loop over sequencies in the batch
for seqi in tqdm(

  # get flattened co-probability matrix (same for all layers)
  P = np.array((token_props[batch[seqi]],)).T
  pmat = np.log(P) + np.log(P.T)
  pmat = pmat.flatten()

  # now loop over transformer layers
  for layeri in range(n_layers):

    # get flattened QK' dot products for this layer and this sequence
    qk_t =

    # define equal-sized bin boundaries and dataset for this sequence
    edges =
    meansByBin,_,_ =
    stdesByBin,_,_ =

    # pool and sum
    descriptivesByProb[layeri,:,0] +=
    descriptivesByProb[layeri,:,1] +=
    descriptivesByProb[layeri,:,2] +=


# divide by batchsize for average
descriptivesByProb

In [None]:
_,axs = plt.subplots(2,2,figsize=(12,6))

# for mapping line color onto the colorbar
norm = mpl.colors.Normalize(vmin=0,vmax=n_layers-1)
sm = mpl.cm.ScalarMappable(cmap=mpl.cm.plasma,norm=norm)


# loop over layers
for layeri in range(n_layers):

  # extract the descriptive stats
  x  = descriptivesByProb[,:,]
  yM = descriptivesByProb
  yS =

  # plot the data
  axs[0,0].plot(x,,'.-',markersize=7,color=plt.cm.plasma(norm(layeri)),linewidth=2)
  axs[0,1].plot(x,,'.-',markersize=7,color=plt.cm.plasma(norm(layeri)),linewidth=2)

  # and the correlation coefficients
  axs[1,0].plot(layeri,np.corrcoef(,'ks',markersize=11,markerfacecolor=plt.cm.plasma(layeri/n_layers))
  axs[1,1].plot(layeri,,'ko',markersize=11,markerfacecolor=plt.cm.plasma(layeri/n_layers))


# thin lines for r=0
axs[1,0].axhline(0,color='k',linestyle='--',linewidth=.4)
axs[1,1].axhline(0,color='k',linestyle='--',linewidth=.4)

# add colorbars
cbar = plt.colorbar(sm,ax=axs[0,0],pad=.01)
cbar.set_label('Transformer layer',fontweight='normal')
cbar = plt.colorbar(sm,ax=axs[0,1],pad=.01)
cbar.set_label('Transformer layer',fontweight='normal')

plt.tight_layout()
plt.savefig('ch6_proj39_part6.png')
plt.show()