In [3]:
import os

import chess
import chess.svg
from jax import random as jrandom
import numpy as np

In [8]:
# Jupyter-style exploration utilities for the parameterised BC pipeline.
# This notebook code will:
# 1) Inspect / load a (fen, move) dataset if present (CSV or JSONL), or make a tiny demo set.
# 2) Convert moves into (from_idx, to_idx, promo_id).
# 3) (Optionally) run a forward pass through a minimal param-action head if JAX/Haiku are available.
#
# You can re-run cells as you like.


import os, sys, json, csv, math, textwrap, pathlib, random
from dataclasses import dataclass
from typing import List, Tuple, Optional

import pandas as pd


# ----------------------------
# 0) Helpers: UCI <-> params
# ----------------------------
FILES = "abcdefgh"
RANKS = "12345678"
PROMO_TO_ID = {"":0, "q":1, "r":2, "b":3, "n":4}
ID_TO_PROMO = {v:k for k,v in PROMO_TO_ID.items()}

def square_to_index(file_char: str, rank_char: str) -> int:
    f = FILES.index(file_char)
    r = int(rank_char) - 1
    return r * 8 + f

def index_to_square(idx: int) -> str:
    f = idx % 8
    r = idx // 8
    return f"{FILES[f]}{RANKS[r]}"

def uci_to_params(uci: str) -> Tuple[int, int, int]:
    uci = uci.strip().lower()
    if len(uci) < 4:
        raise ValueError(f"Bad UCI: {uci}")
    f_file, f_rank, t_file, t_rank = uci[0], uci[1], uci[2], uci[3]
    from_idx = square_to_index(f_file, f_rank)
    to_idx = square_to_index(t_file, t_rank)
    promo = ""
    if len(uci) == 5:
        promo = uci[4]
    return from_idx, to_idx, PROMO_TO_ID.get(promo, 0)

def params_to_uci(from_idx: int, to_idx: int, promo_id: int) -> str:
    return f"{index_to_square(from_idx)}{index_to_square(to_idx)}{ID_TO_PROMO.get(promo_id,'')}"


# ------------------------------------------------------------
# 1) Load dataset (try common locations or make a tiny sample)
# ------------------------------------------------------------
def try_load_dataset() -> pd.DataFrame:
    """
    Try a few common file names:
      - /mnt/data/behavioral_cloning.csv
      - /mnt/data/behavioral_cloning.jsonl
      - ../data/behavioral_cloning.csv (relative to current)
    The file format should have columns/fields: 'fen' and 'move' (UCI).
    If nothing is found, we create a tiny demo DataFrame.
    """
    candidates = [
        "/mnt/data/behavioral_cloning.csv",
        "/mnt/data/behavioral_cloning.jsonl",
        "/mnt/data/data.csv",
        "/mnt/data/data.jsonl",
        "../data/behavioral_cloning.csv",
        "../data/behavioral_cloning.jsonl",
    ]
    for path in candidates:
        if os.path.exists(path):
            if path.endswith(".csv"):
                df = pd.read_csv(path)
            else:
                # JSONL
                rows = []
                with open(path, "r", encoding="utf-8") as f:
                    for line in f:
                        rows.append(json.loads(line))
                df = pd.DataFrame(rows)
            # normalize cols
            if "uci" in df.columns and "move" not in df.columns:
                df = df.rename(columns={"uci":"move"})
            if not {"fen","move"}.issubset(df.columns):
                raise ValueError(f"Found {path} but it lacks 'fen'/'move' columns. Columns = {df.columns.tolist()}")
            df["__source_file__"] = path
            return df

    # Fallback tiny demo set (3 legal positions with a couple of moves)
    demo = [
        # Scholars mate ideas etc.
        {"fen":"rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1R b KQkq - 1 1", "move":"d7d5"},
        {"fen":"rnbqkbnr/pppp1ppp/8/4p3/3PP3/5N2/PPP2PPP/RNBQKB1R b KQkq - 0 2", "move":"e5d4"},
        {"fen":"rnbqkbnr/ppp2ppp/3p4/4p3/3PP3/4BN2/PPP2PPP/RN1QKB1R b KQkq - 0 3", "move":"e5d4"},
        # Promotion example
        {"fen":"8/P7/8/8/8/8/8/k6K w - - 0 1", "move":"a7a8q"},
    ]
    df = pd.DataFrame(demo)
    df["__source_file__"] = "<demo>"
    return df


df = try_load_dataset()
print(f"Loaded {len(df)} rows from {df['__source_file__'].unique().tolist()}")
print("Raw dataset preview (fen, move)")
print(df.head(20))


# ----------------------------------------------------------------------
# 2) Apply parameterised transform: add from/to/promo columns & checks
# ----------------------------------------------------------------------
def add_param_columns(df: pd.DataFrame) -> pd.DataFrame:
    from_list, to_list, promo_list = [], [], []
    for uci in df["move"].astype(str).tolist():
        f_idx, t_idx, p_id = uci_to_params(uci)
        from_list.append(f_idx)
        to_list.append(t_idx)
        promo_list.append(p_id)
    out = df.copy()
    out["from_idx"] = from_list
    out["to_idx"] = to_list
    out["promo_id"] = promo_list
    out["reconstructed_uci"] = [params_to_uci(f,t,p) for f,t,p in zip(from_list, to_list, promo_list)]
    out["matches_roundtrip"] = (out["reconstructed_uci"] == out["move"].str.lower())
    return out

df_params = add_param_columns(df)
df_params.head(50)

print("Round-trip check (uci -> params -> uci):",
      f"{df_params['matches_roundtrip'].mean()*100:.1f}% of rows match")

