In [107]:
import torch
import fancy_einsum as fancy
from mech_interp.fixTL import make_official
from mech_interp.visualizations import (
    plot_board_log_probs,
    map_token_to_move_index,
    preprocess_offset_mapping)
from transformer_lens import HookedTransformer
import pandas as pd
from austin_plotly import imshow
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f37f58cf650>

In [206]:
PROBE_NUMBER = 4
PROBE_NAME = f'linear_probes/saved_probes/color_probe_gpt2-chess-uci-hooked_layer_{PROBE_NUMBER}_indexing_df_to_color_state.pth'
probe = torch.load(PROBE_NAME)
print(probe.keys())
linear_probe = probe['linear_probe']

dict_keys(['acc', 'loss', 'lr', 'epoch', 'batch', 'linear_probe', 'linear_probe_name', 'model_name', 'layer', 'indexing_function_name', 'batch_size', 'wd', 'split', 'num_epochs', 'num_classes', 'wandb_project', 'wandb_run_name', 'dataset_prefix'])


In [207]:
MODEL_NAME = make_official()
model = HookedTransformer.from_pretrained(MODEL_NAME)

Loaded pretrained model AustinD/gpt2-chess-uci-hooked into HookedTransformer


In [208]:
DATASET_NAME = 'chess_data/lichess_test.pkl'
df = pd.read_pickle(DATASET_NAME)
print(df.keys())

Index(['WhiteElo', 'BlackElo', 'Result', 'complete_transcript', 'input_ids',
       'offsets', 'transcript', 'WhiteEloBinned', 'WhiteEloBinIndex',
       'BlackEloBinned', 'BlackEloBinIndex', 'fen_stack'],
      dtype='object')


In [209]:
input_ids = torch.tensor(df['input_ids'][0:15].tolist())
logits, cache = model.run_with_cache(input_ids)

In [210]:
resids = cache['resid_post', 11]
probe_output = fancy.einsum(
    "batch pos d_model, d_model rows cols classes -> batch pos rows cols classes",
    resids,
    linear_probe,
)
probe_output.shape

torch.Size([15, 126, 8, 8, 3])

In [211]:
probe_output[2,...,1].squeeze().shape

torch.Size([126, 8, 8])

In [212]:
def plot_probe_output(output: torch.Tensor, clip_size, title):
    if output.device != 'cpu':
        output = output.cpu()
    tensor_np = output.numpy()
    tensor_np = np.clip(tensor_np, -clip_size, clip_size)

    # Create the frames for each slice of the tensor
    frames = [go.Frame(data=[go.Heatmap(z=tensor_np[i, :, :])], name=str(i)) for i in range(tensor_np.shape[0])]

    # Create the initial figure
    fig = go.Figure(
        data=[go.Heatmap(z=tensor_np[0, :, :])],
        layout=go.Layout(
            title=title,
            sliders=[dict(
                steps=[dict(method='animate',
                            args=[[f.name], 
                                {"frame": {"duration": 0, "redraw": True},
                                "mode": "immediate",
                                "transition": {"duration": 0}}],
                            label=str(i))
                        for i, f in enumerate(frames)],
                active=0,
                transition={"duration": 0},
                x=0, y=0,
                currentvalue={"prefix": "Slice: ", "visible": True},
                len=1
            )]
        ),
        
        frames=frames
    )

    fig.show()



import chess

