In [1]:
import pandas as pd
df = pd.read_csv('chessData.csv')

In [23]:
import jax
import jax.numpy as jnp
import chess
import chess.svg
from IPython.display import SVG
jax.config.update("jax_enable_x64", True)

def print_chessboard_from_fen(fen):
    board = chess.Board(fen + " w - - 0 1")  # Append a standard ending for completeness
    board_svg = chess.svg.board(board, size=350)
    display(SVG(board_svg))


def bitboards_to_array(bb: jnp.array) -> jnp.array:
    bb = jnp.asarray(bb, dtype=jnp.uint64)[:, jnp.newaxis]
    s = 8 * jnp.arange(7, -1, -1, dtype=jnp.uint64)
    b = (bb >> s).astype(jnp.uint8)
    b = jnp.unpackbits(b, bitorder="little")
    return b.reshape(-1, 8, 8)

def fen_to_bitboard(fen):
    board = chess.Board(fen)
    black, white = board.occupied_co

    bitboards = jnp.array([
        black & board.pawns,
        black & board.knights,
        black & board.bishops,
        black & board.rooks,
        black & board.queens,
        black & board.kings,
        white & board.pawns,
        white & board.knights,
        white & board.bishops,
        white & board.rooks,
        white & board.queens,
        white & board.kings,
       
    ], dtype=jnp.uint64)

    board_array = bitboards_to_array(bitboards)
    # Create the 13th layer for empty spaces
    empty_spaces = jnp.ones((8, 8), dtype=jnp.uint8)
    for bb in board_array:
        empty_spaces = empty_spaces & ~bb

    # Add the 13th layer to the board array
    board_array = jnp.concatenate((board_array, empty_spaces[jnp.newaxis, :, :]), axis=0)

    return jax.device_put(board_array)
     

def bitboard_to_fen(bitboard):
    piece_order = ['p', 'n', 'b', 'r', 'q', 'k', 'P', 'N', 'B', 'R', 'Q', 'K']


    board = chess.Board(None)

    for i, piece_symbol in enumerate(piece_order):
        piece_bitboard = bitboard[i]

        for row in range(8):
            for col in range(8):
                square = chess.square(col, 7 - row)
                if piece_bitboard[row, col]:
                    piece = chess.Piece.from_symbol(piece_symbol)
                    board.set_piece_at(square, piece)

    return board.fen()

#print out FEN and Tensor:
fen_strings = df['FEN'][:1][0]

bitboard= fen_to_bitboard(fen_strings)
fen=bitboard_to_fen(bitboard)

print("FEN Notation:", fen_strings)
print("BitBoard:", bitboard.shape) 
print("FEN Notation Recovered:", fen)


FEN Notation: rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq - 0 1
BitBoard: (13, 8, 8)
FEN Notation Recovered: rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR w - - 0 1
