|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating layers<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge HELPER: Logit Lens in BERT<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Reminder of masked-token predictions in BERT

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased')

# -> GPU

In [None]:
# getting predicts in BERT (note: target word for [MASK] is "way")
text = 'The way you do anything is the [MASK] you do everything.'
tokens =

for t in tokens[0]:
  print(f'Token {} is "{}"')

# index of the [MASK] token
mask_token_idx =
print(f'\nMask token is {tokenizer.mask_token_id} and is in index {}')

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(

# logits for [MASK] token
mask_logits =

# top predicted token
top_token_idx = torch.argmax(

# and plot
plt.figure(figsize=(10,4))

plt.plot(,,'go',markersize=9)
plt.plot(,'k.',alpha=.3)
plt.gca().set(xlim=[-5,tokenizer.vocab_size+4],xlabel='Token index',ylabel='BERT output logits',
              title=f'Predicted token is "{tokenizer.decode(top_token_idx)}"')
plt.show()

# Exercise 2: Calculate and visualize z-scores

In [None]:
# calculate z
zLogits = (mask_logits-) / .std()

# and plot
fig = plt.figure(figsize=(13,4))
gs = GridSpec(1,3,figure=fig)

ax0 = fig.add_subplot(gs[0])
ax1 = fig.add_subplot(gs[1:])


# histogram of final output layer logits
ax0.hist(,bins=80,edgecolor='k',facecolor='gray',label='All tokens')

ax0.set(xlabel='Logit value',ylabel='Count',title='Histogram of logits')
ax0.legend(fontsize=9)

# scatter plot of logits
ax1.plot(,,'go',markersize=9)
ax1.plot(,'k.',alpha=.3)
ax1.set(xlim=[-5,tokenizer.vocab_size+4],xlabel='Token index',ylabel='Z-scored output logits',
              title=f'Predicted token is "{tokenizer.decode(top_token_idx)}"')

plt.tight_layout()
plt.show()

# Exercise 3: Z-score final output logits for each token

In [None]:
# full sentence
text = 'the way you do anything is the way you do everything'
tokens = tokenizer.encode(text,return_tensors='pt')

# initialize
mask_zscores = np.zeros(len(tokens[0]))
predicted_text = []

print(f'[MASK] token is {tokenizer.mask_token_id}\n')

# loop over tokens, replace with [MASK], and get logits
for idx,tok in enumerate:

  # make a copy and replace a token with mask
  masked_tokens =
  masked_tokens[] =
  # confirmation (convert to list for visualizability):
  print(masked_tokens[0].tolist())

  # forward pass through the model


  # get logits for the masked position
  mask_logits =

  # get the max masked prediction and its z-score
  predicted_token =
  predicted_text.append( tokenizer.decode(
  mask_zscores[idx] = (mask_logits

In [None]:
# print a table of results

print('    TARGET   |  PREDICTED |  Z-SCORE')
print('-'*38)
for idx in range(len(tokens[0])):
  print(f'{}   |   {}   |   {}')


# Exercise 4: Proper internal lensing in BERT

In [None]:
# reminder
model

In [None]:
# convenience variable (+1 for initial embedding layer)
numlayers =

In [None]:
# getting predicts in BERT (note: target word for [MASK] is "way")
text_withMask = 'The way you do anything is the [MASK] you do everything.'
masked_tokens = tokenizer).to(device)

In [None]:
outputs = model(masked_tokens,output_hidden_states=True)


In [None]:
### THE WRONG WAY

# unembedding matrix
unembeddingT = model.cls.predictions.decoder.weight.detach().T

# activations from mask position in one layer
acts = outputs.hidden_states[-3][0,8,:].detach()

# same as with GPT?
mask_logitsWrong = acts @ unembeddingT

mask_logitsWrong = mask_logitsWrong.cpu()

In [None]:
### THE CORRECT WAY
mask_logitsCorrect = model.cls.predictions(

In [None]:
# predictions
_,axs = plt.subplots(1,2,figsize=(14,4))

# the WRONG way
top_token_idx_wrong = torch.argmax(,dim=-1).cpu()

axs[0].plot(,,'go',markersize=9)
axs[0].plot(,'k.',alpha=.3)
axs[0].set(xlim=[-10,tokenizer.vocab_size+9],xlabel='Token index',ylabel='Z-scored output logits',
              title=f'(WRONG WAY) Predicted token is "{tokenizer.decode(top_token_idx_wrong)}"')


# the CORRECT way
top_token_idx_correct = torch.argmax(

axs[1].plot(,'go',markersize=9)
axs[1].plot(,'k.',alpha=.3)
axs[1].set(xlim=[-10,tokenizer.vocab_size+9],xlabel='Token index',ylabel='Z-scored output logits',
              title=f'(CORRECT WAY) Predicted token is "{tokenizer.decode(top_token_idx_correct)}"')

plt.tight_layout()
plt.show()

# Exercise 5: Repeat for all layers (logit-lens)

In [None]:
# initialize
mask_zscores = np.zeros(
predictedTokens = np.zeros((

# loop over tokens, replace with [MASK], and get logits
for idx,tok in

  # make a copy and replace a token with mask
  masked_tokens =
  masked_tokens
  # confirmation:
  #print(masked_tokens[0])

  # forward pass through the model (with hidden states exported!)



  ### loop over layers
  for layeri in range(numlayers):

    # get internal logits for the masked position
    acts = outputs.

    # logit lens
    mask_logits = model.

    # get the masked prediction and its z-score
    predicted_token =
    predictedTokens[layeri,idx] =
    mask_zscores[layeri,idx] =

In [None]:
print('Original text:\n', text, '\n')
print('Predictions at first transformer block:\n',
print('Predictions at final transformer block:\n',

In [None]:
# lines visualization
plt.figure(figsize=(12,4))

for layeri in range(numlayers):
  plt.plot(,'o-',color=mpl.cm.plasma((layeri+1)/numlayers),label=f'L{layeri}')

plt.gca().set(xticks=range(len(tokens[0,1:-1])),xticklabels=[tokenizer.decode(t) for t in tokens[0,1:-1]],
              ylabel='Z-score',title='Z-scores of predicted tokens')

plt.legend(fontsize=11.3,bbox_to_anchor=(1,1.03),ncol=2)
plt.show()

In [None]:
# and as an image
plt.figure(figsize=(13,5))
plt.imshow(,aspect='auto',origin='lower',cmap='plasma')
plt.colorbar(pad=.01,label='Z-score')

plt.gca().set(xlabel='Target word',ylabel='Layer',
              xticks=range(len(tokens[0,1:-1])),xticklabels=[tokenizer.decode(t) for t in tokens[0,1:-1]])
plt.show()

# Exercise 6: Logit Lens text heatmap

In [None]:
# min-max scale z-scores
scaled_zscores =


In [None]:
fig,ax = plt.subplots(1,figsize=(13,10))

# original text (separated into a list of decoded tokens)
target = [] # use list-comprehension
numTokens = len(target)

# loop over layers
for layeri in range(numlayers):

  # y-axis coordinate for this layer
  yCoord = 1-layeri/numlayers

  # print the layer number in the left margin
  ax.text(-.1,yCoord,f'Layer {layeri}:',ha='right')

  # loop over the predicted tokens in this layer
  for xi,tok in enumerate(predictedTokens[layeri]):
    ax.text(xi/numTokens,yCoord,tokenizer.decode(tok),ha='center',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=mpl.cm.Reds(scaled_zscores[layeri,xi]), edgecolor='none',alpha=.4))

ax.axis('off')

# finally, draw the target tokens at the bottom
ax.text(-.1,yCoord-.05,f'Target:',ha='right',fontweight='bold')
for xi,tok in enumerate(target):
  ax.text(xi/numTokens,yCoord-.05,tok,ha='center',fontsize=12,fontweight='bold')

plt.show()