In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle

import chess
import iceberg as ice
import leela_interp.tools.figure_helpers as fh
import matplotlib.pyplot as plt
import numpy as np
import torch
from leela_interp import Lc0sight, LeelaBoard, patching
from leela_interp.core.iceberg_board import palette

In [None]:
with open("interesting_puzzles.pkl", "rb") as f:
    puzzles = pickle.load(f)
len(puzzles)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Lc0sight("lc0.onnx", device=device)

In [None]:
all_effects = torch.load(
    "results/global_patching/residual_stream_results.pt", map_location=device
)
all_effects.shape

In [None]:
def get_effects_data(mask=None):
    if mask is None:
        mask = np.ones(len(puzzles), dtype=bool)

    effects = all_effects[mask]

    candidate_effects = []
    follow_up_effects = []
    patching_square_effects = []
    other_effects = []
    skipped = []

    for i, (idx, puzzle) in enumerate(puzzles[mask].iterrows()):
        board = LeelaBoard.from_puzzle(puzzle)
        corrupted_board = LeelaBoard.from_fen(puzzle.corrupted_fen)

        # Figure out which square(s) differ in the corrupted position
        patching_squares = []
        for square in chess.SQUARES:
            if board.pc_board.piece_at(square) != corrupted_board.pc_board.piece_at(
                square
            ):
                patching_squares.append(chess.SQUARE_NAMES[square])

        candidate_squares = [puzzle.principal_variation[0][2:4]]
        follow_up_squares = [puzzle.principal_variation[2][2:4]]

        # We count squares later than the 3rd one as follow-up squares too:
        for move in puzzle.principal_variation[3:]:
            follow_up_squares.append(move[2:4])

        if (
            set(patching_squares).intersection(set(candidate_squares))
            or set(patching_squares).intersection(set(follow_up_squares))
            or set(candidate_squares).intersection(set(follow_up_squares))
        ):
            skipped.append(idx)
            continue

        candidate_effects.append(
            effects[i, :, [board.sq2idx(square) for square in candidate_squares]]
            # min because effects are negative (so morally, this is the biggest effect)
            .amin(-1)
            .cpu()
            .numpy()
        )
        follow_up_effects.append(
            effects[i, :, [board.sq2idx(square) for square in follow_up_squares]]
            .amin(-1)
            .cpu()
            .numpy()
        )
        patching_square_effects.append(
            effects[i, :, [board.sq2idx(square) for square in patching_squares]]
            .amin(-1)
            .cpu()
            .numpy()
        )
        covered_squares = set(patching_squares + candidate_squares + follow_up_squares)
        other_effects.append(
            effects[
                i,
                :,
                [idx for idx in range(64) if board.idx2sq(idx) not in covered_squares],
            ]
            .amin(-1)
            .cpu()
            .numpy()
        )

    print(
        f"Skipped {len(skipped)} out of {mask.sum()} puzzles ({len(skipped)/mask.sum():.2%})"
    )

    # Define lists for effects and their configurations
    return [
        {"effects": candidate_effects, "name": "1st move target"},
        {"effects": follow_up_effects, "name": "3rd move target"},
        {"effects": patching_square_effects, "name": "Corrupted square(s)"},
        {"effects": other_effects, "name": "Other squares"},
    ]

