### Setup

In [None]:
import matplotlib.pyplot as plt
import math
import networkx as nx
import numpy as np
import os
import torch
import transformer_lens
import random

from tqdm import tqdm
from transformer_lens import HookedTransformer, HookedTransformerConfig
import transformer_lens.utils as utils

from tree_generation import generate_example, parse_example, GraphDataset
from utils import *
from interp_utils import *

: 

### Model Training

In [None]:
n_examples = 172_000
n_states = 16

dataset = GraphDataset(n_states, "dataset.txt", n_examples)
dataset.visualize_example(0)
train_loader, test_loader = get_loaders(dataset, 32)

: 

In [None]:
cfg = HookedTransformerConfig(
    n_layers=6,
    d_model=128,
    n_ctx=dataset.max_seq_length - 1,
    n_heads=1,
    d_mlp=512,
    d_head=128,
    #attn_only=True,
    d_vocab=len(dataset.idx2tokens),
    device="cuda",
    attention_dir= "causal",
    act_fn="gelu",
)
model = HookedTransformer(cfg)


: 

In [None]:
model_name="model_epoch0.pt"

: 

In [None]:
# Load in the model if weights are in the directory, else train new model
if os.path.exists(model_name):
    model.load_state_dict(torch.load(model_name))
else:
    train(model, train_loader, test_loader, n_epochs=100, learning_rate=3e-4)
    torch.save(model.state_dict(), model_name)

: 

In [None]:
start_seed = 250_000
num_samples = 1_000

for order in ["forward", "backward", "random"]:
  total_correct = 0
  for seed in range(start_seed, start_seed + num_samples):
      graph = generate_example(16, seed, order=order)
      pred, correct = eval_model(model, dataset, graph)
      if correct:
        total_correct += 1      

  print(f"{order}: {100* total_correct / num_samples:.4f}%")

: 

In [None]:
labels, cache = get_example_cache(pred, model, dataset)

: 

In [None]:
for l in range(model.cfg.n_layers):
    for h in range(model.cfg.n_heads):
        fig = display_head(cache, labels, l, h, show=True)

: 

### Experiment: understanding embeddings

In [None]:
embedding_matrix = model.W_E.detach().cpu().numpy()
embedding_matrix = embedding_matrix - embedding_matrix.mean(axis=-1, keepdims=True)

# Compute L2 norm for each row
row_norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
embedding_matrix = embedding_matrix / row_norms

incoming_embeddings = embedding_matrix[3:19]
outgoing_embeddings = embedding_matrix[19:]
all_nodes = embedding_matrix[3:]

: 

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score

X = all_nodes
y = np.zeros((32,))
y[16:] = 1.0

probe = LinearRegression().fit(X, y)
# Calculate train loss
y_pred = probe.predict(X)
loss = mean_squared_error(y, y_pred)
loss

: 

In [None]:
from sklearn.decomposition import PCA


pca = PCA()
pca.fit(all_nodes)

explained_variance = pca.explained_variance_ratio_
plt.plot(explained_variance)

plt.show()

: 

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score

X = all_nodes
y = np.zeros((32,))
y[16:] = 1.0

probe = LinearRegression().fit(X, y)
# Calculate train loss
y_pred = probe.predict(X)
loss = mean_squared_error(y, y_pred)
loss

: 

In [None]:
def add_low_rank_hook(layer, k):
    
    X = []

    for i in range(1_000):
        # Sample example
        test_graph = generate_example(n_states, i, order="random")
        pred, correct = eval_model(model, dataset, test_graph)
        if not correct:
            continue
        labels, cache = get_example_cache(pred, model, dataset)
        # Record information
        X.append(cache[f"blocks.{layer}.attn.hook_v"][0, [i for i in range(45) if (i-1) % 3 == 0], 0])

    X = torch.cat(X, dim=0).detach().cpu().numpy()

    pca = PCA(n_components=k)  # k is the desired rank of the approximation
    pca = pca.fit(X)

    # Add low-rank hook

    from functools import partial

    def low_rank_hook(
            resid_pre,
            hook,
            position):
        # Each HookPoint has a name attribute giving the name of the hook.
        np_resid = resid_pre[:, position, 0, :].detach().cpu().numpy()
        b, p, d = np_resid.shape
        low_rank = pca.transform(np_resid.reshape(b*p, d))
        inv_low_rank = pca.inverse_transform(low_rank).reshape(b, p, d)
        resid_pre[:, position, 0, :] = torch.from_numpy(inv_low_rank).cuda()
        return resid_pre


    temp_hook_fn = partial(low_rank_hook, position=[i for i in range(45) if (i-1) % 3 == 0])
    model.blocks[layer].attn.hook_v.add_hook(temp_hook_fn)

: 

In [None]:
model.reset_hooks()
add_low_rank_hook(1, 16)
add_low_rank_hook(2, 16)
add_low_rank_hook(3, 16)
add_low_rank_hook(4, 16)
add_low_rank_hook(5, 16)

