## Setup

### Setup 1

In [76]:
import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"
chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
sys.path.append(chapter_dir + f"{chapter}/exercises")

import os
os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader
import numpy as np
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
import itertools
import random
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
from typing import List, Union, Optional, Tuple, Callable, Dict
import typeguard
from functools import partial
# from torcheval.metrics.functional import multiclass_f1_score
from sklearn.metrics import f1_score as multiclass_f1_score
import copy
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
import pandas as pd

# Make sure exercises are in the path
# exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
# section_dir = exercises_dir / "part6_othellogpt"
# section_dir = "interpretability"
# if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from neel_plotly import scatter, line
# import part6_othellogpt.tests as tests

device = t.device("cuda" if t.cuda.is_available() else "cpu")

### Setup 2

In [77]:
MAIN = __name__ == "__main__"

cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre",
    device=device,
)
model = HookedTransformer(cfg)

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

# An example input
sample_input = t.tensor([[
    20, 19, 18, 10,  2,  1, 27,  3, 41, 42, 34, 12,  4, 40, 11, 29, 43, 13, 48, 56,
    33, 39, 22, 44, 24,  5, 46,  6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37,  9,
    25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7
]]).to(device)

# The argmax of the output (ie the most likely next move from each position)
sample_output = t.tensor([[
    21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5, 33,  5,
    52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59, 50, 28, 14, 28,
    28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15, 14, 15,  8,  7,  8
]]).to(device)

assert (model(sample_input).argmax(dim=-1) == sample_output.to(device)).all()

# os.chdir(section_dir)
section_dir = Path.cwd()
sys.path.append(str(section_dir))
print(section_dir.name)

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

# if not OTHELLO_ROOT.exists():
#     !git clone https://github.com/likenneth/othello_world

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    plot_board,
    plot_single_board,
    plot_board_log_probs,
    to_string,
    to_int,
    int_to_label,
    string_to_label,
    OthelloBoardState
)

# Load board data as ints (i.e. 0 to 60)
board_seqs_int = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_int_small.npy"), dtype=t.long)
# Load board data as "strings" (i.e. 0 to 63 with middle squares skipped out)
board_seqs_string = t.tensor(np.load(OTHELLO_MECHINT_ROOT / "board_seqs_string_small.npy"), dtype=t.long)

assert all([middle_sq not in board_seqs_string for middle_sq in [27, 28, 35, 36]])
assert board_seqs_int.max() == 60

num_games, length_of_game = board_seqs_int.shape

# Define possible indices (excluding the four center squares)
stoi_indices = [i for i in range(64) if i not in [27, 28, 35, 36]]

# Define our rows, and the function that converts an index into a (row, column) label, e.g. `E2`
alpha = "ABCDEFGH"

def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"

# Get our list of board labels
board_labels = list(map(to_board_label, stoi_indices))
full_board_labels = list(map(to_board_label, range(64)))

def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    kwargs = {
        "y": [i for i in alpha],
        "x": [str(i) for i in range(8)],
        "color_continuous_scale": "RdBu" if diverging_scale else "Blues",
        "color_continuous_midpoint": 0. if diverging_scale else None,
        "aspect": "equal",
        **kwargs
    }
    imshow(state, **kwargs)

start = 30000
num_games = 50
focus_games_int = board_seqs_int[start : start + num_games]
focus_games_string = board_seqs_string[start: start + num_games]

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].to(device))
focus_logits.shape

def one_hot(list_of_ints, num_classes=64):
    out = t.zeros((num_classes,), dtype=t.float32)
    out[list_of_ints] = 1.
    return out

focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = t.zeros((num_games, 60, 64), dtype=t.float32)

for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())

print("focus states:", focus_states.shape)
print("focus_valid_moves", tuple(focus_valid_moves.shape))

# full_linear_probe = t.load(OTHELLO_MECHINT_ROOT / "main_linear_probe.pth", map_location=device)

linear_probe2 = t.load("probes/linear/resid_6_linear.pth")

rows = 8
cols = 8
options = 3
assert linear_probe2.shape == (1, cfg.d_model, rows, cols, options)

black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2

# Creating values for linear probe (converting the "black/white to play" notation into "me/them to play")

'''LAYER = 6
game_index = 0
move = 29'''

BLANK1 = 0
BLACK = 1
WHITE = -1

# MINE = 0
# YOURS = 1
# BLANK2 = 2

EMPTY = 0
YOURS = 1
MINE = 2

interpretability
focus states: (50, 60, 8, 8)
focus_valid_moves (50, 60, 64)


