In [31]:
import torch as t
import os
import numpy as np
import sys
from pathlib import Path

# import OthelloBoardState
section_dir = Path.cwd()
assert section_dir.name == "interpretability"

OTHELLO_ROOT = (section_dir / "othello_world").resolve()
OTHELLO_MECHINT_ROOT = (OTHELLO_ROOT / "mechanistic_interpretability").resolve()

sys.path.append(str(OTHELLO_MECHINT_ROOT))

from mech_interp_othello_utils import (
    OthelloBoardState,
    to_string,
    to_int,
)

from training_utils import (
    get_state_stack_one_hot_general,
)

In [32]:
board_seqs_string = t.load(
    os.path.join(
        section_dir,
        "data/board_seqs_string_train.pth",
    )
)

In [52]:
WHITE = -1
BLACK = 1
BLANK = 0
ACCESIBLE = 2
LEGAL = 3

def is_legal_or_accesible(_state, row, col, player):
    is_accesible = False
    for dr in [-1, 0, 1]:
        for dc in [-1, 0, 1]:
            if dr == 0 and dc == 0:
                continue
            r, c = row + dr, col + dc
            if r < 0 or r >= 8 or c < 0 or c >= 8:
                continue
            if _state[r, c] == -player or _state[r, c] == player:
                is_accesible = True
            if _state[r, c] == player:
                while True:
                    r += dr
                    c += dc
                    if r < 0 or r >= 8 or c < 0 or c >= 8:
                        break
                    if _state[r, c] == BLANK:
                        break
                    if _state[r, c] == -player:
                        return LEGAL
    if is_accesible:
        return ACCESIBLE
    return BLANK
                    

def seq_to_state_stack_legal(str_moves):
    """
    0: blank
    1: unaffected
    2: accesible
    3: legal
    """
    if isinstance(str_moves, t.Tensor):
        str_moves = str_moves.tolist()
    board = OthelloBoardState()
    states = []
    for move_idx, move in enumerate(str_moves):
        # The Player who just played
        player = BLACK if move_idx % 2 == 0 else WHITE
        try:
            board.umpire(move)
        except RuntimeError:
            breakpoint()
        _state = np.copy(board.state)
        # Do Accessible
        for row in range(8):
            for col in range(8):
                if _state[row, col] != BLANK:
                    continue
                _state[row, col] = is_legal_or_accesible(_state, row, col, player)
        # _state = np.abs(_state)
        states.append(_state)
    states = np.stack(states, axis=0)
    return t.tensor(states)

def build_state_stack_legal(board_seqs_string):
    """
    Construct stack of board-states.
    This function will also filter out corrputed game-sequences.
    """
    state_stack = []
    for idx, seq in enumerate(board_seqs_string):
        _stack = seq_to_state_stack_legal(seq)
        state_stack.append(_stack)
    return t.tensor(np.stack(state_stack))

def state_stack_to_one_hot_accesible(state_stack):
    one_hot = t.zeros(
        state_stack.shape[0],
        state_stack.shape[1],
        8,  # rows
        8,  # cols
        2,  # options
        device=state_stack.device,
        dtype=t.int,
    )  # [batch_size, 59, 8, 8, 4]

    # Accesible
    one_hot[..., 0] = state_stack >= ACCESIBLE
    one_hot[..., 1] = 1 - one_hot[..., 0]
    return one_hot

def state_stack_to_one_hot_legal(state_stack):
    one_hot = t.zeros(
        state_stack.shape[0],
        state_stack.shape[1],
        8,  # rows
        8,  # cols
        2,  # options
        device=state_stack.device,
        dtype=t.int,
    )  # [batch_size, 59, 8, 8, 4]

    # Accesible
    one_hot[..., 0] = state_stack == LEGAL
    one_hot[..., 1] = 1 - one_hot[..., 0]
    return one_hot


get_state_stack_one_hot_accesible = get_state_stack_one_hot_general(seq_to_state_stack_legal, state_stack_to_one_hot_accesible)
get_state_stack_one_hot_legal = get_state_stack_one_hot_general(seq_to_state_stack_legal, state_stack_to_one_hot_legal)

In [53]:
games_str = board_seqs_string[:10]
assert games_str.shape == (10, 60)

print(games_str.shape)

one_hot = get_state_stack_one_hot_accesible(games_str)
one_hot.shape

torch.Size([10, 60])


torch.Size([10, 60, 8, 8, 2])

In [54]:
one_hot[0, 2, :, :, 1]

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 1, 0],
        [1, 1, 0, 1, 1, 1, 0, 0],
        [1, 1, 0, 1, 1, 0, 0, 1],
        [1, 1, 0, 0, 1, 0, 1, 1],
        [1, 1, 1, 0, 0, 0, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0', dtype=torch.int32)

In [55]:
one_hot_legal = get_state_stack_one_hot_legal(games_str)

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32)