In [None]:
def plot_residual_effects(effects_data):
    layers = list(range(15))
    # Create plots using matplotlib
    fig, ax = plt.subplots()
    fig.set_figwidth(4.5)
    fig.set_figheight(2)

    colors = fh.COLORS
    # line_styles = ["-", "-", "-", "--"]
    line_styles = ["-"] * 4

    for i, effect_data in enumerate(effects_data):
        effects = effect_data["effects"]
        mean_effects = -np.mean(effects, axis=0)
        # 2 sigma error bars
        stderr_effects = 2 * np.std(effects, axis=0) / np.sqrt(len(effects))

        ax.plot(
            layers,
            mean_effects,
            label=effect_data["name"],
            color=colors[i],
            linestyle=line_styles[i],
            linewidth=fh.LINE_WIDTH,
        )
        ax.fill_between(
            layers,
            mean_effects - stderr_effects,
            mean_effects + stderr_effects,
            color=colors[i],
            alpha=fh.ERROR_ALPHA,
        )

    # ax.set_title("Patching effects on different squares by layer")
    ax.set_xlabel("Layer")
    ax.set_ylabel("Log odds reduction")
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)
    ax.spines[["right", "top", "left"]].set_visible(False)
    ax.set_facecolor(fh.PLOT_FACE_COLOR)

In [None]:
_CUSTOM_EXAMPLES = {
    "backrank": {
        "board": LeelaBoard.from_fen(
            "6k1/pp3ppp/4b3/1B1r4/P2b1q2/8/1PQ2PPP/4R1K1 w - - 0 1"
        ),
        "corrupted_board": LeelaBoard.from_fen(
            "6k1/pp3p1p/4b3/1B1r4/P2b1q2/8/1PQ2PPP/4R1K1 w - - 0 1"
        ),
        "principal_variation": ["c2c8", "e6c8", "e1e8"],
    }
}


def get_example(index: int | str, layers=[0, 8, 11, 12], exponent=1.0):
    if isinstance(index, int):
        effect = all_effects[index, layers]
        puzzle = puzzles.iloc[index]
        board = LeelaBoard.from_puzzle(puzzle)
        first_move_target = puzzle.principal_variation[0][2:4]
        third_move_target = puzzle.principal_variation[2][2:4]
        corrupted_board = LeelaBoard.from_fen(puzzle.corrupted_fen)
    else:
        example = _CUSTOM_EXAMPLES[index]
        board = example["board"]
        corrupted_board = example["corrupted_board"]
        first_move_target = example["principal_variation"][0][2:4]
        third_move_target = example["principal_variation"][2][2:4]

        effect = patching.residual_stream_activation_patch(
            model=model,
            # The puzzles we loaded already specify corrupted board positions
            boards=board,
            corrupted_boards=corrupted_board,
            location_batch_size=64,
            layers=layers,
        )

    # Scale the effects down to make colors more readable:
    effect_for_palette = -torch.pow(effect.abs(), exponent) * effect.sign()
    effect_for_palette = effect_for_palette.cpu().numpy().ravel()

    colormap_values, mappable = palette(
        effect_for_palette,
        cmap=fh.EFFECTS_CMAP,
        # zero_center=True,
    )
    colormap_values = [
        colormap_values[j : j + 64] for j in range(0, 64 * len(layers), 64)
    ]

    # Figure out which square(s) differ in the corrupted position
    patching_squares = []
    for square in chess.SQUARES:
        if board.pc_board.piece_at(square) != corrupted_board.pc_board.piece_at(square):
            patching_squares.append(chess.SQUARE_NAMES[square])

    new_plots = []
    for j, layer in enumerate(layers):
        max_effect_idx = effect[j].abs().argmax()
        max_effect = -effect[j, max_effect_idx].item()

        current_board_og = board.plot(heatmap=colormap_values[j])
        caption_text = (
            f"\\texttt{{L{layer}}}: max log odds reduction = {max_effect:.2f}"
        )
        title = ice.Tex(
            tex=caption_text,
        ).scale(1.8)
        current_board = current_board_og + title.relative_to(
            current_board_og, ice.UP * 10
        )

        with current_board:
            current_board += ice.Point(
                current_board_og.square(
                    chess.parse_square(first_move_target)
                ).relative_bounds.corners[ice.CENTER],
                color=ice.Color.from_hex(fh.COLOR_DICT["first_target"]),
            ).scale(3)
            current_board += ice.Point(
                current_board_og.square(
                    chess.parse_square(third_move_target)
                ).relative_bounds.corners[ice.CENTER],
                color=ice.Color.from_hex(fh.COLOR_DICT["third_target"]),
            ).scale(3)
            for square in patching_squares:
                current_board += ice.Point(
                    current_board_og.square(
                        chess.parse_square(square)
                    ).relative_bounds.corners[ice.CENTER],
                    color=ice.Color.from_hex(fh.COLOR_DICT["corrupted"]),
                ).scale(3)

        new_plots.append((current_board, current_board_og))

    return new_plots, first_move_target, third_move_target, patching_squares

