In this notebook, we want to:

- investigate the spatial structure of the residual stream
- see which tokens the different directions in the residual stream map to

In [34]:
# Generic
import os
from pathlib import Path
from copy import deepcopy
import typing

# Numerical Computing
import numpy as np
import torch
import pandas as pd
# import torch.nn.functional as F
from fancy_einsum import einsum
import einops
from jaxtyping import Float, Int, Bool
import matplotlib.pyplot as plt
from matplotlib import gridspec
from sklearn.decomposition import PCA
import ipywidgets

from muutils.misc import shorten_numerical_to_str
from muutils.nbutils.configure_notebook import configure_notebook
# TransformerLens imports
from transformer_lens import ActivationCache

# Our Code
from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze, LatticeMaze, SPECIAL_TOKENS
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP
from maze_dataset.tokenization.token_utils import strings_to_coords, coords_to_strings
from maze_dataset.constants import _SPECIAL_TOKENS_ABBREVIATIONS

from maze_transformer.training.config import ConfigHolder, ZanjHookedTransformer, BaseGPTConfig
from maze_transformer.utils.dict_shapes import string_dict_shapes
from maze_transformer.mechinterp.plot_weights import plot_embeddings


In [2]:
# Setup (we won't be training any models)
DEVICE: torch.device = configure_notebook(seed=42, dark_mode=False)
print(f"{DEVICE = }")
torch.set_grad_enabled(False)


DEVICE = device(type='cpu')


<torch.autograd.grad_mode.set_grad_enabled at 0x20bf03349d0>

In [3]:
# path to load the model from
MODEL_PATH: str = "../examples/hallway-medium_2023-06-16-03-40-47.iter_26554.zanj"

In [4]:
MODEL: ZanjHookedTransformer = ZanjHookedTransformer.read(MODEL_PATH)
num_params: int = MODEL.num_params()
print(f"loaded model with {shorten_numerical_to_str(num_params)} params ({num_params = }) from\n{MODEL_PATH}")
TOKENIZER: MazeTokenizer = MODEL.zanj_model_config.maze_tokenizer

loaded model with 1.3M params (num_params = 1274699) from
../examples/hallway-medium_2023-06-16-03-40-47.iter_26554.zanj


In [5]:
# embed each token in the vocabulary
print(f"{TOKENIZER.token_arr = }")
d_model: int = MODEL.config.d_model

# get the embedding matrix
print(MODEL.W_E.shape)
assert MODEL.W_E.shape == (TOKENIZER.vocab_size, d_model)


TOKENIZER.token_arr = ['<ADJLIST_START>', '<ADJLIST_END>', '<TARGET_START>', '<TARGET_END>', '<ORIGIN_START>', '<ORIGIN_END>', '<PATH_START>', '<PATH_END>', '<-->', ';', '<PADDING>', '(0,0)', '(0,1)', '(1,0)', '(1,1)', '(0,2)', '(2,0)', '(1,2)', '(2,1)', '(2,2)', '(0,3)', '(3,0)', '(3,1)', '(2,3)', '(3,2)', '(1,3)', '(3,3)', '(0,4)', '(2,4)', '(4,0)', '(1,4)', '(4,1)', '(4,2)', '(3,4)', '(4,3)', '(4,4)', '(0,5)', '(5,0)', '(5,1)', '(2,5)', '(5,2)', '(5,3)', '(4,5)', '(5,4)', '(1,5)', '(3,5)', '(5,5)', '(0,6)', '(2,6)', '(4,6)', '(6,0)', '(1,6)', '(6,1)', '(6,2)', '(3,6)', '(6,3)', '(6,4)', '(5,6)', '(6,5)', '(6,6)', '(0,7)', '(7,0)', '(7,1)', '(2,7)', '(7,2)', '(7,3)', '(4,7)', '(7,4)', '(7,5)', '(6,7)', '(7,6)', '(1,7)', '(3,7)', '(5,7)', '(7,7)']
torch.Size([75, 128])


In [6]:
VOCAB_TOKENS: Int[torch.Tensor, "vocab_size"] = torch.arange(TOKENIZER.vocab_size, device=DEVICE)
assert VOCAB_TOKENS.tolist() == TOKENIZER.encode(TOKENIZER.token_arr)

In [7]:
# plot_embeddings(MODEL, token_arr=TOKENIZER.token_arr)

In [42]:
def coordinate_to_color(coord: tuple[float, float], max_val: float = 1.0) -> tuple[float, float, float]:
	"""Maps a coordinate (i, j) to a unique RGB color"""
	coord = np.array(coord)
	if max_val < coord.max():
		raise ValueError(f"max_val ({max_val}) must be at least as large as the largest coordinate ({coord.max()})")
	
	coord = coord / max_val

	return (
		coord[0] * 0.6 + 0.3, # r
		0.5,                  # g
		coord[1] * 0.6 + 0.3, # b
	)



tokens_coords: list[str|tuple[int,int]] = strings_to_coords(TOKENIZER.token_arr, when_noncoord="include")
tokens_coords_only: list[tuple[int,int]] = strings_to_coords(TOKENIZER.token_arr, when_noncoord="skip")
max_coord: int = np.array(tokens_coords_only).max()
token_idxs_coords: list[int] = TOKENIZER.encode(TOKENIZER.coords_to_strings(tokens_coords_only))

