|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Editing hidden states<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Hidden-state scaling and token loss<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dulm_x/?couponCode=202509" target="_blank">udemy.com/course/dulm_x/?couponCode=202509</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Model, hooks, tokens

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model.eval()

In [None]:
# dictionary of scaling factors
scaling_fact_dict = {}

# hooking functions
def implant_hook(layer_number):
  def hook(module, input, output):

    # only change this layer if there's a matching key in the dictionary
    if layer_number in scaling_fact_dict.keys():
      hidden, *rest = output
      hs = hidden.clone()
      hs.mul_(scaling_fact_dict[layer_number])
      output = tuple([hs]+rest)

    return output
  return hook

for layeri in range(model.config.n_layer):
  model.transformer.h[layeri].register_forward_hook(implant_hook(layeri))

In [None]:
# quote from Einstein (missing final word is "curious")
text = 'I have no special talent. I am only passionately'
tokens = tokenizer.encode(text,return_tensors='pt')

target_token = tokenizer.encode(' curious',return_tensors='pt')[0].item() # isolate the index from the tensor
target_token

In [None]:
with torch.no_grad():
  out = model(tokens)

pure_logits = out.logits.detach()

# Exercise 2: Scale one layer

In [None]:
# redefine dictionary
scaling_fact_dict = {2:.6}

with torch.no_grad():
  out = model(tokens)

scal_logits = out.logits.detach()

In [None]:
# show the logits
plt.figure(figsize=(8,3))
plt.plot(pure_logits[0,-1,:],'b.',markersize=2,alpha=.1,label='Unscaled')
plt.plot(scal_logits[0,-1,:],'r.',markersize=2,alpha=.1,label='Scaled')

# adjust legend markers
h = plt.legend()
for hi in h.legend_handles:
  hi.set(markersize=5,alpha=1)

plt.gca().set(xlabel='Vocab elements',xlim=[0,model.config.vocab_size],ylabel='Final logits')

plt.show()

In [None]:
# show the logits
plt.figure(figsize=(5,4))
plt.plot(pure_logits[0,-1,:],scal_logits[0,-1,:],'k.',alpha=.3,markersize=2)

# their correlation
r = np.corrcoef(pure_logits[0,-1,:],scal_logits[0,-1,:])

plt.gca().set(xlabel='Unscaled logits',ylabel='Scaled logits',
              title=f'Impact of scaling layer {list(scaling_fact_dict.keys())[0]} (r = {r[0,1]:.6f})')
plt.show()

# Exercise 3: Sort and print predicted tokens

In [None]:
# sort the final logits
pure_sorted = torch.argsort(pure_logits[0,-1,:],descending=True)
scal_sorted = torch.argsort(scal_logits[0,-1,:],descending=True)

print(' '*(len(text)+3),'   Unscaled   |    Scaled')
print(' '*(len(text)+3),'-'*30)
for i in range(10):
  print(f'{tokenizer.decode(tokens[0])}... {tokenizer.decode(pure_sorted[i]):>12}  | {tokenizer.decode(scal_sorted[i])}')

In [None]:
# calculate loss for target item on final logit
pure_target_loss = -F.log_softmax(pure_logits[0,-1,:],dim=-1)[target_token]
scal_target_loss = -F.log_softmax(scal_logits[0,-1,:],dim=-1)[target_token]

dictItems = list(scaling_fact_dict.items())[0]
print(f'When scaling layer {dictItems[0]} by {dictItems[1]}:','\n')
print(f'Target loss in unscaled model: {pure_target_loss:.3f}')
print(f'Target loss in scaled model  : {scal_target_loss:.3f}')

# Exercise 4: Target token loss by layer scaling

In [None]:
# initialize
targetLosses = np.zeros(model.config.n_layer)
targetCorrs = np.zeros(model.config.n_layer)

# log-softmax for the unscaled logits
pure_logsm = F.log_softmax(pure_logits[0,-1,:],dim=-1)

# loop over layers
for layeri in range(model.config.n_layer):

  # recreate the scaling dictionary from scratch
  scaling_fact_dict = {layeri:.6}

  # forward pass
  with torch.no_grad(): out=model(tokens)

  # get the target log-softmax and loss
  scal_logsm = F.log_softmax(out.logits[0,-1,:],dim=-1)
  targetLosses[layeri] = -scal_logsm[target_token]

  # and their correlation
  catX = torch.concatenate((pure_logsm.unsqueeze(0),scal_logsm.unsqueeze(0)),dim=0)
  targetCorrs[layeri] = torch.corrcoef(catX)[0,1]

In [None]:
_,axs = plt.subplots(1,2,figsize=(12,3.5))

axs[0].plot(targetLosses,'ks-',markerfacecolor=[.7,.9,.7],markersize=10,linewidth=.5,label='Scaled')
axs[0].axhline(-pure_logsm[target_token],color=[.9,.7,.7],linewidth=2,zorder=-3,label='Unscaled')

axs[0].set(xlabel='Transformer block with hidden-state scaling',
              ylabel='Target token loss $\\left(-\ln(\hat{y})\\right)$',title='Target losses')
axs[0].legend()

axs[1].plot(targetCorrs,'ks-',markerfacecolor=[.7,.7,.9],markersize=10,linewidth=.5)
axs[1].set(xlabel='Transformer block with hidden-state scaling',
              ylabel='Logit correlation',title='Correlation between scaled and unscaled log-sm')

plt.tight_layout()
plt.show()