: 

In [None]:
start_seed = 250_000
num_samples = 100

for order in ["forward", "backward", "random"]:
  total_correct = 0
  for seed in range(start_seed, start_seed + num_samples):
      graph = generate_example(16, seed, order=order)
      pred, correct = eval_model(model, dataset, graph)
      if correct:
        total_correct += 1      

  print(f"{order}: {100* total_correct / num_samples:.4f}%")

: 

In [None]:
X = []

for i in range(1_000):
        # Sample example
        test_graph = generate_example(n_states, i, order="random")
        pred, correct = eval_model(model, dataset, test_graph)
        if not correct:
            continue
        labels, cache = get_example_cache(pred, model, dataset)
        # Record information
        X.append(cache[f"blocks.{1}.attn.hook_v"][0, [i for i in range(45) if (i-1) % 3 == 0], 0])

X = torch.cat(X, dim=0).detach().cpu().numpy()


: 

In [None]:
pca_128 = PCA(n_components=128)  # k is the desired rank of the approximation
pca_128 = pca.fit(X)

: 

In [None]:
pca_16 = PCA(n_components=16)  # k is the desired rank of the approximation
pca_16 = pca.fit(X)

: 

In [None]:
explained_variance = pca_16.explained_variance_ratio_
plt.plot(explained_variance)

plt.show()

: 

In [None]:
pca_16.n_components_

: 

In [None]:
imshow(pca_128.components_[0:16])

: 

In [None]:
test_graph = generate_example(n_states, i, order="random")

: 

In [None]:
parse_example(test_graph)

: 

In [None]:
ablated_edges, important_edges=attention_knockout_discovery(model, dataset, test_graph)


: 

In [None]:
edge_list = important_edges.values() 

: 

In [None]:
list(edge_list)

: 

In [None]:
modified_list = [tuple_[1:] for tuple_ in edge_list]

: 

In [None]:
modified_list

: 

In [None]:
G = nx.DiGraph(modified_list)

: 

In [None]:
model.reset_hooks()

: 

In [None]:
test_graph = generate_example(n_states, np.random.randint(400_000, 600_000), order="backward")
pred, correct = eval_model(model, dataset, test_graph)
if correct:
    parse_example(pred)

: 

In [None]:
parse_example(pred)

: 

In [None]:
def generate_goal_distance_examples(distance):
    # Generate clean and corrupted prompts
    clean_prompt = "0>1,1>2,2>3,3>4,4>5,5>6,6>7,7>8,8>9,9>10,10>11,11>12,12>13,13>14,14>15|15:0>1>2>3>4>5>6>7>8>9>10>11>12>13>14>15"
    edges = [f"{i}>{i+1}" for i in range(15 - distance - 1)] + [f"{15 - distance - 1}>{15 - distance}", f"{15 - distance - 1}>{15 - distance + 1}"] + [f"{i}>{i+1}" for i in range(15 - distance + 1, 15)]
    path = '>'.join([str(x) for x in range(16) if x != 15 - distance])
    corrupted_prompt = ",".join(edges) + "|" + f"{15}:{path}"
    return clean_prompt, corrupted_prompt

: 

In [None]:
for i in range(16):
    clean_prompt, corrupted_prompt= generate_goal_distance_examples(i)
    logit_lens(corrupted_prompt, model, dataset)

: 

In [None]:
for i in range(16):
    clean_prompt, corrupted_prompt= generate_goal_distance_examples(i)
    logit_lens_correct_probs(corrupted_prompt, model, dataset,46 + (15 - i + 1))

: 

In [None]:
def logit_lens_correct_probs_result(pred, model, dataset, position):
    # Get labels and cache
    labels, cache = get_example_cache(pred, model, dataset)
    # Get the probability of the correct next token at every layer
    probs = []
    correct_token = labels[position+1]
    correct_token_idx = dataset.tokens2idx[correct_token]
    for layer in range(1, model.cfg.n_layers + 1):
        if layer < model.cfg.n_layers:
            res_stream = cache[utils.get_act_name("normalized", layer, "ln1")][0]
        else:
            res_stream = cache["ln_final.hook_normalized"][0]
        out_proj = res_stream @ model.W_U
        out_proj = out_proj.softmax(-1)
        probs.append( out_proj[position, correct_token_idx].item() )
    # Plot data
    return probs


: 

In [None]:
probs=[]
for i in range(15):
    clean_prompt, corrupted_prompt= generate_goal_distance_examples(i)
    probs.append(logit_lens_correct_probs_result(corrupted_prompt, model, dataset,46 + (15 - i + 1)))


: 

In [None]:
px.imshow(probs)

: 

In [None]:
clean_prompt, corrupted_prompt= generate_goal_distance_examples(0)

: 

In [None]:
corrupted_prompt

: 

In [None]:
logit_lens_correct_probs(pred, model, dataset)

: 

In [None]:
test_graph = generate_example(n_states, np.random.randint(400_000, 600_000), order="backward")

: 