# Workbook to generate Factual Board Answering (FBA) samples
---

In [1]:
import pandas as pd
from pathlib import Path

from fba.fba_generator import fba_generator
from utils.sampling_manager import SamplingManager

In [2]:
# Load base data. Must use the behaviorcloning dataset
# DATA_FILE = "chess_data/deepmind_behaviorcloning_1k.csv"        # Smaller dataset for testing
DATA_FILE = "chess_data/deepmind_behaviorcloning_10k.csv"        # Medium dataset for testing
# DATA_FILE = "chess_data/deepmind_behaviorcloning_100k.csv"    # Larger dataset
df = pd.read_csv(DATA_FILE)

In [3]:
# ================================================
# Sampling criteria.
# These determine the ending distribution of the dataset for each task.
# ================================================
BASE_SAMPLING_CRITERIA = {
    "movecount": {
        (0, 9): 0.15,
        (10, 19): 0.3,
        (20, 29): 0.25,
        (30, 39): 0.20,
        (40, None): 0.10,
    },
    "player": {"w": 0.5, "b": 0.5},
}

TASK_SAMPLING_CRITERIA = {
    "is_check": {
        "is_check": {"n": 0.50, "w": 0.25, "b": 0.25},
        "is_check_gen": {"tp": 0.40, "fp": 0.10, "tn": 0.50},
    },
    "large_mat_adv": {
        "large_mat_adv_gen": {"tp": 0.40, "fp": 0.10, "tn": 0.50},
    },
    "mat_bal": {
        "mat_bal": {"y": 0.50, "n": 0.50},
    },
    "is_legal": {
        "is_legal_gen": {"tp": 0.50, "fp": 0.10, "tn": 0.40},
        "is_legal_piece": {"p": 0.10, "b": 0.20, "n": 0.20, "r": 0.20, "q": 0.20, "k": 0.10,},
        "is_legal_in_check": {"y":0.1, "n":0.9},
    },
    "under_attack": {
        "under_attack_gen": {"tp": 0.40, "fp": 0.20, "tn": 0.40},
    },
    "mat_adv_value": {
        "mat_adv_abs": {"0-100": 0.4, "100-300": 0.3, "300+": 0.3},
    },
    "win_prob": {
        "win_prob": {"0-0.2": 0.2, "0.2-0.4": 0.2, "0.4-0.6": 0.2, "0.6-0.8": 0.2, "0.8-1": 0.2},
    },
    "mobility": {
        "mobility_piece": {"p": 0.06, "b": 0.2, "n": 0.2, "r": 0.2, "q": 0.25, "k": 0.09,},
        "mobility_moves": {"0-1": 0.25, "2-3": 0.3, "4-5": 0.3, "6+": 0.15},
    }, 
    "contrastive_ntp": {
        "contrastive_ntp": {"1": 0.5, "2": 0.5, "None": 0},
        "contrastive_ntp_piece": {"p": 0.05, "b": 0.2, "n": 0.2, "r": 0.2, "q": 0.25, "k": 0.1,},
    }, 
    "cloze_capture": {
        "cloze_piece": {"p": 0.1, "b": 0.2, "n": 0.2, "r": 0.2, "q": 0.2, "k": 0.1, "None": 0}
    },
    "bestmove": {},
    "multi_sample": {},
    "bestline": {},
}


