# Image → FEN Sequence Model

This notebook trains a CNN encoder + sequence decoder model that predicts a FEN-like string from a schematic image of a chess board. It uses filenames (without extension) as labels, where each filename encodes the board position.

In [2]:
import os
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model

print('TensorFlow version:', tf.__version__)

# Avoid OOM errors by setting GPU Memory Consumption Growth
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus: 
    tf.config.experimental.set_memory_growth(gpu, True)

tf.config.list_physical_devices('GPU')

TensorFlow version: 2.13.1


[]

In [7]:
import tensorflow as tf
print("TF version:", tf.__version__)
print("Num GPUs:", len(tf.config.list_physical_devices("GPU")))


TF version: 2.13.1
Num GPUs: 0


## 1. Paths and label extraction

We assume the following directory structure:

- `data/train`: training images
- `data/val`: validation images
- `data/test`: test images

Each image filename (without extension) is a FEN-like string, e.g.:
`1b1B1b2-2pK2q1-4p1rB-7k-8-8-3B4-3rb3.jpeg` → label string `1b1B1b2-2pK2q1-4p1rB-7k-8-8-3B4-3rb3`.

In [None]:
data_root = Path('data')
train_dir = data_root / 'train'
val_dir   = data_root / 'val'
test_dir  = data_root / 'test'

def get_image_label_pairs(folder: Path, exts=(".jpeg", ".jpg", ".png")):
    image_paths = []
    labels = []
    if not folder.exists():
        raise FileNotFoundError(f"Folder not found: {folder}")
    for p in folder.iterdir():
        if p.suffix.lower() in exts and p.is_file():
            image_paths.append(str(p))
            labels.append(p.stem)  # filename without extension
    return image_paths, labels

train_paths, train_labels = get_image_label_pairs(train_dir)
val_paths,   val_labels   = get_image_label_pairs(val_dir)
test_paths,  test_labels  = get_image_label_pairs(test_dir)

print(len(train_paths), 'train images')
print(len(test_paths),  'test images')
print('Example label:', train_labels[0] if train_labels else 'N/A')

## 2. Character-level vocabulary

We build a character-level vocabulary from all FEN-like labels. We also add special tokens: `<PAD>`, `<SOS>`, `<EOS>`.

In [None]:
all_labels = train_labels + val_labels + test_labels
if not all_labels:
    raise ValueError('No labels found. Check that your data folders contain images.')

chars = sorted(list({c for lab in all_labels for c in lab}))
print('Chars in dataset:', chars)

PAD_TOKEN = '<PAD>'
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'

vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + chars
char2idx = {c: i for i, c in enumerate(vocab)}
idx2char = {i: c for c, i in char2idx.items()}

vocab_size = len(vocab)
print('Vocab size:', vocab_size)

max_len_raw = max(len(l) for l in all_labels)
max_len = max_len_raw + 2  # +2 for SOS and EOS
print('Max sequence length:', max_len)

## 3. Encode labels to integer sequences

For each label string we build:

- **Decoder input**: `[SOS, c1, c2, ..., cN]`
- **Target output**: `[c1, c2, ..., cN, EOS]`

Both are padded to `max_len`.

In [None]:
def encode_label(label: str):
    chars_list = list(label)
    # decoder input: SOS + label chars
    in_seq = [char2idx[SOS_TOKEN]] + [char2idx[c] for c in chars_list]
    # target: label chars + EOS
    out_seq = [char2idx[c] for c in chars_list] + [char2idx[EOS_TOKEN]]

    # pad or truncate
    in_seq  = in_seq[:max_len]  + [char2idx[PAD_TOKEN]] * max(0, max_len - len(in_seq))
    out_seq = out_seq[:max_len] + [char2idx[PAD_TOKEN]] * max(0, max_len - len(out_seq))

    return np.array(in_seq, dtype=np.int32), np.array(out_seq, dtype=np.int32)

# quick sanity check
test_in, test_out = encode_label(all_labels[0])
print('Example encoded input length:', len(test_in))
print('Example encoded target length:', len(test_out))

## 4. `tf.data` pipeline

We create a dataset that yields:

- Inputs: `{ 'image': image_tensor, 'decoder_input': token_ids }`
- Target: `decoder_output` token IDs.

Images are resized to 256×256 and normalized to `[0, 1]`.

In [None]:
IMG_SIZE = 256


def load_image(path):
    img_bytes = tf.io.read_file(path)
    # Use decode_jpeg instead of decode_image so TF knows the rank
    img = tf.image.decode_jpeg(img_bytes, channels=3)
    # Explicitly set shape: [H, W, 3]
    img.set_shape([None, None, 3])

    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, (IMG_SIZE, IMG_SIZE))
    return img


def encode_label_tf(label):
    in_seq, out_seq = tf.py_function(
        func=lambda s: encode_label(s.numpy().decode('utf-8')),
        inp=[label],
        Tout=[tf.int32, tf.int32]
    )
    in_seq.set_shape((max_len,))
    out_seq.set_shape((max_len,))
    return in_seq, out_seq

