In [None]:
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig
import einops
import torch
from tqdm import tqdm
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import csv
from dataclasses import dataclass
from torch.nn import MSELoss, L1Loss
import pandas as pd
import pickle
import os
import logging
import plotly.graph_objects as go
from functools import partial

import chess_utils
import train_test_chess
from train_test_chess import Config, LinearProbeData

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import torch
import ipywidgets as widgets
from IPython.display import display, clear_output


In [None]:
torch.set_grad_enabled(False)

In [None]:
file_path = "output_tracker.pkl"
with open(file_path, "rb") as file:
    output_tracker = pickle.load(file)

In [None]:
print(output_tracker[5]["cells"])

In [None]:
board_state = output_tracker[5]["original_piece_grid"][0]
print(board_state.shape)
r, c = output_tracker[5]["cells"][0]
piece = output_tracker[5]["pieces"][0]
print(piece)
print(r, c)

In [None]:

INT_TO_CHAR = {
    -6: "k",
    -5: "q",
    -4: "r",
    -3: "b",
    -2: "n",
    -1: "p",
    0: ".",
    1: "P",
    2: "N",
    3: "B",
    4: "R",
    5: "Q",
    6: "K",
}

# Mapping of integers to chess pieces
# I'm duplicating this from chess_utils.py for easy reference
PIECE_TO_ONE_HOT_MAPPING = {
    -6: 0,
    -5: 1,
    -4: 2,
    -3: 3,
    -2: 4,
    -1: 5,
    0: 6,
    1: 7,
    2: 8,
    3: 9,
    4: 10,
    5: 11,
    6: 12,
}

# Mapping of chess pieces to integers
PIECE_TO_INT = {
    chess.PAWN: 1,
    chess.KNIGHT: 2,
    chess.BISHOP: 3,
    chess.ROOK: 4,
    chess.QUEEN: 5,
    chess.KING: 6,
}

INT_TO_PIECE = {value: key for key, value in PIECE_TO_INT.items()}

blank_index = PIECE_TO_ONE_HOT_MAPPING[0]
white_pawn_index = PIECE_TO_ONE_HOT_MAPPING[1]
black_king_index = PIECE_TO_ONE_HOT_MAPPING[-6]

def plot_board_state(board_state: torch.Tensor, clip_size: int = 200):
    # color scale: Black for -1, Gray for 0, White for 1
    # colorscale = [[0.0, 'black'], [0.5, 'gray'], [1.0, 'white']]
    colorscale = 'gray'
    if board_state.is_cuda:
        board_state = board_state.cpu()
    board_state = np.clip(board_state.numpy(), -clip_size, clip_size)

    # Create heatmap
    heatmap = go.Heatmap(z=board_state, colorscale=colorscale)
    return heatmap

heatmap = plot_board_state(board_state)

