In [1]:
import transformer_lens

# Load a model (eg GPT-2 Small)
model = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")

# Run the model and get logits and activations
logits, activations = model.run_with_cache("Hello World")

  from .autonotebook import tqdm as notebook_tqdm
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gpt2-small into HookedTransformer


In [2]:
from datasets import load_dataset

dataset = load_dataset("Skylion007/openwebtext", split="train")

In [3]:
import numpy as np
lengths = np.array([len(x["text"]) for x in dataset])

In [None]:
print("mean of lengths: ", np.mean(lengths))
print("std of lengths: ", np.std(lengths))
print("min of lengths: ", np.min(lengths))
print("max of lengths: ", np.max(lengths))
print("total lengths: ", np.sum(lengths)) # roughly 3x tokens => 13B tokens first order magnitude
# around 13B * 100 dims * 2 bytes => 2.5PB


mean of lengths:  4914.9018131169
std of lengths:  6662.991617660633
min of lengths:  316
max of lengths:  100000
total lengths:  39386887788


In [7]:
import torch
import torch.nn as nn
from typing import List, Tuple
from jaxtyping import Float

class NaiveNormalizingFlow(nn.Module):
    """
    Very naive normalizing flow model using the very first model introduced in this
    blogpost: https://lilianweng.github.io/posts/2018-10-13-flow-models/.

    It works like this:
        - You have a sequence of shifts and scales
        - Each subsequent element shifts the second shift scales set of indices
            by the shift of the first layer and scales by the scale of the
            first one too.
    NOTE: does not support general index mask (TODO you might want to implement this?)
    """
    def __init__(
            self,
            dim: int,
            shift_scales: List[Tuple[nn.Module, nn.Module]],
            split_index: List[Tuple[int, bool]]
        ):
        super().__init__()
        self.dim = dim
        self.shift_scales = nn.ModuleList(shift_scales)
        # NOTE: we train a bunch of linears, but they must all be invertible.
        # TODO(Adriano) add linear layers
        # self.linears = nn.ModuleList([nn.Linear(dim, dim, bias=True) for _ in range(len(shift_scales))])
        self.split_index = split_index
        assert len(split_index) == len(shift_scales)
        assert all(sum(ratio) == dim for ratio in split_index)

    def forward(self, x: Float[torch.Tensor, "batch dim"]) -> Float[torch.Tensor, "batch dim"]:
        """
        Implement sequential shift and scaling, alternating by index.
        """
        for (shift, scale), (index, index_is_upper) in zip(self.shift_scales, self.split_index):
            assert x.ndim == 2, f"x.shape={x.shape} (you should just flatten into ndim=2 into batch dim)"
            assert x.shape[1] == self.dim

            # Calculate the shift and scale
            sh = shift(x[..., :index]) if index_is_upper else shift(x[..., index:])
            assert sh.ndim == 2
            assert sh.shape[1] == self.dim - index if index_is_upper else index # Need to fit in REST of indices; fmt: skip
            # ...
            sc = scale(x[..., :index]) if index_is_upper else scale(x[..., index:])
            assert sc.ndim == 2
            assert sc.shape[1] == self.dim - index if index_is_upper else index # Need to fit in REST of indices; fmt: skip

            # Apply the scale and shift
            if index_is_upper: # Means that you READ from BELOW index and WRITE ABOVE index
                x[..., index:] = x[..., index:] * sc.exp()
                x[..., index:] = x[..., index:] + sh
            else: # Means that READ from ABOVE index and WRITE BELOW index
                x[..., :index] = x[..., :index] * sc.exp()
                x[..., :index] = x[..., :index] + sh
        return x
    
    # TODO(Adriano) implement this
    # def inverse(y: Float["batch dim", torch.Tensor]) -> Float["batch dim", torch.Tensor]:
    #     raise NotImplementedError("Not implemented")

    # def log_det_jacobians(self, x: Float["batch dim", torch.Tensor]) -> Float["batch dim", torch.Tensor]:
    #     raise NotImplementedError("Not implemented")
    # def log_det_jacobian(self, x: Float["batch dim", torch.Tensor]) -> Float["batch dim", torch.Tensor]:
    #     return torch.sum(self.log_det_jacobians(x), dim=-1)
    
    # def log_prob(self, x: Float["batch dim", torch.Tensor]) -> Float["batch dim", torch.Tensor]:
    #     raise NotImplementedError("Not implemented")
    
    # def sample(self, n: int) -> Float["batch dim", torch.Tensor]:
    #     raise NotImplementedError("Not implemented")
    
    

In [9]:
################ Test that the flow forwards method works (i.e. to give you output) #################
batch_size = 100
dim = 128
split_index = [(64, True), (64, False)]
mlps = [nn.Sequential(
    nn.Linear(d if u else dim - d, d if u else dim - d),
    nn.ReLU(),
    nn.Linear(d if u else dim - d, d if u else dim - d),
    nn.ReLU(),
    nn.Linear(d if u else dim - d, d if u else dim - d),
) for d, u in split_index]
assert len(mlps) == len(split_index)

mlps[0](torch.randn(10, 64)) # Make sure no error
mlps[1](torch.randn(10, 64)) # Make sure no error

flow = NaiveNormalizingFlow(dim, mlps, split_index)

x = torch.randn(batch_size, dim)
y = flow(x)
print(y.shape)


AssertionError: 

In [18]:
"""
Collect some activations :/
"""
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
from collect_activations import collect_all_activations
device = "cuda:0"

# hf_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
# model = HookedTransformer.from_pretrained('google/gemma-2-9b-it', device=device)
# text = "bob is in paris, alice is in tokyo, actually bob is in london... bob is in"
# text_tok = hf_tokenizer(text, return_tensors = "pt").to(device)["input_ids"]
# # print(message_tokens)
# response = model.generate(text_tok, max_new_tokens = 8)
# print(hf_tokenizer.decode(response[0], skip_special_tokens = True))
# model = model.to(device)
# text = dataset[0]["text"]
# text_tok = model.to_tokens(text)
# model(text_tok)
# Test that we are able to collet some activations
test_hook_names = ["blocks.6.hook_resid_pre"]
test_msg = "I am a happy watermelon. Every day I roll roll roll down the hill to see my friend the"
test_tok = model.tokenizer(test_msg, return_tensors = "pt").to(device)["input_ids"]
# test_resp_tok = model.generate(test_tok, max_new_tokens = 32)
# test_resp = hf_tokenizer.decode(test_resp_tok[0], skip_special_tokens = True)
# # debug
# test_brk = "\n" + "="*100 + "\n"
# print(test_msg, test_brk, test_tok, test_brk, test_resp_tok, test_brk, test_resp)

# run em
outputs, losses = collect_all_activations(model, test_tok, test_hook_names, inference_batch_size=20)
print(f"Outputs has shape: {outputs.shape}")
print(f"Losses has shape {losses.shape}")

Average loss is: 4.412842750549316: 100%|██████████| 1/1 [00:00<00:00, 37.28it/s]

Outputs has shape: torch.Size([1, 1, 21, 768])
Losses has shape torch.Size([1])





In [16]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h