|<h2>Book:</h2>|<h1><a href="https://open.substack.com/pub/mikexcohen/p/llm-breakdown-16-tokenization-words" target="_blank">50 ML projects to understand LLMs</a></h1>|
|-|:-:|
|<h2>Project:</h2>|<h1><b>[27] Impact of layer-specific noise and scaling</b></h1>|
|<h2>Author:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the book may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
### matplotlib adjustments (commented lines are for dark mode)

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    # 'figure.facecolor': '#282a2c',
    # 'figure.edgecolor': '#282a2c',
    # 'axes.facecolor':   '#282a2c',
    # 'axes.edgecolor':   '#DDE2F4',
    # 'axes.labelcolor':  '#DDE2F4',
    # 'xtick.color':      '#DDE2F4',
    # 'ytick.color':      '#DDE2F4',
    # 'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Model, tokens, and clean activations**

In [None]:
# load in GPT2-large and its tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2-large',output_hidden_states=True)
model.eval()

In [None]:
# move to the gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device);

In [None]:
print(f'This LLM has {model.config.n_layer} transformer layers.')

In [None]:
text = 'Pay no attention to that man behind the'
tokens = tokenizer.encode(text,return_tensors='pt')
target_token = tokenizer.encode(' curtain')[0]

for t in tokens[0]:
  print(f'Token {t:>5} is "{tokenizer.decode(t)}"')

In [None]:
with torch.no_grad(): # ~8s on CPU, <1s on GPU
  outputs_clean = model(tokens.to(device))

outputs_clean.hidden_states[0].shape

In [None]:
logsm = outputs_clean.logits[0,-1,:].log_softmax(dim=-1)
log_sm_target_clean = logsm[target_token].item()

nextword_clean = torch.argmax(logsm)
print(f'Next word is "{tokenizer.decode(nextword_clean)}" (token index {nextword_clean}) with {np.exp(log_sm_target_clean):.2%} probability.')

In [None]:
plt.figure(figsize=(10,3))

plt.axhline(log_sm_target_clean,color='g',linestyle='--',linewidth=.5)
plt.axvline(target_token,color='g',linestyle='--',linewidth=.5)

plt.plot(logsm.cpu(),'kh',markerfacecolor=[.7,.9,.7,.3],markersize=4)
plt.gca().set(xlabel='Token index',ylabel='log-softmax prob',
              title='log-softmax of final token',xlim=[-150,tokenizer.vocab_size+150])

plt.tight_layout()
plt.savefig('ch5_proj27_part1.png')
plt.show()

# **Part 2: Hook to inject noise**

In [None]:
# hooking functions
def implant_noise_hook(layer_number):
  def noise_hook(module,input,output):

    # only change one layer
    if layer_number == layer2noise:

      # unpack tuple
      hidden, *rest = output

      # generate a matrix of noise
      h_std = hidden.std() / 2
      noise = torch.randn_like(hidden)*h_std

      # add that noise to the hidden states
      hidden += noise
      # print(f'Changed layer {layer_number} with noise')

      # reconstruct output
      output = tuple([hidden]+rest)

    return output
  return noise_hook


# loop over layers and do surgery
handles = []
for layeri in range(model.config.n_layer):
  baselayer = model.transformer.h[layeri]
  h = baselayer.register_forward_hook(implant_noise_hook(layeri))
  handles.append(h)

In [None]:
# test with one layer
layer2noise = 15

with torch.no_grad():
  outputs_noise = model(tokens.to(device))

outputs_noise.hidden_states[0].shape

In [None]:
# initialize
diffnorms = torch.zeros(model.config.n_layer)

# loop over layers
for layeri in range(model.config.n_layer):

  # extract hidden states for this layer
  hs_c = outputs_clean.hidden_states[layeri].cpu()
  hs_n = outputs_noise.hidden_states[layeri].cpu()

  # norm of difference matrix
  diffnorms[layeri] = torch.norm(hs_c-hs_n)

# and plot
plt.figure(figsize=(10,3))
plt.plot(diffnorms,'kh',markerfacecolor=[.9,.7,.7],markersize=12)
plt.axvline(layer2noise+1,color='k',linestyle='--',zorder=-10)
plt.axhline(0,color='k',linestyle=':',zorder=-10)
plt.gca().set(xlabel='Layer',ylabel='Norm of difference')

plt.tight_layout()
plt.savefig('ch5_proj27_part2.png')
plt.show()

In [None]:
# predicted next token and its probability
max_logit = torch.argmax(outputs_noise.logits[0,-1,:])
log_sm = outputs_noise.logits[0,-1,:].softmax(dim=-1)

print(f' Clean model: next token is "{tokenizer.decode(nextword_clean)}" with {np.exp(log_sm_target_clean):.2%} probability.')
print(f'Noised model: next token is "{tokenizer.decode(max_logit)}" with {log_sm[max_logit].item():.2%} probability.')

# **Part 3: Impacts of layer-specific noising**

In [None]:
# initializations
log_sm_targets = torch.zeros((model.config.n_layer,2))