def make_dataset(paths, labels, batch_size=32, shuffle=False):
    paths_ds = tf.data.Dataset.from_tensor_slices(paths)
    labels_ds = tf.data.Dataset.from_tensor_slices(labels)
    ds = tf.data.Dataset.zip((paths_ds, labels_ds))

    if shuffle:
        ds = ds.shuffle(buffer_size=len(paths))

    def _process(path, label):
        img = load_image(path)
        dec_in, dec_out = encode_label_tf(label)
        return {'image': img, 'decoder_input': dec_in}, dec_out

    ds = ds.map(_process, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

batch_size = 32
train_ds = make_dataset(train_paths, train_labels, batch_size=batch_size, shuffle=True)

for batch in train_ds.take(1):
    x, y = batch
    print('Image batch shape:', x['image'].shape)
    print('Decoder input shape:', x['decoder_input'].shape)
    print('Decoder target shape:', y.shape)


## 5. CNN encoder

We use EfficientNetB0 (ImageNet-pretrained) as the image encoder, followed by a dense layer to produce a 256-dimensional embedding.

In [None]:
def build_encoder(img_size=IMG_SIZE, embed_dim=256):
    base = tf.keras.applications.EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=(img_size, img_size, 3)
    )
    base.trainable = False  # start frozen; fine-tune later if needed

    inputs = layers.Input(shape=(img_size, img_size, 3), name='image')
    x = base(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    encoded = layers.Dense(embed_dim, activation='relu', name='image_embedding')(x)
    return Model(inputs, encoded, name='encoder')

encoder = build_encoder()
encoder.summary()

## 6. GRU-based sequence decoder

The decoder takes:

- A sequence of token IDs (decoder input)
- The image embedding (from the encoder)

and produces logits over the vocabulary at each time step.

In [None]:
def build_decoder(vocab_size, max_len, hidden_dim=256, embed_dim=128):
    dec_input_tokens = layers.Input(shape=(max_len,), name='decoder_input')
    image_feat       = layers.Input(shape=(256,),   name='image_embedding')

    # Token embedding
    x = layers.Embedding(
        input_dim=vocab_size,
        output_dim=embed_dim,
        mask_zero=True,
        name='token_embedding'
    )(dec_input_tokens)

    # Project image embedding to initial GRU state
    init_state = layers.Dense(hidden_dim, activation='tanh')(image_feat)

    gru_out = layers.GRU(
        hidden_dim,
        return_sequences=True,
        name='decoder_gru'
    )(x, initial_state=init_state)

    logits = layers.Dense(vocab_size, name='vocab_logits')(gru_out)
    return Model([dec_input_tokens, image_feat], logits, name='decoder')

decoder = build_decoder(vocab_size, max_len, hidden_dim=256, embed_dim=128)
decoder.summary()

## 7. Full encoder–decoder model

Combine the encoder and decoder into a single model that takes an image and the decoder input sequence, and outputs logits over the vocabulary at each time step.

In [None]:
image_inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='image')
token_inputs = layers.Input(shape=(max_len,), name='decoder_input')

img_emb = encoder(image_inputs)
logits  = decoder([token_inputs, img_emb])

model = Model(
    inputs={'image': image_inputs, 'decoder_input': token_inputs},
    outputs=logits,
    name='image_to_fen_model'
)

model.summary()

## 8. Training

We train the model with teacher forcing using sparse categorical cross-entropy on the sequence. The metric `token_accuracy` measures per-time-step accuracy (not full-sequence exact match).

In [None]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=loss_fn,
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='token_accuracy')],
)

# Adjust epochs as appropriate for your dataset size and resources.
history = model.fit(
    train_ds,
    epochs=20
)

## 9. Greedy decoding (inference)

We now implement a simple greedy decoding loop to turn an image into a FEN-like string.

In [None]:
def decode_tokens(token_ids):
    chars_out = []
    for tid in token_ids:
        tid = int(tid)
        if tid == char2idx[EOS_TOKEN] or tid == char2idx[PAD_TOKEN]:
            break
        if tid in idx2char and idx2char[tid] not in (SOS_TOKEN, EOS_TOKEN, PAD_TOKEN):
            chars_out.append(idx2char[tid])
    return ''.join(chars_out)

def predict_fen_for_image(path):
    # Prepare image
    img = load_image(path)
    img = tf.expand_dims(img, 0)  # add batch dim

    # Encode image
    img_emb = encoder(img, training=False)

    # Start sequence with SOS
    dec_input = np.zeros((1, max_len), dtype=np.int32)
    dec_input[0, 0] = char2idx[SOS_TOKEN]

    for t in range(1, max_len):
        logits = decoder([dec_input, img_emb], training=False)
        step_logits = logits[:, t-1, :]  # (1, vocab_size)
        next_token = tf.argmax(step_logits, axis=-1).numpy()[0]
        dec_input[0, t] = next_token
        if next_token == char2idx[EOS_TOKEN]:
            break

    return decode_tokens(dec_input[0, 1:])  # skip SOS

# Quick smoke test on a single image (if available)
if test_paths:
    sample_path = test_paths[0]
    pred = predict_fen_for_image(sample_path)
    print('Sample prediction :', pred)
    print('Ground truth       :', Path(sample_path).stem)
else:
    print('No test images found; skipping sample prediction.')

## 10. Exact-sequence evaluation

We can now evaluate the model by computing exact string match accuracy on the test set.

In [None]:
def evaluate_exact_match(paths, labels, max_samples=None):
    if max_samples is not None:
        paths = paths[:max_samples]
        labels = labels[:max_samples]

    correct = 0
    total = len(paths)

    for p, true_label in zip(paths, labels):
        pred = predict_fen_for_image(p)
        if pred == true_label:
            correct += 1

    acc = correct / total if total > 0 else 0.0
    print(f'Exact string match accuracy: {acc * 100:.2f}% ({correct}/{total})')
    return acc

# Example usage (limit max_samples for speed if needed):
if test_paths:
    evaluate_exact_match(test_paths, test_labels, max_samples=200)
else:
    print('No test data to evaluate.')