## Code

In [78]:
from utils import plot_game
from training_utils import get_state_stack_num_flipped
from utils import plot_probe_outputs

### Plots

In [79]:
'''game_index = 0
move = 4
end_move = 16
LAYER = 5
square = "D5"
square_tuple = (3, 5)
tile_state_clean = 1'''

plot_game(focus_games_string, game_index=0, end_move = 16)
'''# plot_single_board(focus_games_string[game_index, :move+1], title="Original Game (black plays E0)")
# plot_single_board(focus_games_string[game_index, :move].tolist()+[to_string(to_int("C4"))], title="Corrupted Game (blank plays C0)")
focus_states_num_flipped = get_state_stack_num_flipped(focus_games_string)
imshow(
        focus_states_num_flipped[game_index, :end_move],
        facet_col=0,
        facet_col_wrap=8,
        facet_labels=[f"Move {i}" for i in range(0, end_move)],
        title="First 16 moves of first game",
        color_continuous_scale="Greys",
        y = [i for i in alpha],
    )
flipped_list = list(focus_states_num_flipped[game_index, :end_move, 3, 5])
first_flip = True if flipped_list[0] == 1 else False
flipped_list = [first_flip] + [flipped_list[i-1] < flipped_list[i] for i in range(1, end_move)]
print(len(flipped_list))
flipped_list = [i for i in range(0, end_move) if flipped_list[i]]
print(flipped_list)'''
# plot_single_board(int_to_label(moves_int))
# plot_probe_outputs(focus_cache, full_linear_probe, 5, game_index, 4)

'# plot_single_board(focus_games_string[game_index, :move+1], title="Original Game (black plays E0)")\n# plot_single_board(focus_games_string[game_index, :move].tolist()+[to_string(to_int("C4"))], title="Corrupted Game (blank plays C0)")\nfocus_states_num_flipped = get_state_stack_num_flipped(focus_games_string)\nimshow(\n        focus_states_num_flipped[game_index, :end_move],\n        facet_col=0,\n        facet_col_wrap=8,\n        facet_labels=[f"Move {i}" for i in range(0, end_move)],\n        title="First 16 moves of first game",\n        color_continuous_scale="Greys",\n        y = [i for i in alpha],\n    )\nflipped_list = list(focus_states_num_flipped[game_index, :end_move, 3, 5])\nfirst_flip = True if flipped_list[0] == 1 else False\nflipped_list = [first_flip] + [flipped_list[i-1] < flipped_list[i] for i in range(1, end_move)]\nprint(len(flipped_list))\nflipped_list = [i for i in range(0, end_move) if flipped_list[i]]\nprint(flipped_list)'

In [80]:
print(focus_games_string[0])

tensor([19, 18, 17, 29, 45, 42, 43, 44, 21, 46, 37, 22, 51, 20, 15,  9, 47, 23,
         1, 34, 11, 50, 25, 13, 12, 60, 30,  3,  5, 14,  6, 52, 10,  0, 26, 54,
        58, 24, 16, 49, 41, 31, 61,  2, 32,  7, 38, 63, 56,  4, 53, 59, 62, 39,
        55, 33, 57, 48,  8, 40])


### Rest

In [94]:
LAYER = 4

@dataclass
class Arguments:
    clean_input: Tensor = None
    corrupted_input: Tensor = None
    square: str = None
    corrupted_move: int = None
    end_move: int = None
    move: int = None
    tile_state_clean: int = None
    tile_state_corrupt: int = None
    include_resid: bool = True
    include_heads: bool = True

In [82]:
def square_tuple_from_square(square : str):
    return (alpha.index(square[0]), int(square[1]))
    # assert type(square) == int
    # square_str = to_string(square)
    # return (square_str // 8, square_str % 8)

In [83]:
def cache_to_logit(cache: ActivationCache, args: Arguments) -> Float[Tensor, "1"]:
    square_tuple = square_tuple_from_square(args.square)
    resid = cache["resid_post", LAYER][0]
    logits= einops.einsum(resid, linear_probe2, 'pos d_model, modes d_model rows cols options -> modes pos rows cols options')[0]
    '''logit_diffs = logits.log_softmax(dim=-1)
    logit_diff = logit_diffs[move, square_tuple[0], square_tuple[1], tile_state_clean]
    return logit_diff'''
    logit_correct_clean = logits[args.move, square_tuple[0], square_tuple[1], args.tile_state_clean]
    logit_correct_corrupt = logits[args.move, square_tuple[0], square_tuple[1], args.tile_state_corrupt]
    logit_diff = logit_correct_clean - logit_correct_corrupt
    return logit_diff
    