In [213]:
def plot_combined_probe_output(output: torch.Tensor, clip_size, titles, df_data):
    if output.device != 'cpu':
        output = output.cpu()

    tensor_nps = [np.clip(output[..., i].squeeze().flip(-1).flip(-2).numpy(), 0, clip_size) for i in range(len(titles))]

    # Determine the number of frames (should be same for all tensors)
    num_frames = tensor_nps[0].shape[0]

    # Create subplots
    fig = make_subplots(rows=1, cols=len(titles))

    # Create frames for animation
    frames = []
    for frame_idx in range(num_frames):
        frame_data = []
        for i, tensor_np in enumerate(tensor_nps):
            frame_data.append(go.Heatmap(z=tensor_np[frame_idx, :, :], xaxis=f'x{i+1}', yaxis=f'y{i+1}'))
        frames.append(go.Frame(data=frame_data, name=str(frame_idx)))



    # Add initial data and layout for each subplot
    for i, (title, tensor_np) in enumerate(zip(titles, tensor_nps)):
        fig.add_trace(
            go.Heatmap(z=tensor_np[0, :, :], xaxis=f'x{i+1}', yaxis=f'y{i+1}'),
            row=1, col=i+1,
        )
        fig.update_yaxes(scaleanchor=f"x{i+1}", scaleratio=1, row=1, col=i+1)
        fig.update_layout(**{f'xaxis{i+1}': dict(anchor=f'y{i+1}', title=title)})

    # Update layout with slider
    fig.update_layout(
        sliders=[dict(
            steps=[dict(method='animate',
                        args=[[str(frame_idx)], 
                              {"frame": {"duration": 0, "redraw": True},
                               "mode": "immediate",
                               "transition": {"duration": 0}}],
                        label=str(frame_idx))
                    for frame_idx in range(num_frames)],
            active=0,
            transition={"duration": 0},
            x=0, y=0,
            currentvalue={"prefix": "Frame: ", "visible": True},
            len=1
        )],
    )
    
    fig.frames=frames
    
    
    if df_data is not None:
        
        # Add pieces to board
        preprocessed_offsets = preprocess_offset_mapping(df_data['offsets'])
        df_data
        fen_stack = df_data['fen_stack']
        for idx, frame in enumerate(fig.frames):
            
            # nasty hack to ensure board matches intuition since <s> is included
            if idx +1 < len(fig.frames):
                idx+=1 
                
            move_idx = map_token_to_move_index(df_data['transcript'],
                                    idx,preprocessed_offsets,zero_index=True)
            fen = fen_stack[move_idx]
            board = chess.Board(fen)
            text = ['']
        

            strip_all = lambda s: s.replace(' ', '').replace('\n', '')
            
            text[-1] = [f"<span style='font-size: 16em;'>{c}</span>" 
                    for c in strip_all(board.unicode(empty_square='·'))]
            
            frame.data[0]["text"] = np.array(text).reshape(8,8)
            frame.data[0]["texttemplate"] = "%{text}"

    fig.show()


