In [1]:
import os, sys

sys.path.append(os.path.abspath(".."))

from jax.config import config

config.update("jax_enable_x64", True)


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")


Using renderer: notebook_connected


In [3]:
import circuitsvis as cv
import circuitsvis.examples as cve

cve.hello("Alex")

In [4]:
# Import stuff
import numpy as np
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px

from jaxtyping import Float, Int, Array
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, GPT2TokenizerFast
import dataclasses
import datasets
from IPython.display import HTML

from tx.hooks import HookPoint, StoreHook
from tx.models.gpt2 import PretrainedGPT2Model
from tx.network import GenerativeModel

In [5]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(tensor, color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(tensor, labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [6]:
model = GenerativeModel.from_pretrained("gpt2")


Loading GPT2 params...


2023-10-11 04:39:54.241890: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Using sep_token, but it is not set yet.
Using pad_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using mask_token, but it is not set yet.


In [7]:
model_description_text = """## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. See my explainer for documentation of all supported models, and this table for hyper-parameters and the name used to load them. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly. 

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!"""
loss = model(model_description_text, return_type="loss")
print("Model loss:", loss)

Model loss: (Array(4.18662207, dtype=float64), {})


In [8]:
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = model.to_tokens(gpt2_text, prepend_bos=True)
gpt2_logits, gpt2_output = model.run_with_intermediates(gpt2_tokens)
gpt2_cache = gpt2_output["intermediates"]


In [9]:
attention_pattern = gpt2_cache["block_0"]["attn"][HookPoint.ATTN_SCORES.value][0]
print(attention_pattern.shape)
gpt2_str_tokens = model.to_str_list(gpt2_text, prepend_bos=True)

(12, 33, 33)


In [10]:
print("Layer 0 Head Attention Patterns:")
cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)

Layer 0 Head Attention Patterns:


In [11]:
# layer_to_ablate = 0
# head_index_to_ablate = 8


# def head_ablation_hook(
#     value: Float[Array, "... S NH HD"], hook: HookPoint
# ) -> Float[Array, "... S NH HD"]:
#     print(f"Shape of the value tensor: {value.shape}")
#     value[..., :, head_index_to_ablate, :] = 0.0
#     return value

# original_loss = model(gpt2_tokens, return_type="loss")
# ablated_loss = model.run_with_hooks(
#     gpt2_tokens,
#     return_type="loss",
#     fwd_hooks=[(utils.get_act_name("v", layer_to_ablate), head_ablation_hook)],
# )
# print(f"Original Loss: {original_loss.item():.3f}")
# print(f"Ablated Loss: {ablated_loss.item():.3f}")