# Define the layout
layout = go.Layout(
    title="Chess board white pawns",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import torch

def plot_board_states(output_tracker: dict, index: int):
    """
    Plots 5 rows of 4 columns of chess board states for different layers and grid types.
    
    Args:
    output_tracker (dict): The tracker dictionary containing all board states and metadata.
    index (int): The index of the state to plot.
    
    Returns:
    None: Displays the plot.
    """
    # Define the grid types to plot
    grid_types = ["original_blank_grid", "modified_blank_grid", "original_piece_grid", "modified_piece_grid"]
    
    # Create subplots: 5 rows for layers, 4 columns for grid types
    fig = make_subplots(rows=5, cols=4, subplot_titles=grid_types)

    r, c = output_tracker[3]["cells"][index]
    piece = INT_TO_CHAR[output_tracker[3]["pieces"][index]]
    success = output_tracker[3]["successes"][index]
    scale = output_tracker[3]["scales"][index]
    original_move = output_tracker[3]["original_move"][index]
    modified_move = output_tracker[3]["modified_move"][index]
    title_text = f"R, C: {r}, {c}, Piece: {piece}, Original: {original_move}, modified: {modified_move}, Success: {success}, Scale: {scale}, Index: {index}"
    
    for layer, content in output_tracker.items():
        for i, grid_type in enumerate(grid_types, start=1):
            board_state = content[grid_type][index]  # Assuming each grid is stored as a list of lists
            heatmap = plot_board_state(board_state)  # Reuse your existing plot_board_state function
            
            # Add heatmap to the appropriate subplot
            row = layer - 2  # Assuming layer keys in output_tracker directly map to subplot rows
            col = i  # Column is determined by the position of grid_type in the list
            traces = plot_board_state_with_marker(board_state, r, c)
            for trace in traces:
                fig.add_trace(trace, row=row, col=col)
            
            # Optional: Add metadata text as annotations if needed
    # Adjust layout
    plot_size = 400
    fig.update_layout(
        height=plot_size * 5,
        width=plot_size * 4,
        # title_text=f"R, C: {r}, {c}, Piece: {piece}, Success: {success}, Scale: {scale}",
        annotations=[
        dict(
            text=title_text,  # Top text
            showarrow=False,
            xref="paper",  # Use 'paper' to position relative to the entire figure
            yref="paper",
            x=0.5,  # Centered horizontally
            y=1.00,  # Slightly above the top of the plot
            xanchor="center",
            yanchor="bottom",
            font=dict(size=20)  # Adjust font size as needed
        ),
        dict(
            text=title_text,  # Bottom text
            showarrow=False,
            xref="paper",
            yref="paper",
            x=0.5,  # Centered horizontally
            y=-0.01,  # Slightly below the bottom of the plot
            xanchor="center",
            yanchor="top",
            font=dict(size=20)  # Adjust font size as needed
        )
    ],
        showlegend=False
    )
    
    fig.show()

def plot_board_state_with_marker(board_state: torch.Tensor, r: int, c: int, clip_size: int = 200) -> list:
    """
    Generates a heatmap for a given chess board state and overlays a red dot at a specific cell,
    with hover information showing the value at the cell.
    
    Args:
    board_state (torch.Tensor): The board state tensor.
    r (int): Row of the cell to mark.
    c (int): Column of the cell to mark.
    clip_size (int): The maximum and minimum value for clipping board states.
    
    Returns:
    list: Contains the generated heatmap and scatter plot for the marker.
    """
    colorscale = 'gray'
    if board_state.is_cuda:
        board_state = board_state.cpu()
    clipped_board_state = np.clip(board_state.numpy(), -clip_size, clip_size)
    
    # Create heatmap
    heatmap = go.Heatmap(z=clipped_board_state, colorscale=colorscale)
    
    # Get the value at the specified cell for hover information
    cell_value = clipped_board_state[r, c]
    hover_text = f"Value: {cell_value}"
    
    # Create a scatter plot with a single red dot at the specified cell
    marker = go.Scatter(
        x=[c],  # x coordinate in Plotly corresponds to column
        y=[r],  # y coordinate in Plotly corresponds to row
        mode='markers',
        marker=dict(
            color='red',
            size=10  # Adjust the size of the dot as needed
        ),
        text=[hover_text],  # Set hover text
        hoverinfo='text',  # Only display the text on hover
        showlegend=False
    )
    
    return [heatmap, marker]

def plot_board_state(board_state: torch.Tensor, clip_size: int = 200) -> go.Heatmap:
    """
    Generates a heatmap for a given chess board state.
    
    Args:
    board_state (torch.Tensor): The board state tensor.
    clip_size (int): The maximum and minimum value for clipping board states.
    
    Returns:
    go.Heatmap: The generated heatmap.
    """
    colorscale = 'gray'
    if board_state.is_cuda:
        board_state = board_state.cpu()
    board_state = np.clip(board_state.numpy(), -clip_size, clip_size)
    
    heatmap = go.Heatmap(z=board_state, colorscale=colorscale)
    return heatmap

# Example usage
plot_board_states(output_tracker, 0)


In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Assuming output_tracker and plot_board_states are defined as before

current_index = 0  # Initialize the current index

def update_plot_with_index(index: int):
    global current_index
    # Ensure the new index is within the valid range
    max_index = len(output_tracker[5]["original_piece_grid"]) - 1
    current_index = max(0, min(index, max_index))
    
    # Update the plot
    clear_output(wait=True)
    plot_board_states(output_tracker, current_index)
    display(widgets.HBox([prev_button, text_index, next_button, next_false_button]))  # Redisplay the widgets

def on_button_click(b):
    if b.description == 'Next':
        update_plot_with_index(current_index + 1)
    elif b.description == 'Previous':
        update_plot_with_index(current_index - 1)
    elif b.description == 'Next False':
        next_false_index = current_index + 1
        while next_false_index < len(output_tracker[5]["original_piece_grid"]):
            if not output_tracker[3]["successes"][next_false_index]:
                break
            next_false_index += 1
        update_plot_with_index(next_false_index)

def on_text_submit(change):
    update_plot_with_index(change.new)

# Create buttons for navigation
prev_button = widgets.Button(description="Previous")
next_button = widgets.Button(description="Next")
next_false_button = widgets.Button(description="Next False")

# Attach the event handler to the buttons
prev_button.on_click(on_button_click)
next_button.on_click(on_button_click)
next_false_button.on_click(on_button_click)

# Create an IntText widget for direct index input
text_index = widgets.IntText(
    value=0,
    description='Index:',
    continuous_update=False
)

# Attach the event handler to the IntText widget
text_index.observe(on_text_submit, names='value')

# Display the initial plot
plot_board_states(output_tracker, current_index)

# Display the widgets
display(widgets.HBox([prev_button, text_index, next_button, next_false_button]))


In [None]:
# def tensor_to_text(board_state: torch.Tensor) -> np.ndarray:
#     # Create a mapping from numbers to characters
#     # Update this mapping according to your requirements

#     # Convert the tensor to numpy array for easier processing
#     board_array = board_state.numpy()

#     # Create an empty array with the same shape for text
#     text_array = np.empty(board_array.shape, dtype=str)

#     # Fill the text array with corresponding characters
#     for i in range(board_array.shape[0]):
#         for j in range(board_array.shape[1]):
#             text_array[i, j] = INT_TO_CHAR.get(board_array[i, j], str(board_array[i, j]))

#     return text_array

# def plot_board_state_with_text(board_state: torch.Tensor):
#     # Convert the tensor to a text matrix
#     text_matrix = tensor_to_text(board_state)

#     # Define the custom colorscale
#     colorscale = [
#         [0, 'black'],   # Negative values
#         [0.49, 'black'],
#         [0.5, 'grey'],  # Zero
#         [0.51, 'white'],
#         [1, 'white']    # Positive values
#     ]


#     # Create heatmap with text and custom colorscale
#     heatmap = go.Heatmap(
#         z=board_state.numpy(), 
#         text=text_matrix, 
#         showscale=False, 
#         colorscale=colorscale,
#         texttemplate="%{text}"  # Set the texttemplate here
#     )

#     return heatmap
# heatmap = plot_board_state_with_text(move_of_interest_state)

# # Define the layout
# layout = go.Layout(
#     title="Chess board state with text",
#     xaxis=dict(ticks='', nticks=8),
#     yaxis=dict(ticks='', nticks=8),
#     autosize=False,
#     width=600,
#     height=600
# )

# # Create figure and plot
# fig = go.Figure(data=[heatmap], layout=layout)
# fig.show()

In [None]:
# from plotly.subplots import make_subplots

# move_of_interest_probe_out = probe_out[0][0][move_of_interest]
# print(move_of_interest_probe_out.shape)

# fig_rows = 4
# fig_cols = 3
# fig = make_subplots(rows=fig_rows, cols=fig_cols, subplot_titles=[
#     "Chess board blank squares", "Probe output blank squares clip=2", "Probe output blank squares no clipping",
#     "Chess board white pawns", "Probe output white pawns clip=5", "Probe output white pawns no clipping",
#     "Chess board black king", "Probe output black king clip=5", "Probe output black king no clipping",
#     "Chess board state", "Probe output board state", "Redundant probe output board state"
# ])


# # Specify the size of each plot
# plot_size = 400  # You can adjust this size

# fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, blank_index]), row=1, col=1)
# fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, blank_index], clip_size=2), row=1, col=2)
# fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, blank_index]), row=1, col=3)

# fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, white_pawn_index]), row=2, col=1)
# fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], clip_size=5), row=2, col=2)
# fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index]), row=2, col=3)

# fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, black_king_index]), row=3, col=1)
# fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index], clip_size=5), row=3, col=2)
# fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index]), row=3, col=3)

# fig.add_trace(plot_board_state_with_text(move_of_interest_state), row=4, col=1)
# fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=4, col=2)
# fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=4, col=2)

# # Adjust the overall size of the figure
# fig.update_layout(height=fig_rows * plot_size, width=fig_cols * plot_size)

# # Show the figure
# fig.show()