In [2]:
import sys
import os 

os.environ["CONFIG_PATHS"] = "../configs/self_play.yaml"
os.environ["CONFIG_OVERRIDES"] = 'game.moves_directory="../data/moves_10"'
sys.path.append("../src")

In [78]:
import numpy as np
from tqdm import tqdm

from configuration import config, moves_data

BOARD_SIZE = config()["game"]["board_size"]
MOVES_DATA = moves_data()
MOVE_SIZES = MOVES_DATA["new_occupieds"].sum(axis=(1, 2))

In [4]:
# This stuff is all static, independent of board

move_edge_counts = MOVES_DATA["new_occupieds"].copy().astype(np.int8)
move_edge_counts[:, 1:, :] += MOVES_DATA["new_occupieds"][:, :-1, :]
move_edge_counts[:, :-1, :] += MOVES_DATA["new_occupieds"][:, 1:, :]
move_edge_counts[:, :, 1:] += MOVES_DATA["new_occupieds"][:, :, :-1]
move_edge_counts[:, :, :-1] += MOVES_DATA["new_occupieds"][:, :, 1:]
moves_with_counts = move_edge_counts * MOVES_DATA["new_occupieds"]

start_corners = np.zeros((4, BOARD_SIZE, BOARD_SIZE), dtype=bool)
start_corners[0, 0, 0] = True
start_corners[1, 0, BOARD_SIZE - 1] = True
start_corners[2, BOARD_SIZE - 1, BOARD_SIZE - 1] = True
start_corners[3, BOARD_SIZE - 1, 0] = True

INITIAL_MOVES_ENABLED = np.any(MOVES_DATA["new_occupieds"] & start_corners[:, np.newaxis, :, :], axis=(2, 3))

piece_indices_one_hot = np.eye(21)[MOVES_DATA["piece_indices"]]

In [92]:
def fetch_augmented(boards):
    board_edge_counts = boards.copy().astype(np.int8)
    board_edge_counts[:, :, 1:, :] += boards[:, :, :-1, :]
    board_edge_counts[:, :, :-1, :] += boards[:, :, 1:, :]
    board_edge_counts[:, :, :, 1:] += boards[:, :, :, :-1]
    board_edge_counts[:, :, :, :-1] += boards[:, :, :, 1:]
    board_with_counts = board_edge_counts * boards

    board_count_matches_move_count = (
        board_with_counts[:, :, np.newaxis, :, :] == moves_with_counts[np.newaxis, np.newaxis, :, :, :]
    )

    moves_played = (
        np.logical_or(
            board_count_matches_move_count,
            ~MOVES_DATA["new_occupieds"][np.newaxis, np.newaxis, :, :, :],
        )
    ).all(axis=(3, 4))

    pieces_used_raw = (moves_played @ piece_indices_one_hot)
    assert np.max(pieces_used_raw) == 1

    pieces_available = ~(pieces_used_raw.astype(np.bool))

    batch_indices, player_indices, move_indices = np.nonzero(moves_played)
    
    moves_ruled_out = np.zeros((len(boards), 6233), dtype=np.bool)
    moves_enabled = np.repeat(INITIAL_MOVES_ENABLED[0][np.newaxis, :], len(boards), axis=0).astype(np.bool)

    for i in range(len(batch_indices)):
        batch_index = batch_indices[i]
        player_index = player_indices[i]
        move_index = move_indices[i]

        moves_ruled_out[batch_index] |= MOVES_DATA["moves_ruled_out_for_all"][move_index]

        if player_index == 0:
            moves_enabled[batch_index] |= MOVES_DATA["moves_enabled_for_player"][move_index]
            moves_ruled_out[batch_index] |= MOVES_DATA["moves_ruled_out_for_player"][move_index]

    valid_moves = ~moves_ruled_out & moves_enabled

    return pieces_available, valid_moves

In [93]:
game_files = glob.glob('../data/2024-11-23_00-37-50-doublehandedness/games/*.npz')
game_files.sort()

In [96]:
import glob
from multiprocessing import Pool

game_files = glob.glob('../data/2024-11-23_00-37-50-doublehandedness/games/*.npz')
game_files.sort()

created_game_files = glob.glob('../data/2024-11-23_00-37-50-doublehandedness/augmented_games/*.npz')

def process_game_file(game_file):
    new_name = game_file.replace("/games/", "/augmented_games/")
    if new_name in created_game_files:
        return

    with open(game_file, "rb") as f:
        npz = np.load(f)
        boards = npz["occupancies"]
        if "valid_moves_array" in npz:
            saved_valid_moves = npz["valid_moves_array"]

        unused_pieces, valid_moves = fetch_augmented(boards)
        if "valid_moves_array" in npz:
            assert (valid_moves == saved_valid_moves).all()

        if "values" in npz:
            values = npz["values"]
        else:
            values = npz["final_game_values"]

        np.savez_compressed(
            new_name,
            boards=boards,
            policies=npz["policies"],
            values=values,
            valid_moves_array=valid_moves,
            unused_pieces=unused_pieces,
        )

for game_file in tqdm(game_files):
    process_game_file(game_file)

100%|██████████| 1614/1614 [00:45<00:00, 35.59it/s]  