PRINT_DISTRIBUTION_COLUMNS = {
    "is_check": ["movecount_bucket", "player_bucket", "is_check_bucket", "is_check_gen_bucket"],
    "large_mat_adv": ["movecount_bucket", "player_bucket", "large_mat_adv_gen_bucket"],
    "mat_bal": ["movecount_bucket", "player_bucket", "mat_bal_bucket"],
    "is_legal": ["movecount_bucket", "player_bucket", "is_legal_gen_bucket", "is_legal_piece_bucket", "is_legal_in_check_bucket"],
    "under_attack": ["movecount_bucket", "player_bucket", "under_attack_gen_bucket"],
    "mat_adv_value": ["movecount_bucket", "player_bucket", "mat_adv_abs_bucket"],
    "win_prob": ["movecount_bucket", "player_bucket", "win_prob_bucket"],
    "mobility": ["movecount_bucket", "player_bucket", "mobility_piece_bucket", "mobility_moves_bucket"],
    "contrastive_ntp": ["movecount_bucket", "player_bucket", "contrastive_ntp_piece_bucket", "contrastive_ntp_bucket"],
    "cloze_capture": ["movecount_bucket", "player_bucket", "cloze_piece_bucket"],
    "bestmove": ["movecount_bucket", "player_bucket"],
    "multi_sample": ["movecount_bucket", "player_bucket"],
    "bestline": ["movecount_bucket", "player_bucket"],
}

In [4]:
# ================================================
# Generation criteria.
# These the arguments we use for generation for each task. 
# The multi-sample task has 'frequency' which uses a lottery ticket system to determine which FBAs to generate (using their own generation criteria). 
# ================================================
GENERATION_CONFIG = {    # This determines the actual generation parameters for each task. Adjust these if you'd like to alter the generation behavior for a task (piece frequency, etc.)
    "is_check": {"tp": 0.8},
    "large_mat_adv": {"tp": 0.8},
    "mat_bal": None,
    "is_legal": {
        "choose_legal": {"legal_you": 0.5, "legal_opp": 0.1, "illegal": 0.4},
        "piece_freq": {"p": 1, "n": 3, "b": 3, "r": 3, "q": 5, "k": 3},
        "in_check": 0.1
    },
    "under_attack": {
        "legal_attack": {"attack_you": 0.4, "attack_opp": 0.2, "safe": 0.4},
        "piece_freq": {"p": 1, "n": 2, "b": 2, "r": 2, "q": 3, "k": 1},
    },
    "mat_adv_value": None,
    "win_prob": None,
    "mobility": {
        "piece_freq": {"p": 1, "n": 3, "b": 3, "r": 3, "q": 5, "k": 1},
    },
    "contrastive_ntp": {
        "min_threshold": 0.25,
        "piece_freq": {"p": 1, "n": 5, "b": 5, "r": 5, "q": 8, "k": 3},
    },
    "cloze_capture": {
        "piece_freq": {"p": 1, "n": 3, "b": 3, "r": 3, "q": 5, "k": 1},
    },
    "bestmove": None,
    "multi_sample": None,
    "bestline": {
        "plies": (4, 6),
        "search_depth": 10,
    }
}

# New technique -- can update 'args' to equal something different if I want to adjust args
GENERATION_CONFIG['multi_sample'] = {
    "generation_samples": (4, 6),   # Dictates (min, max) number of FBAs to generate for each board
    "tasks": {
        "is_check": {
            "frequency": 2,         # Lottery ticket frequency (higher = more likely to be chosen)
            "max_samples": 1,       # Max number of samples to take from this task
            "args": GENERATION_CONFIG['is_check']     # Defaults to using generation args previously defined for this task
        },
        "is_legal": {
            "frequency": 5,
            "max_samples": 3,
            "args": GENERATION_CONFIG['is_legal']
        },
        "under_attack": {
            "frequency": 5,
            "max_samples": 2,
            "args": GENERATION_CONFIG['under_attack']
        },
        "mat_adv_value": {
            "frequency": 5,
            "max_samples": 1,
            "args": GENERATION_CONFIG['mat_adv_value']
        },
        "mobility": {
            "frequency": 5,
            "max_samples": 2,
            "args": GENERATION_CONFIG['mobility']
        },
        "cloze_capture": {
            "frequency": 5,
            "max_samples": 2,
            "args": GENERATION_CONFIG['cloze_capture']
        }
    }
}