vocab_coordinates_colored: list[tuple[
	str, # token
	tuple[int, int]|str, # coordinate
	tuple[float, float, float], # color
]] = list(zip(
	TOKENIZER.token_arr,
	tokens_coords,
	[
		coordinate_to_color(coord, max_val=max_coord) if isinstance(coord, tuple) else (0.0, 1.0, 0.0)
		for coord in tokens_coords
	]
))

print(f"{vocab_coordinates_colored = }")
print(f"{len(token_idxs_coords) = }")
print(f"{token_idxs_coords = }")


vocab_coordinates_colored = [('<ADJLIST_START>', '<ADJLIST_START>', (0.0, 1.0, 0.0)), ('<ADJLIST_END>', '<ADJLIST_END>', (0.0, 1.0, 0.0)), ('<TARGET_START>', '<TARGET_START>', (0.0, 1.0, 0.0)), ('<TARGET_END>', '<TARGET_END>', (0.0, 1.0, 0.0)), ('<ORIGIN_START>', '<ORIGIN_START>', (0.0, 1.0, 0.0)), ('<ORIGIN_END>', '<ORIGIN_END>', (0.0, 1.0, 0.0)), ('<PATH_START>', '<PATH_START>', (0.0, 1.0, 0.0)), ('<PATH_END>', '<PATH_END>', (0.0, 1.0, 0.0)), ('<-->', '<-->', (0.0, 1.0, 0.0)), (';', ';', (0.0, 1.0, 0.0)), ('<PADDING>', '<PADDING>', (0.0, 1.0, 0.0)), ('(0,0)', (0, 0), (0.3, 0.5, 0.3)), ('(0,1)', (0, 1), (0.3, 0.5, 0.3857142857142857)), ('(1,0)', (1, 0), (0.3857142857142857, 0.5, 0.3)), ('(1,1)', (1, 1), (0.3857142857142857, 0.5, 0.3857142857142857)), ('(0,2)', (0, 2), (0.3, 0.5, 0.4714285714285714)), ('(2,0)', (2, 0), (0.4714285714285714, 0.5, 0.3)), ('(1,2)', (1, 2), (0.3857142857142857, 0.5, 0.4714285714285714)), ('(2,1)', (2, 1), (0.4714285714285714, 0.5, 0.3857142857142857)), ('(2

In [39]:

pca_all: PCA = PCA(svd_solver='full')
PCA_RESULTS = pca_all.fit_transform(MODEL.W_E.cpu().numpy().T)

pca_coords: PCA = PCA(svd_solver='full')
PCA_RESULTS_COORDS_ONLY = pca_coords.fit_transform(MODEL.W_E[token_idxs_coords].cpu().numpy().T)


In [41]:
print(f"{MODEL.W_E.shape = }")
print(f"{PCA_RESULTS.shape = }")
print(f"{PCA_RESULTS_COORDS_ONLY.shape = }")


MODEL.W_E.shape = torch.Size([75, 128])
PCA_RESULTS.shape = (128, 75)
PCA_RESULTS_COORDS_ONLY.shape = (128, 64)


In [46]:
# plot the PCA
def plot_pca_colored(
    pca_results: np.ndarray, 
    vocab_colors: list[tuple],
    dim1: int, 
    dim2: int,
    index_map: list[int]|None = None,
) -> None:
    
    fig, ax = plt.subplots(figsize=(5, 5))

    for i in range(pca_results.shape[1]):
        if index_map is not None:
            i_map: int = index_map[i]
        else:
            i_map = i
        token, coord, color = vocab_colors[i_map]
        ax.scatter(
            pca_results[dim1-1, i], 
            pca_results[dim2-1, i], 
            alpha=0.5,
            color=color,
        )
        if isinstance(coord, str):
            ax.text(
                pca_results[dim1-1, i], 
                pca_results[dim2-1, i], 
                _SPECIAL_TOKENS_ABBREVIATIONS[coord],
                fontsize=8,
            )
        
    ax.set_xlabel(f"PC{dim1}")
    ax.set_ylabel(f"PC{dim2}")
    ax.set_title(f"PCA of Survey Responses:\nPC{dim1} vs PC{dim2}")
    ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.show()

# Dropdowns for PCA dimensions
dim1_dropdown = ipywidgets.IntText(
    value=1,
    description='Dim 1:',
    disabled=False
)

dim2_dropdown = ipywidgets.IntText(
    value=2,
    description='Dim 1:',
    disabled=False
)

ipywidgets.interact(
    plot_pca_colored, 
    pca_results=ipywidgets.fixed(PCA_RESULTS_COORDS_ONLY), 
    vocab_colors=ipywidgets.fixed(vocab_coordinates_colored), 
    dim1=dim1_dropdown,
    dim2=dim2_dropdown,
    index_map=ipywidgets.fixed(token_idxs_coords),
);

interactive(children=(IntText(value=1, description='Dim 1:'), IntText(value=2, description='Dim 1:'), Output()…