In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle

import chess
import iceberg as ice
import torch
from leela_interp import Lc0sight, LeelaBoard
from leela_interp.tools import figure_helpers as fh

In [None]:
with open("interesting_puzzles.pkl", "rb") as f:
    puzzles = pickle.load(f)
len(puzzles)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Lc0sight("lc0.onnx", device=device)

In [None]:
idx = fh.PUZZLE_LOC
# loc, not iloc!
puzzle = puzzles.loc[idx]
board = LeelaBoard.from_puzzle(puzzle)
corrupted_board = LeelaBoard.from_fen(puzzle.corrupted_fen)
display(board)
model.pretty_play(board)
display(corrupted_board)

In [None]:
for square in chess.SQUARES:
    if board.pc_board.piece_at(square) != corrupted_board.pc_board.piece_at(square):
        patching_square = chess.SQUARE_NAMES[square]
        break
patching_square

In [None]:
pv = puzzle.principal_variation

In [None]:
def get_policy(board: LeelaBoard, top_k: int = 3):
    policy, wdl, _ = model.play(board, return_probs=True)
    top_moves = model.top_moves(board, policy, top_k=top_k)
    top_moves = {
        board.pc_board.san(chess.Move.from_uci(move)): prob
        for move, prob in top_moves.items()
    }
    return top_moves

In [None]:
# From https://github.com/revalo/iceberg/blob/main/examples/neural_network.py
class NeuralNetwork(ice.DrawableWithChild):
    layer_node_counts: tuple[int, ...]
    node_radius: float = 30
    node_vertical_gap: float = 30
    layer_gap: float = 150
    node_border_color: ice.Color = ice.Colors.WHITE
    node_fill_color: ice.Color = None
    node_border_thickness: float = 3
    line_path_style: ice.PathStyle = ice.PathStyle(ice.Colors.WHITE, thickness=3)

    def setup(self):
        # [layer_index, node_index]
        self._layer_nodes = [
            [
                ice.Ellipse(
                    rectangle=ice.Bounds(
                        top=0,
                        left=0,
                        bottom=self.node_radius * 2,
                        right=self.node_radius * 2,
                    ),
                    border_color=self.node_border_color,
                    fill_color=self.node_fill_color,
                    border_thickness=self.node_border_thickness,
                )
                for _ in range(layer_node_count)
            ]
            for layer_node_count in self.layer_node_counts
        ]

        self._node_vertical_gap = self.node_vertical_gap
        self._layer_gap = self.layer_gap
        self._line_path_style = self.line_path_style

        self._initialize_based_on_nodes()

    def _initialize_based_on_nodes(self):
        # Arrange the circles.
        nodes_arranged = ice.Arrange(
            [
                ice.Arrange(
                    circles,
                    arrange_direction=ice.Arrange.Direction.VERTICAL,
                    gap=self._node_vertical_gap,
                )
                for circles in self.layer_nodes
            ],
            arrange_direction=ice.Arrange.Direction.HORIZONTAL,
            gap=self._layer_gap,
        )

        # Draw the lines.
        self._lines = []
        for layer_a, layer_b in zip(self.layer_nodes[:-1], self.layer_nodes[1:]):
            for circle_a in layer_a:
                for circle_b in layer_b:
                    start = nodes_arranged.child_bounds(circle_a).corners[
                        ice.Corner.MIDDLE_RIGHT
                    ]
                    end = nodes_arranged.child_bounds(circle_b).corners[
                        ice.Corner.MIDDLE_LEFT
                    ]

                    line = ice.Line(start, end, self._line_path_style)
                    self._lines.append(line)

        # All the children in this composition.
        # Nodes are drawn on top of lines.
        children = self._lines
        children.append(nodes_arranged)

        self.set_child(ice.Compose(children))

    @property
    def layer_nodes(self) -> list[list[ice.Drawable | ice.Ellipse]]:
        return self._layer_nodes

In [None]:
MOVE_COLOR = fh.BEST_MOVE_COLOR + "80"
_HORIZONTAL_GAP = 25
_ARROW_PADDING = 3
_ARROW_KWARGS = {
    "line_path_style": ice.PathStyle(ice.BLACK, thickness=2),
    "head_length": 7,
}

board_plots: list[ice.Drawable] = []
board_copy = board.copy()

for i in range(3):
    if i == 0:
        heatmap = {pv[i][2:4]: fh.COLOR_DICT["first_target"]}
    elif i == 2:
        heatmap = {pv[i][2:4]: fh.COLOR_DICT["third_target"]}
    else:
        heatmap = None

    if i == 0:
        caption = "Initial puzzle position"
    elif i == 1:
        caption = "Position after 1st move"
    else:
        caption = "Position after 2nd move"

    board_plot = board_copy.plot(
        arrows={pv[i]: MOVE_COLOR},
        heatmap=heatmap,
        show_lastmove=False,
    )
    board_plot += ice.Tex(tex=caption).scale(2).relative_to(board_plot, ice.UP * 10)
    board_plot = board_plot.crop(board_plot.bounds)
    new_board = board_plot

    net = (
        NeuralNetwork(
            layer_node_counts=(2, 3, 3, 2),
            node_border_color=ice.Colors.BLACK,
            node_border_thickness=6,
            line_path_style=ice.PathStyle(ice.Colors.BLACK, thickness=6),
        )
        .scale(0.2)
        .relative_to(board_plot, ice.DOWN * _HORIZONTAL_GAP)
    )
    new_board += net

    policy = get_policy(board_copy)
    policy_plot = fh.PolicyBar(
        numbers=list(policy.values()),
        bar_labels=list(policy.keys()),
        label_font_family=fh.FONT_FAMILY,
        bar_height=150,
        ellipses=len(policy) == 3,
        use_tex=True,
        bar_width=20,
    )
    new_board += policy_plot.relative_to(new_board, ice.DOWN * _HORIZONTAL_GAP)

    with new_board:
        arrow1 = ice.Arrow(
            start=board_plot.relative_bounds.bottom_middle + ice.DOWN * _ARROW_PADDING,
            end=net.relative_bounds.top_middle + ice.UP * _ARROW_PADDING,
            **_ARROW_KWARGS,
        )
        arrow2 = ice.Arrow(
            start=net.relative_bounds.bottom_middle + ice.DOWN * _ARROW_PADDING,
            end=policy_plot.relative_bounds.top_middle + ice.UP * _ARROW_PADDING,
            **_ARROW_KWARGS,
        )
        new_board += arrow1
        new_board += arrow2

    board_plots.append(new_board)

    board_copy.push_uci(pv[i])

scene = board_plots[0]
for board_plot in board_plots[1:]:
    scene += board_plot.relative_to(scene).move(*ice.RIGHT * 20)

scene.scale(2)