# Notebook config

In [None]:
directory = 'add-patch-v2'

# Imports

In [None]:
%%capture
%%bash
pip install transformers==4.27.4
pip install datasets
pip install nltk

In [None]:
from ast import literal_eval
import functools
import json
import os
import random
import re

# Scientific packages
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
torch.set_grad_enabled(False)
tqdm.pandas()

from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

import altair as alt
alt.data_transformers.disable_max_rows()

def softmax(x):
  e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
  return e_x / e_x.sum(axis=-1, keepdims=True)

In [None]:
import torch
import transformers
import warnings
from matplotlib.colors import to_hex

print(torch.cuda.is_available())
if torch.cuda.is_available():
  num_gpus = torch.cuda.device_count()
  print(f"Number of available GPUs: {num_gpus}")

  for i in range(num_gpus):
      gpu_name = torch.cuda.get_device_name(i)
      print(f"GPU {i}: {gpu_name}")

True
Number of available GPUs: 1
GPU 0: NVIDIA A100-SXM4-40GB


In [None]:
import jax.numpy as jnp
class NpEncoder(json.JSONEncoder):
  """Save NP as json."""

  def default(self, o):
    if isinstance(o, np.integer):
      return int(o)
    if isinstance(o, np.floating):
      return float(o)
    if isinstance(o, np.ndarray):
      return o.tolist()

    if isinstance(o, jnp.integer):
      return int(o)
    if isinstance(o, jnp.floating):
      return float(o)
    if isinstance(o, jnp.ndarray):
      return o.tolist()

    if isinstance(o, range):
      return list(o)

    return super(NpEncoder, self).default(o)

# Load model

In [None]:
"""Utility class and functions.

Adapted from:
https://github.com/kmeng01/rome/blob/bef95a6afd2ca15d794bdd4e3ee0f24283f9b996/
"""

def set_requires_grad(requires_grad, *models):
  for model in models:
    if isinstance(model, torch.nn.Module):
      for param in model.parameters():
        param.requires_grad = requires_grad
    elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
      model.requires_grad = requires_grad
    else:
      assert False, "unknown type %r" % type(model)

def remove_hooks_from_model(model):
  for module in model.modules():
    module._forward_hooks.clear()


class ModelAndTokenizer:
  """An object to hold a GPT-style language model and tokenizer."""


  def __init__(
      self,
      model_name=None,
      model=None,
      tokenizer=None,
      low_cpu_mem_usage=False,
      torch_dtype=None,
      ):
    if tokenizer is None:
      assert model_name is not None
      tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    if model is None:
      assert model_name is not None
      model = transformers.AutoModelForCausalLM.from_pretrained(
          model_name, low_cpu_mem_usage=low_cpu_mem_usage,
          torch_dtype=torch_dtype
          )
      set_requires_grad(False, model)
      model.eval().cuda()
    self.tokenizer = tokenizer
    self.model = model
    self.layer_names = [
        n
        for n, _ in model.named_modules()
        if (re.match(r"^(transformer|gpt_neox)\.(h|layers)\.\d+$", n))
    ]
    self.num_layers = len(self.layer_names)

  def __repr__(self):
    """String representation of this class.
    """
    return (
        f"ModelAndTokenizer(model: {type(self.model).__name__} "
        f"[{self.num_layers} layers], "
        f"tokenizer: {type(self.tokenizer).__name__})"
        )

In [None]:
mt = ModelAndTokenizer(
    "EleutherAI/gpt-j-6B", # "EleutherAI/pythia-70m-deduped-v0"
    low_cpu_mem_usage=False,
    torch_dtype=None,
)

In [None]:
mt.model

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f)

In [None]:
sorted_vocab = sorted(mt.tokenizer.vocab.items(), key=lambda item: item[1])
vocab_list = [mt.tokenizer.decode(d[1]) for d in sorted_vocab]

# Scratch

## Unembed

