In [1]:
from tensorflow.keras.models import load_model
import numpy as np
from skimage import io, transform, color
from skimage.util import view_as_blocks
import ipywidgets as widgets
from IPython.display import display, clear_output
from PIL import Image
import os
import io as io_lib

In [2]:
piece_symbols = "prbnkqPRBNKQ"
model = load_model('chess_fen_model1.h5')

Metal device set to: Apple M1 Pro


In [3]:
def process_image(img_array):
    downsample_size = 200
    square_size = int(downsample_size / 8)

    # Check if the image is grayscale (2D array), if so convert it to RGB (3D array)
    if len(img_array.shape) == 2:
        img_array = color.gray2rgb(img_array)
    elif img_array.shape[2] == 4:  # Check for RGBA format and convert to RGB
        img_array = color.rgba2rgb(img_array)

    # Resize the image
    img_resized = transform.resize(img_array, (downsample_size, downsample_size), anti_aliasing=True)

    # Split the image into 64 tiles (8x8 grid)
    tiles = view_as_blocks(img_resized, block_shape=(square_size, square_size, 3))
    return tiles.reshape(64, square_size, square_size, 3)


def predictions_to_fen(predictions):
    fen_string = ""
    for row in predictions.reshape(8, 8, -1):
        empty_count = 0
        for square in row:
            piece_index = np.argmax(square)
            if piece_index == 12:  # Empty square
                empty_count += 1
            else:
                if empty_count > 0:
                    fen_string += str(empty_count)
                    empty_count = 0
                fen_string += piece_symbols[piece_index]
        if empty_count > 0:
            fen_string += str(empty_count)
        fen_string += '/'
    return fen_string.rstrip('/')  # Replace slashes with dashes if needed

In [5]:
def handle_upload(change):
    uploaded_file = change['new'][0]  # Get image
    if uploaded_file is not None:
        image_data = uploaded_file['content']
        img = Image.open(io_lib.BytesIO(image_data))
        display_size = (100, 100)  
        display_img = img.resize(display_size)

        # RGBA to RGB
        if img.mode == 'RGBA':
            img = img.convert('RGB')
    
        # Convert PIL Image to np array
        img_array = np.array(img)

        # Process the image and make a prediction
        processed_image = process_image(img_array)  
        processed_image = processed_image.reshape(-1, 25, 25, 3)  # Flatten the squares for prediction
        predictions = model.predict(processed_image)
        fen = predictions_to_fen(predictions)  # Convert predictions to FEN

        # Display
        with output:
            clear_output(wait=True)
            display(img)
            print("Predicted FEN:", fen)

# upload button
upload_button = widgets.FileUpload(accept='image/*', multiple=False)
upload_button.observe(handle_upload, names='value')

# Output
output = widgets.Output()

display(upload_button, output)


FileUpload(value=(), accept='image/*', description='Upload')

Output()