In [214]:
def plot_combined_probe_output(output: torch.Tensor, clip_size, titles, df_data):
    if output.device != 'cpu':
        output = output.cpu()
        
    HEIGHT=800
    WIDTH=1200

    tensor_nps = [np.clip(output[..., i].squeeze().flip(-1).numpy(), 0, clip_size) for i in range(len(titles))]

    # Determine the number of frames (should be same for all tensors)
    num_frames = tensor_nps[0].shape[0]

    # Create subplots
    fig = make_subplots(rows=1, cols=len(titles))

    # Create frames for animation
    frames = []
    for frame_idx in range(num_frames):
        frame_data = []
        for i, tensor_np in enumerate(tensor_nps):
            figi = go.Heatmap(z=tensor_np[frame_idx, :, :], xaxis=f'x{i+1}', yaxis=f'y{i+1}')
            frame_data.append(figi)
        frames.append(go.Frame(data=frame_data, name=str(frame_idx)))



    # Add initial data and layout for each subplot
    for i, (title, tensor_np) in enumerate(zip(titles, tensor_nps)):
        fig.add_trace(
            go.Heatmap(z=tensor_np[0, :, :], xaxis=f'x{i+1}', yaxis=f'y{i+1}'),
            row=1, col=i+1,
        )
        fig.update_yaxes(scaleanchor=f"x{i+1}", scaleratio=1, row=1, col=i+1)
        fig.update_layout(**{f'xaxis{i+1}': dict(anchor=f'y{i+1}', title=title)})

    # Update layout with slider
    fig.update_layout(
        sliders=[dict(
            steps=[dict(method='animate',
                        args=[[str(frame_idx)], 
                              {"frame": {"duration": 0, "redraw": True},
                               "mode": "immediate",
                               "transition": {"duration": 0}}],
                        label=str(frame_idx))
                    for frame_idx in range(num_frames)],
            active=0,
            transition={"duration": 0},
            x=0, y=0,
            currentvalue={"prefix": "Frame: ", "visible": True},
            len=1
        )],
    )
    
    fig.frames=frames
    
    if df_data is not None:
        preprocessed_offsets = preprocess_offset_mapping(df_data['offsets'])
        fen_stack = df_data['fen_stack']
        move_indices = [map_token_to_move_index(df_data['transcript'], i + 1, preprocessed_offsets, zero_index=True)
                        for i in range(len(fig.frames) - 1)]
        heatmaps_texts = []
        for move_idx in move_indices:
            fen = fen_stack[move_idx]
            board = chess.Board(fen)
            strip_all = lambda s: s.replace(' ', '').replace('\n', '')
            text = [f"<span style='font-size: 14em; align: center;'>{c}</span>" 
                    for c in strip_all(board.__str__())] #.unicode(empty_square='·')
            heatmaps_texts.append(np.array(text).reshape(8, 8))
        
        # Add the text to the initial heatmap traces
        for i in range(len(titles)):
            fig.data[i]['text'] = heatmaps_texts[0]
            fig.data[i]['texttemplate'] = "%{text}"

        # Add the text to the frames
        for idx, frame in enumerate(fig.frames):
            if idx < len(heatmaps_texts):
                frame.data[0]["text"] = heatmaps_texts[idx]
                frame.data[0]["texttemplate"] = "%{text}"
                frame.data[1]["text"] = heatmaps_texts[idx]
                frame.data[1]["texttemplate"] = "%{text}"
                frame.data[2]["text"] = heatmaps_texts[idx]
                frame.data[2]["texttemplate"] = "%{text}"

    fig.update_layout(
        height=HEIGHT,  # Set the desired height
        width=WIDTH,   # Set the desired width
        margin=dict(l=20, r=20, t=20, b=20),  # Adjust margins as needed
        grid=dict(rows=1, columns=len(titles), pattern='independent'),  # Configure the grid for subplots
        # Include any other layout configurations here
    )
    fig.show()



In [215]:
df.keys()

Index(['WhiteElo', 'BlackElo', 'Result', 'complete_transcript', 'input_ids',
       'offsets', 'transcript', 'WhiteEloBinned', 'WhiteEloBinIndex',
       'BlackEloBinned', 'BlackEloBinIndex', 'fen_stack'],
      dtype='object')

In [216]:
clip = 0.4
game_id = 13
# plot_board_log_probs(" ".join(df.iloc[game_id]['transcript'].split(' ')[:-1]),model.tokenizer,logits[game_id,:-1])
plot_combined_probe_output(probe_output[game_id].squeeze(), clip, ['mine','theirs','empty'], df.iloc[game_id])


In [217]:
from mech_interp.visualizations import plot_board_log_probs

plot_board_log_probs()

TypeError: plot_board_log_probs() missing 3 required positional arguments: 'uci_moves', 'tokenizer', and 'logits'

'f2f4 e7e6 g1f3 f8e7 e2e3 e7h4 g2g3 h4f6 d2d4 d7d6 f1d3 h7h6 e1g1 g8e7 c2c4 c7c6 b1c3 b8d7 c1d2 d7b6 b2b4 e8g8 a2a4 a7a6 a4a5 b6d7 d1c2 e7g6 d3g6 f7g6 c2g6 d7b8 e3e4 c8d7 e4e5 f6e7 e5d6 e7d6 c4c5 d6e7 a1e1 e7f6 f3e5 d7e8 g6g4 d8d4 d2e3 d4c3 g4e6 e8f7 e5f7 f8f7 e1d1 c3b4 g3g4 g8f8 g4g5 f7e7 e6d6 f6c3 g5h6 b4a5 h6'