In [84]:
def patching_metric(patched_cache: ActivationCache, corrupted_logit_diff: float, clean_logit_diff: float, args: Arguments) -> Float[Tensor, "1"]:
    '''
    Function of patched logits, calibrated so that it equals 0 when performance is
    same as on corrupted input, and 1 when performance is same as on clean input.

    Should be linear function of the logits for the d5 token at the final move.
    '''
    # SOLUTION
    patched_logit_diff = cache_to_logit(patched_cache, args)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

In [104]:
def patch_final_move_output(
    activation: Float[Tensor, "batch (head_ind) seq d_model"],
    hook: HookPoint,
    clean_cache: ActivationCache,
    head: Optional[int] = None
) -> Float[Tensor, "batch (head_ind) seq d_model"]:
    '''
    Hook function which patches activations at the final sequence position.

    Note, we only need to patch in the final sequence position, because the
    prior moves in the clean and corrupted input are identical (and this is
    an autoregressive model).
    '''
    # SOLUTION
    # print(activation.shape)
    if not head is None:
        activation[0, -1, head, :] = clean_cache[hook.name][0, -1, head, :]
    else:
        activation[0, -1, :] = clean_cache[hook.name][0, -1, :]
    return activation


def get_act_patch_resid_pre(
    model: HookedTransformer,
    corrupted_input: Float[Tensor, "batch pos"],
    clean_cache: ActivationCache,
    patching_metric: Callable[[ActivationCache], Float[Tensor, ""]],
    corrupted_logit_diff: float,
    clean_logit_diff: float,
    args: Arguments
) -> Float[Tensor, "2 n_layers"]:
    '''
    Returns an array of results, corresponding to the results of patching at
    each (attn_out, mlp_out) for all layers in the model.
    '''
    # SOLUTION
    if args.include_resid:
        PATCH_LAYERS = 4
    else:
        PATCH_LAYERS = 2

    if args.include_heads:
        PATCH_HEADS = 8
    else:
        PATCH_HEADS = 0

    activations = ["attn_out", "mlp_out"]
    if args.include_resid:
        activations += ["resid_pre", "resid_post"]
    
    results = t.zeros(PATCH_LAYERS+PATCH_HEADS, model.cfg.n_layers, device=device, dtype=t.float32)

    for layer in tqdm(range(model.cfg.n_layers)):
        for i, activation in enumerate(activations):
            hook_fn = partial(patch_final_move_output, clean_cache=clean_cache)
            model.reset_hooks()
            cache = model.add_caching_hooks()
            _ = model.run_with_hooks(
                corrupted_input,
                fwd_hooks = [(utils.get_act_name(activation, layer), hook_fn)],
            )
            cache = ActivationCache(cache, model)
            results[i, layer] = patching_metric(cache, corrupted_logit_diff, clean_logit_diff, args)
    
        for head in range(PATCH_HEADS):
            hook_fn = partial(patch_final_move_output, clean_cache=clean_cache, head=head)
            model.reset_hooks()
            cache = model.add_caching_hooks()
            _ = model.run_with_hooks(
                corrupted_input,
                fwd_hooks = [(utils.get_act_name("z", layer), hook_fn)],
            )
            cache = ActivationCache(cache, model)
            results[PATCH_LAYERS+head, layer] = patching_metric(cache, corrupted_logit_diff, clean_logit_diff, args)

    return results

In [86]:
from utils import seq_to_state_stack