In [None]:
effects_data = get_effects_data()

In [None]:
layers = [0, 10, 14]

plt.ioff()
plt.gcf().clear()
fh.set()
plot_residual_effects(effects_data)
for layer in layers:
    plt.axvline(x=layer, color="k", linestyle="--", alpha=0.8, linewidth=1)
plt.tight_layout()
fig = plt.gcf()
plot = ice.MatplotlibFigure(figure=fig)
plot

In [None]:
example_loc = fh.PUZZLE_LOC
example_iloc = int(puzzles.index.get_indexer([example_loc]).item())
boards, first_move_target, third_move_target, patching_squares = get_example(
    example_iloc, layers=layers, exponent=0.5
)

d = 50

ice_boards, og_boards = zip(*boards)

boards_arranged = ice.Arrange(ice_boards, gap=30).scale(0.5)

legend_items = [
    "1st move target",
    "3rd move target",
    "Corrupted square(s)",
    "Other squares",
]
legend_line_length = 30
legend_line_thickness = 2
legend_lines = [
    ice.Line(
        (0, 0),
        (legend_line_length, 0),
        path_style=ice.PathStyle(
            color=ice.Color.from_hex(fh.COLORS[i]), thickness=legend_line_thickness
        ),
    )
    for i in range(4)
]
legend_texts = [ice.Tex(tex=text).scale(1.5) for text in legend_items]

legend_lines_and_texts = [
    ice.Arrange([line, text], gap=10, arrange_direction=ice.HORIZONTAL)
    for line, text in zip(legend_lines, legend_texts)
]

legend_contents = legend_lines_and_texts[0]
for i in range(1, 4):
    legend_contents += (
        legend_lines_and_texts[i]
        .pad_top(10)
        .relative_to(legend_contents, ice.TOP_LEFT, ice.BOTTOM_LEFT)
    )

legend_title = ice.Tex(tex="\\textbf{Patched Square}").scale(1.5)
legend_contents = ice.Arrange(
    [legend_title, legend_contents], gap=10, arrange_direction=ice.VERTICAL
)

legend_background = ice.Rectangle(
    legend_contents.pad(10).bounds,
    border_color=ice.BLACK.with_alpha(0.1),
    fill_color=ice.BLACK.with_alpha(0.01),
    border_thickness=3,
    border_radius=10,
)

legend = legend_background.add_centered(legend_contents)

legend_and_plot = ice.Arrange(
    [legend.scale(0.8).pad_bottom(20), plot],
    gap=30,
    arrange_direction=ice.HORIZONTAL,
)
boards_arranged = boards_arranged.next_to(legend_and_plot, ice.DOWN * 30)

ax = plot.figure.axes[0]
y_min, y_max = ax.get_ylim()

connect_path_style = ice.PathStyle(
    color=ice.BLACK, thickness=1.5, dashed=True, dash_intervals=(5, 5), dash_phase=-1.5
)

with boards_arranged:
    for layer, board in zip(layers, ice_boards):
        sx, sy = plot.axes_coordinates(layer, y_max)
        ex, ey = board.relative_bounds.bottom_middle
        boards_arranged += ice.CubicBezier(
            points=[(sx, sy), (sx, sy - d), (ex, ey + d), (ex, ey)],
            path_style=connect_path_style,
        ).opacity(0.8)


scene = boards_arranged.pad(10).background(ice.WHITE)
scene.scale(3)