# loop over layers
for layer2noise in range(model.config.n_layer):

  # run the model
  with torch.no_grad():
    outputs_noise = model(tokens.to(device))

  # log-softmax the final token logits
  logsm = outputs_noise.logits[0,-1,:].log_softmax(dim=-1)

  # get the target ("curtain") value
  log_sm_targets[layer2noise,0] = logsm[target_token].item()

  # get the max value
  maxtok = torch.argmax(outputs_noise.logits[0,-1,:])
  log_sm_targets[layer2noise,1] = logsm[maxtok].item()

  # print the completed text
  print(f'L{layer2noise:2}: {text}"{tokenizer.decode(maxtok)}"')

In [None]:
_,axs = plt.subplots(1,2,figsize=(12,3.5))

# plot the impact of the perturbations
axs[0].axhline(log_sm_target_clean,color='g',linestyle='--',label='Clean')
axs[0].plot(log_sm_targets[:,0],'kh',markerfacecolor=[.9,.7,.7,.7],markersize=12,label='Noisified')
axs[0].legend()
axs[0].set(xlabel='Layer',ylabel='log-softmax prob',
              title=f'A) Impact of noise on log-softmax of "{tokenizer.decode(target_token)}"')

# and the max logit
axs[1].axhline(log_sm_target_clean,color='g',linestyle='--',label='Clean')
axs[1].plot(log_sm_targets[:,1],'ks',markerfacecolor=[.7,.7,.9,.7],markersize=12)
axs[1].set(xlabel='Layer',ylabel='log-softmax prob',ylim=axs[0].get_ylim(),
              title='B) Impact of noise on log-softmax of max logit')

plt.tight_layout()
plt.savefig('ch5_proj27_part3.png')
plt.show()

In [None]:
# remove handles
for h in handles:
  h.remove()

# **Part 4: Layer-specific scalar dampening**

In [None]:
# hooking functions
def implant_scale_hook(layer_number):
  def scale_hook(module, input, output):

    # only change one layer
    if layer_number == layer2scale:

      # unpack tuple
      hidden, *rest = output

      # in-place method to scale down the hidden states
      hidden.mul_(.5)

      # reconstruct output
      output = tuple([hidden]+rest)

    return output
  return scale_hook


# loop over layers and do surgery
handles = []
for layeri in range(model.config.n_layer):
  baselayer = model.transformer.h[layeri]
  h = baselayer.register_forward_hook(implant_scale_hook(layeri))
  handles.append(h)

In [None]:
# test with one layer
layer2scale = 15

with torch.no_grad():
  outputs_scale = model(tokens.to(device))

outputs_scale.hidden_states[0].shape

In [None]:
# initialize
diffnorms = torch.zeros(model.config.n_layer)

# loop over layers
for layeri in range(model.config.n_layer):

  # extract hidden states for this layer
  hs_c = outputs_clean.hidden_states[layeri].cpu()
  hs_n = outputs_scale.hidden_states[layeri].cpu()

  # norm of difference matrix
  diffnorms[layeri] = torch.norm(hs_c-hs_n)

# and plot
plt.figure(figsize=(10,3))
plt.plot(diffnorms,'kh',markerfacecolor=[.9,.7,.7],markersize=12)
plt.axvline(layer2scale+1,color='k',linestyle='--')
plt.axhline(0,color='k',linestyle=':',zorder=-10)
plt.gca().set(xlabel='Layer',ylabel='Norm of difference')

plt.tight_layout()
plt.savefig('ch5_proj27_part4a.png')
plt.show()

In [None]:
# initializations
log_sm_targets = torch.zeros((model.config.n_layer,2))

# loop over layers
for layer2scale in range(model.config.n_layer):

  # run the model
  with torch.no_grad():
    outputs_scale = model(tokens.to(device))

  # log-softmax the final token logits
  logsm = outputs_scale.logits[0,-1,:].log_softmax(dim=-1)

  # get the target ("curtain") value
  log_sm_targets[layer2scale,0] = logsm[target_token].item()

  # get the max value
  maxtok = torch.argmax(outputs_scale.logits[0,-1,:])
  log_sm_targets[layer2scale,1] = logsm[maxtok].item()

  # print the completed text
  print(f'L{layer2scale:2}: {text}"{tokenizer.decode(maxtok)}"')

In [None]:
_,axs = plt.subplots(1,2,figsize=(12,3.5))

# plot the impact of the perturbations
axs[0].axhline(log_sm_target_clean,color='g',linestyle='--',label='Clean')
axs[0].plot(log_sm_targets[:,0],'kh',markerfacecolor=[.9,.7,.7,.7],markersize=12,label='Scaled')
axs[0].legend()
axs[0].set(xlabel='Layer',ylabel='log-softmax prob',#ylim=[-16,1],
              title=f'Impact of down-scaling on log-softmax of "{tokenizer.decode(target_token)}"')

# and the max logit
axs[1].axhline(log_sm_target_clean,color='g',linestyle='--',label='Clean')
axs[1].plot(log_sm_targets[:,1],'ks',markerfacecolor=[.7,.7,.9,.7],markersize=12)
axs[1].set(xlabel='Layer',ylabel='log-softmax prob',ylim=axs[0].get_ylim(),
              title='Impact of down-scaling on log-softmax of max logit')

plt.tight_layout()
plt.savefig('ch5_proj27_part4b.png')
plt.show()