In [None]:
def fmt_layer_sweep_df(output_dst_logits, count=5):
  top_indices = np.unique(np.argsort(output_dst_logits)[:, -count:].flatten())
  softmax_values = softmax(output_dst_logits)
  sorted_indices = np.argsort(-output_dst_logits, axis=1)

  data = []
  for layer, logits in enumerate(output_dst_logits):
    for idx in top_indices:
      data.append({
        'Layer': layer,
        'Token': vocab_list[idx],
        'Logit': logits[idx],
        'Softmax': softmax_values[layer, idx],
        'Rank': (sorted_indices[layer] == idx).argmax() + 1,
        'idx': idx,
      })

  return pd.DataFrame(data)

def plot_layer_sweep(output_dst_logits, count=5):
  base = alt.Chart(fmt_layer_sweep_df(output_dst_logits, count)).encode(
    x=alt.X('Layer:Q', axis=alt.Axis(tickCount=5)),
    tooltip=alt.Tooltip('Token:N'),
    color='Token:N',
  ).properties(
      width=200
  ).mark_line()

  logit_chart = base.encode(y='Logit:Q')
  softmax_chart = base.encode(y='Softmax:Q')
  rank_chart = base.encode(y=alt.Y('Rank:Q', scale=alt.Scale(type='log', reverse=True)))

  return alt.hconcat(rank_chart, softmax_chart, logit_chart)

In [None]:
inputs = mt.tokenizer.encode('What is 23 plus 55?\n', return_tensors='pt').to(device='cuda')
outputs = mt.model(inputs, output_hidden_states=True)
layer_logits_orig = np.stack([mt.model.lm_head(state[0, -1, :]).cpu() for state in outputs.hidden_states])

plot_layer_sweep(layer_logits_orig, 1)

# attn sweep

In [None]:
def patch_layer_sweep(d):
  prompt_src, prompt_dst, layers_dst, position_src, position_dst = d['prompt_src'], d['prompt_dst'], d['layers_dst'], d['position_src'], d['position_dst']
  layers_src = d.get('layers_src', layers_dst)

  # Run the model on prompt_src and get all hidden states.
  # inp_src = make_inputs(mt.tokenizer, [prompt_src])
  # output_src = mt.model(**inp_src, output_hidden_states=True)

  inp_src = mt.tokenizer.encode(prompt_src, return_tensors='pt').to(device='cuda')
  inp_dst = mt.tokenizer.encode(prompt_dst, return_tensors='pt').to(device='cuda')
  output_src = mt.model(inputs, output_hidden_states=True)

  hs_cache_ = [output_src['hidden_states'][layer+1][0] for layer in range(mt.num_layers)]

  # Loop over all layers
  # TODO: Batch?
  output_dst_logits = []
  for index, layer_src in enumerate(layers_src):
    def patch(module, input, output):
      output[0][0, position_dst] = hs_cache_[layer_src][position_src]

    # Run the model on prompt_dst, while patching in hidden state the prompt_src run.
    layer_dst = layers_dst[index]
    mt.model.transformer.h[layer_dst].register_forward_hook(patch)
    output_dst = mt.model(inp_dst, output_hidden_states=True)
    remove_hooks_from_model(mt.model)

    output_dst_logits.append(output_dst.logits[0, -1, :].cpu().numpy())

  return output_src.logits[0, -1, :].cpu().numpy(), np.array(output_dst_logits)


logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': 'j j 82 82 c c t t X',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})


In [None]:
plot_layer_sweep(output_dst_logits[:], 2)

In [None]:
logits_src_0 = logits_src

## old patch

#### old code