# Quick histograms (counts only, textual to keep deps minimal)
def quick_counts(series: pd.Series, topk: int = 10, name: str = ""):
    counts = series.value_counts().head(topk)
    print(f"\nTop {topk} {name}:")
    for k, v in counts.items():
        print(f"  {k}: {v}")

quick_counts(df_params["from_idx"], name="from squares")
quick_counts(df_params["to_idx"], name="to squares")
quick_counts(df_params["promo_id"], name="promo ids (0=None,1=q,2=r,3=b,4=n)")


# -------------------------------------------------------------------------
# 3) (Optional) Forward pass demo with a tiny NumPy MLP as a stand-in head
# -------------------------------------------------------------------------
# We avoid importing JAX/Haiku here (may not be present). This is a sanity
# demonstration: given a "core" vector and one-hot(from/to), produce logits
# for the last three steps shaped [B, 3, 64], with -1e9 padding for promo.
#
# This is NOT your training model; it's a structure/shape demo you can adapt.

import numpy as np

def tiny_mlp(x, out):
    # single hidden layer for demo
    h = np.tanh(x @ np.random.randn(x.shape[-1], x.shape[-1]) * 0.1)
    return h @ np.random.randn(h.shape[-1], out) * 0.1

def make_demo_forward(batch_from: np.ndarray, batch_to: np.ndarray, core_dim: int = 64):
    """
    Inputs:
      batch_from: [B] ints in [0,63]
      batch_to:   [B] ints in [0,63]
    Returns:
      log_probs: [B, 3, 64]; last step has first 5 classes valid, others ~-1e9
    """
    B = batch_from.shape[0]
    V = 64
    # Random "core" vectors just for a demo
    core = np.random.randn(B, core_dim).astype(np.float32)

    from_oh = np.eye(64, dtype=np.float32)[batch_from]           # [B,64]
    to_oh   = np.eye(64, dtype=np.float32)[batch_to]             # [B,64]

    # Head 1: from logits
    h1 = core
    logits_from = tiny_mlp(h1, V)

    # Head 2: to logits (condition on from)
    h2 = np.concatenate([core, from_oh], axis=-1)
    logits_to = tiny_mlp(h2, V)

    # Head 3: promo (condition on from & to)
    h3 = np.concatenate([core, from_oh, to_oh], axis=-1)
    logits_promo5 = tiny_mlp(h3, 5)
    neg_inf = np.full((B, V-5), -1e9, dtype=np.float32)
    logits_promo = np.concatenate([logits_promo5, neg_inf], axis=-1)

    # Pack to [B, T=3, V=64] and log-softmax
    logits = np.zeros((B, 3, V), dtype=np.float32)
    logits[:, 0, :] = logits_from
    logits[:, 1, :] = logits_to
    logits[:, 2, :] = logits_promo

    # log-softmax per step
    def log_softmax(x, axis=-1):
        x_max = np.max(x, axis=axis, keepdims=True)
        y = x - x_max
        np.exp(y, out=y)
        y_sum = np.sum(y, axis=axis, keepdims=True)
        return (x - x_max) - np.log(y_sum)

    logp = log_softmax(logits, axis=-1)
    return logp  # [B,3,64]

# Demo forward on the first few rows
B = min(8, len(df_params))
batch_from = df_params["from_idx"].values[:B].astype(np.int64)
batch_to = df_params["to_idx"].values[:B].astype(np.int64)
logp_demo = make_demo_forward(batch_from, batch_to, core_dim=64)
print("\nDemo forward shape [B,3,64]:", logp_demo.shape)
print("Sanity check (per-step log-prob sums should be ~0):",
      np.allclose(np.log(np.sum(np.exp(logp_demo), axis=-1)), 0.0, atol=1e-5))


# -------------------------------------------------------------------------
# 4) Save an inspection CSV with the parameterised columns
# -------------------------------------------------------------------------
out_path = "/mnt/data/dataset_parametrised_preview.csv"
df_params.to_csv(out_path, index=False)
print(f"\nSaved a CSV preview with parameterised targets to: {out_path}")


Loaded 4 rows from ['<demo>']
Raw dataset preview (fen, move)
                                                 fen   move __source_file__
0  rnbqkbnr/pppppppp/8/8/4P3/5N2/PPPP1PPP/RNBQKB1...   d7d5          <demo>
1  rnbqkbnr/pppp1ppp/8/4p3/3PP3/5N2/PPP2PPP/RNBQK...   e5d4          <demo>
2  rnbqkbnr/ppp2ppp/3p4/4p3/3PP3/4BN2/PPP2PPP/RN1...   e5d4          <demo>
3                       8/P7/8/8/8/8/8/k6K w - - 0 1  a7a8q          <demo>
Round-trip check (uci -> params -> uci): 100.0% of rows match

Top 10 from squares:
  36: 2
  48: 1
  51: 1

Top 10 to squares:
  27: 2
  56: 1
  35: 1

Top 10 promo ids (0=None,1=q,2=r,3=b,4=n):
  0: 3
  1: 1

Demo forward shape [B,3,64]: (4, 3, 64)
Sanity check (per-step log-prob sums should be ~0): True


FileNotFoundError: [Errno 2] No such file or directory: '/mnt/data/dataset_parametrised_preview.csv'

In [9]:
import jax, jax.numpy as jnp
print("JAX backend:", jax.default_backend())
print("Devices:", jax.devices())

x = jnp.ones((1024, 1024))
y = x @ x
print("Result device:", y.device_buffer.device())

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


JAX backend: cpu
Devices: [CpuDevice(id=0)]
Result device: TFRT_CPU_0