In [87]:
# Function takes in a tile e.g. "D3" and
# returns a list "blank" / "mine" / "their" with length 60
# For every move it sais whether the tile is blank, mine or their
def get_tile_state(square, input_int):
    assert len(input_int.shape) == 1
    input_int = input_int.tolist()
    input_str = [to_string(i) for i in input_int]
    tile_index = to_string(to_int(square))
    game_state = seq_to_state_stack(input_str)
    tile_state_list_w_b = game_state[:, tile_index // 8, tile_index % 8].copy() # 0 is blank, -1 is white, 1 is black
    # change dtype to int
    tile_state_list_w_b = tile_state_list_w_b.astype(int)
    assert len(tile_state_list_w_b.shape) == 1
    assert tile_state_list_w_b.shape[0] == len(input_int)
    tile_state_list_w_b[0::2] *= -1 # 0 blank, 1 mine, -1 theirs
    tile_state_list_w_b[tile_state_list_w_b == 1] = MINE # 2
    tile_state_list_w_b[tile_state_list_w_b == -1] = YOURS # 1
    assert set(tile_state_list_w_b).issubset(set([0, 1, 2]))
    return list(tile_state_list_w_b) # 0 empty, 1 yours, 2 mine

In [96]:
def activation_patching_from_inputs(args: Arguments):
    # Create tile_state_clean and tile_state_corrupted instead of tile_state_clean
    tile_state_list_clean = get_tile_state(args.square, args.clean_input)
    tiel_state_list_corrupt = get_tile_state(args.square, args.corrupted_input)
    print(tile_state_list_clean)

    for move in range(args.corrupted_move, args.end_move):
        args.move = move
        args.tile_state_clean = tile_state_list_clean[move]
        args.tile_state_corrupt = tiel_state_list_corrupt[move]
        clean_input_short = args.clean_input[:move+1].clone()
        corrupted_input_short = args.corrupted_input[:move+1].clone()

        _, clean_cache = model.run_with_cache(clean_input_short)
        _, corrupted_cache = model.run_with_cache(corrupted_input_short)

        clean_logit_diff = cache_to_logit(clean_cache, args)
        corrupted_logit_diff = cache_to_logit(corrupted_cache, args)

        print(f"Clean log prob of {args.square} at move {args.move}: {clean_logit_diff.item()}")
        print(f"Corrupted log prob of {args.square} at move {args.move}: {corrupted_logit_diff.item()}")
        
        patching_results = get_act_patch_resid_pre(model, corrupted_input_short, clean_cache, patching_metric, corrupted_logit_diff, clean_logit_diff, args)

        line_labels = ["attn", "mlp"]
        if args.include_resid:
            line_labels += ["resid_pre", "resid_post"]
        if args.include_heads:
            line_labels += [f"head_{head}" for head in range(8)]
        assert patching_results.shape[0] == len(line_labels)

        line(patching_results, title=f"Layer Output Patching Effect on {args.square} Logit Diff", line_labels=line_labels, width=750)

In [89]:
# clean_input_int     = t.Tensor(to_int([37, 43, 42, 29, 19, 41, 44, 21, 30, 39, 14, 38, 51, 26, 45, 10, 1, 22, 46, 12, 23, 7, 18, 15, 3, 47, 20, 31])).to(t.int64)
# corrupted_input_int = t.Tensor(to_int([37, 43, 42, 29, 19, 34, 41, 21, 30, 39, 14, 38, 51, 26, 45, 10, 1, 22, 46, 12, 23, 7, 18, 15, 3, 47, 20, 31])).to(t.int64)
# plot_game(to_string(clean_input_int.unsqueeze(0)), game_index=0, end_move = 16)
# plot_game(to_string(corrupted_input_int.unsqueeze(0)), game_index=0, end_move = 16)

In [105]:
args = Arguments()

args.clean_input     = t.Tensor(to_int([37, 43, 42, 29, 19, 41, 44, 21, 30, 39, 14, 38, 51, 26, 45, 10, 1, 22, 46, 12, 23, 7, 18, 15, 3, 47, 20, 31])).to(t.int64)
args.corrupted_input = t.Tensor(to_int([37, 43, 42, 29, 19, 34, 41, 21, 30, 39, 14, 38, 51, 26, 45, 10, 1, 22, 46, 12, 23, 7, 18, 15, 3, 47, 20, 31])).to(t.int64)
args.corrupted_move = 6-1
# length = 28
# GAME_INDEX = 0
corrupted_square = 41
args.square = to_board_label(corrupted_square)
args.end_move = 16
args.include_resid = False
args.include_heads = True

print("Clean input:", args.clean_input)
print("Corrupted input 1:", args.corrupted_input)

activation_patching_from_inputs(args)


Clean input: tensor([34, 40, 39, 28, 20, 38, 41, 22, 29, 36, 15, 35, 48, 27, 42, 11,  2, 23,
        43, 13, 24,  8, 19, 16,  4, 44, 21, 30])
Corrupted input 1: tensor([34, 40, 39, 28, 20, 33, 38, 22, 29, 36, 15, 35, 48, 27, 42, 11,  2, 23,
        43, 13, 24,  8, 19, 16,  4, 44, 21, 30])
[0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1]
Clean log prob of F1 at move 5: 18.27457046508789
Corrupted log prob of F1 at move 5: -19.908981323242188


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 6: 8.931432723999023
Corrupted log prob of F1 at move 6: -9.048787117004395


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 7: 6.749320983886719
Corrupted log prob of F1 at move 7: -9.987183570861816


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 8: 8.416863441467285
Corrupted log prob of F1 at move 8: -7.762742519378662


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 9: 6.686853408813477
Corrupted log prob of F1 at move 9: -8.681960105895996


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 10: 8.031612396240234
Corrupted log prob of F1 at move 10: -8.110197067260742


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 11: 6.997312545776367
Corrupted log prob of F1 at move 11: -7.539063453674316


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 12: 5.655460357666016
Corrupted log prob of F1 at move 12: -8.857635498046875


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 13: 5.920053958892822
Corrupted log prob of F1 at move 13: -6.69356107711792


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 14: 5.801131248474121
Corrupted log prob of F1 at move 14: -6.46583366394043


  0%|          | 0/8 [00:00<?, ?it/s]

Clean log prob of F1 at move 15: 6.6381683349609375
Corrupted log prob of F1 at move 15: -6.654299259185791


  0%|          | 0/8 [00:00<?, ?it/s]

### Firgure out model, resid_pre, resid_post

In [91]:
model.cfg.positional_embedding_type

'standard'

In [92]:
first_move = focus_games_int[0, 0]
print(first_move)
W_E = model.W_E
W_pos = model.W_pos
# print(W_E[first_move] + W_pos[first_move])
# resid_pre_maybe = W_E[first_move] + W_pos[first_move]
resid_pre_maybe = W_E[first_move] + W_pos[0]
model.reset_hooks()
logits, cache = model.run_with_cache(focus_games_int[0, :-1])
# print(cache, cache.keys())
resid_pre = cache["resid_pre", 0][0, 0]
print(resid_pre.shape)
print(resid_pre[:10])
print(resid_pre_maybe[:10])
t.testing.assert_close(resid_pre, resid_pre_maybe)

tensor(20)
torch.Size([512])
tensor([-0.0181, -0.0278,  0.0164, -0.0400, -0.0454, -0.0560,  0.0028, -0.0418,
        -0.0092, -0.0979], device='cuda:0')
tensor([-0.0181, -0.0278,  0.0164, -0.0400, -0.0454, -0.0560,  0.0028, -0.0418,
        -0.0092, -0.0979], device='cuda:0', grad_fn=<SliceBackward0>)


### Old

In [93]:
model.reset_hooks()
hook_fn = partial(patch_final_move_output, clean_cache=clean_cache)
cache = model.add_caching_hooks()
_ = model.run_with_hooks(
    corrupted_input,
    fwd_hooks = [
        (utils.get_act_name("attn_out", 0), hook_fn),
        (utils.get_act_name("attn_out", 4), hook_fn)]
)
cache = ActivationCache(cache, model)
print(cache.keys())
metric = patching_metric(cache)
print(metric)
# print(cache)'''

NameError: name 'clean_cache' is not defined

In [None]:
def patch_final_move_output(
    activation: Float[Tensor, "batch seq d_model"],
    hook: HookPoint,
    clean_cache: ActivationCache,
) -> Float[Tensor, "batch seq d_model"]:
    '''
    Hook function which patches activations at the final sequence position.

    Note, we only need to patch in the final sequence position, because the
    prior moves in the clean and corrupted input are identical (and this is
    an autoregressive model).
    '''
    # SOLUTION
    activation[0, -1, :] = clean_cache[hook.name][0, -1, :]
    return activation


def get_act_patch_resid_pre(
    model: HookedTransformer,
    corrupted_input: Float[Tensor, "batch pos"],
    clean_cache: ActivationCache,
    patching_metric: Callable[[ActivationCache], Float[Tensor, ""]],
    corrupted_logit_diff: float,
    clean_logit_diff: float
) -> Float[Tensor, "2 n_layers"]:
    '''
    Returns an array of results, corresponding to the results of patching at
    each (attn_out, mlp_out) for all layers in the model.
    '''
    # SOLUTION
    model.reset_hooks()
    cache = model.add_caching_hooks()
    results = t.zeros(2, model.cfg.n_layers, device=device, dtype=t.float32)
    hook_fn = partial(patch_final_move_output, clean_cache=clean_cache)

    for i, activation in enumerate(["attn_out", "mlp_out"]):
        for layer in tqdm(range(model.cfg.n_layers)):
            _ = model.run_with_hooks(
                corrupted_input,
                fwd_hooks = [(utils.get_act_name(activation, layer), hook_fn)],
            )
            cache = ActivationCache(cache, model)
            results[i, layer] = patching_metric(cache, corrupted_logit_diff, clean_logit_diff)

    return results