In [None]:
def set_hs_patch_hooks(model, hs_patch_config, patch_input=False, generation_mode=False):
    # when using mode.generate() the hidden states in the input are cached after the first inference pass,
    # and in the next steps the input/output are of size 1. In these cases we don't need to patch anymore the
    # previous hidden states from the initial input, because they are cached, but we do need to handle these
    # cases in this call because this hook wraps the generation call.
    #
    # NOTE: To use generation mode, we must patch a position that is not the first one. This is because in
    # this case we don't know during generation if we are handling the initial input or a future step and
    # thus don't know if a patching is needed or not.
    if generation_mode:
        for i in hs_patch_config:
            for position_, _ in hs_patch_config[i]:
                assert position_ > 0

    def patch_hs(name, position_hs, patch_input, generation_mode):

        def pre_hook(module, input):
            # (batch, sequence, hidden_state)
            input_len = len(input[0][0])
            if generation_mode and input_len == 1:
                return
            for position_, hs_ in position_hs:
                input[0][0, position_] = hs_

        def post_hook(module, input, output):
            # (batch, sequence, hidden_state)
            output_len = len(output[0][0])
            if generation_mode and output_len == 1:
                return
            for position_, hs_ in position_hs:
                output[0][0, position_] = hs_

        if patch_input:
            return pre_hook
        else:
            return post_hook

    hooks = []
    for i in hs_patch_config:
        if patch_input:
            hooks.append(model.transformer.h[i].register_forward_pre_hook(
                patch_hs(f"patch_hs_{i}", hs_patch_config[i], patch_input, generation_mode)
            ))
        else:
            hooks.append(model.transformer.h[i].register_forward_hook(
                patch_hs(f"patch_hs_{i}", hs_patch_config[i], patch_input, generation_mode)
            ))

    return hooks

def remove_hooks(hooks):
    for hook in hooks:
        hook.remove()

def make_inputs(tokenizer, prompts, device="cuda"):
  """Prepare inputs to the model."""
  token_lists = [tokenizer.encode(p) for p in prompts]
  maxlen = max(len(t) for t in token_lists)
  if "[PAD]" in tokenizer.all_special_tokens:
    pad_id = tokenizer.all_special_ids[
        tokenizer.all_special_tokens.index("[PAD]")
        ]
  else:
    pad_id = 0
  input_ids = [
      [pad_id] * (maxlen - len(t)) + t for t in token_lists]
  attention_mask = [
      [0] * (maxlen - len(t)) + [1] * len(t) for t in token_lists
      ]
  return dict(
      input_ids=torch.tensor(input_ids).to(device),
      attention_mask=torch.tensor(attention_mask).to(device),
      )

def decode_tokens(tokenizer, token_array):
  if hasattr(token_array, "shape") and len(token_array.shape) > 1:
    return [decode_tokens(tokenizer, row) for row in token_array]
  return [tokenizer.decode([t]) for t in token_array]

def predict_from_input(model, inp):
  out = model(**inp)["logits"]
  probs = torch.softmax(out[:, -1], dim=1)
  p, preds = torch.max(probs, dim=1)
  return preds, p

def set_requires_grad(requires_grad, *models):
  for model in models:
    if isinstance(model, torch.nn.Module):
      for param in model.parameters():
        param.requires_grad = requires_grad
    elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
      model.requires_grad = requires_grad
    else:
      assert False, "unknown type %r" % type(model)

#### charts

In [None]:
def patch_layer_sweep(d):
  prompt_src, prompt_dst, layers_dst, position_src, position_dst = d['prompt_src'], d['prompt_dst'], d['layers_dst'], d['position_src'], d['position_dst']
  layers_src = d.get('layers_src', layers_dst)

  # Run the model on prompt_src and get all hidden states.
  inp_src = make_inputs(mt.tokenizer, [prompt_src])
  output_src = mt.model(**inp_src, output_hidden_states=True)
  hs_cache_ = [output_src['hidden_states'][layer+1][0] for layer in range(mt.num_layers)]

  # Loop over all layers
  # TODO: Batch?
  output_dst_logits = []
  for index, layer_src in enumerate(layers_src):
    layer_dst = layers_dst[index]

    # Run the model on prompt_dst, while patching in hidden state the prompt_src run.
    hs_patch_config = {
        layer_dst: [(position_dst, hs_cache_[layer_src][position_src])]
    }
    patch_hooks = set_hs_patch_hooks(mt.model, hs_patch_config, patch_input=False)
    inp_dst = make_inputs(mt.tokenizer, [prompt_dst])
    output_dst = mt.model(**inp_dst)
    remove_hooks(patch_hooks)

    output_dst_logits.append(output_dst.logits[0, -1, :].cpu().numpy())

  return output_src.logits[0, -1, :].cpu().numpy(), np.array(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': 'j j 82 82 c c t t X',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep(output_dst_logits)

In [None]:
logits_src.shape

(50400,)

In [None]:
logits_src_0 -logits_src

array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)