In [5]:
# =====================================
# Helper functions
# =====================================
def generate(task, count, base_df):
    # Start by filtering based on base_criteria (more efficient with our function)
    base_df = SamplingManager(base_df, BASE_SAMPLING_CRITERIA).get_samples(len(base_df))
    base_df.drop(columns=['movecount_bucket', 'player_bucket'])
    cfg = GENERATION_CONFIG.get(task)
    fba_df = fba_generator(task, base_df, cfg) if cfg else fba_generator(task, base_df)

    # Now do final sampling based on combined criteria
    sm   = SamplingManager(fba_df, BASE_SAMPLING_CRITERIA)
    crit = {**BASE_SAMPLING_CRITERIA, **TASK_SAMPLING_CRITERIA.get(task, {})}
    out  = sm.get_samples(count, criteria=crit)
    return out

def print_distributions(df, cols):
    for col in cols:
        print(f"\n{col}:")
        print(df[col].value_counts(normalize=True).sort_index())

# Generate our FBA Samples

In [6]:
# Define our desired number of samples for each task.
# You can do single task generations (i.e., just asks '1' problem) if you choose any the tasks above 'predict_bestmove'.
# multi_sample will generate multiple FBA queries per board.
NUM_SAMPLES_PER_TASK = {
    # "is_check": 50,
    # "large_mat_adv": 50,
    # "mat_bal": 50,
    # "is_legal": 50,
    # "under_attack": 50,
    # "mat_adv_value": 50,
    # "win_prob": 50,
    # "mobility": 50,
    # "contrastive_ntp": 50,
    # "cloze_capture": 50,
    "bestmove": 1_000,
    "multi_sample": 1_000,
    "bestline": 1_000,
}

TASK_TO_LABEL_MAP = {
    "is_check": "is-check",
    "large_mat_adv": "large-mat-adv",
    "mat_bal": "mat-bal",
    "is_legal": "is-legal",
    "under_attack": "under-attack",
    "mat_adv_value": "mat-adv-value",
    "win_prob": "win-prob",
    "mobility": "mobility",
    "contrastive_ntp": "contrastive-ntp",
    "cloze_capture": "cloze-capture",
    "bestmove": "bestmove",
    "multi_sample": "multi-fba",
    "bestline": "bestline",
}

In [7]:
for task, count in NUM_SAMPLES_PER_TASK.items():
    print(f"\n=== {task} ===")
    samples = generate(task, count, df)
    print_distributions(samples, PRINT_DISTRIBUTION_COLUMNS[task])
    outpath = Path(f"processed_data/fba_{TASK_TO_LABEL_MAP[task]}_{count}.jsonl")
    samples[f"{task}_chat"].to_json(outpath, orient="records", lines=True)


=== bestmove ===
[906/9060] 34384.52 samples/s
[1812/9060] 33873.73 samples/s
[2718/9060] 34082.29 samples/s
[3624/9060] 33814.20 samples/s
[4530/9060] 34140.35 samples/s
[5436/9060] 29492.92 samples/s
[6342/9060] 29289.26 samples/s
[7248/9060] 29571.46 samples/s
[8154/9060] 29695.46 samples/s
[9060/9060] 29613.82 samples/s
Total Number of generation errors: 0

movecount_bucket:
movecount_bucket
0-9      0.15
10-19    0.30
20-29    0.25
30-39    0.20
40+      0.10
Name: proportion, dtype: float64

player_bucket:
player_bucket
b    0.5
w    0.5
Name: proportion, dtype: float64

=== multi_sample ===
[906/9060] 1885.44 samples/s
[1812/9060] 1867.59 samples/s
[2718/9060] 1902.43 samples/s
[3624/9060] 1896.71 samples/s
[4530/9060] 1899.33 samples/s
[5436/9060] 1889.67 samples/s
[6342/9060] 1887.79 samples/s
[7248/9060] 1893.49 samples/s
[8154/9060] 1884.13 samples/s
[9060/9060] 1886.67 samples/s
Total Number of generation errors: 0

movecount_bucket:
movecount_bucket
0-9      0.15
10-19   