In [1]:
from datasets import load_dataset, load_from_disk
import torch
import torch.nn as nn
import os
import numpy as np
import torch.nn.functional as F
from preprocessing import add_representations, fen_to_piece_maps

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
## Splitting the dataset to make it more manageable

input_path = "data/lichess_db_eval.jsonl"
output_part1 = "data/lichess_db_eval_part1.jsonl"
# output_part2 = "data/lichess_db_eval_part2.jsonl"
# output_part3 = "data/lichess_db_eval_part3.jsonl"

n_split = 60_000_000    # take 60 million rows at a time

with open(input_path, "r") as in_f, \
     open(output_part1, "w") as out_f:

    for i, line in enumerate(in_f):
        if i < n_split: 
            out_f.write(line)
        else: 
            print("Breaking")
            break

In [18]:
def batch_select_highest_depth_pv_scaled(batch, c_max=1000):
    selected = {"fen": [], "target": []}

    for fen, evals in zip(batch["fen"], batch["evals"]):
        if not evals:
            continue

        best_eval = max(evals, key=lambda e: e["depth"])
        if not best_eval["pvs"]:
            continue

        pv = best_eval["pvs"][0]
        cp = pv.get("cp")
        mate = pv.get("mate")

        if mate is not None:
            target = 1.0 if mate > 0 else -1.0
        elif cp is not None:
            target = max(-1.0, min(1.0, cp / c_max))        # essentially, any centipawn evaluation above a 1000 (i.e. 10 pawns worth of) is clamped to + or - 1, equating it to mate
        else:
            continue  # skip if no usable score

        selected["fen"].append(fen)
        selected["target"].append(target)

    return selected

In [3]:
dataset = load_dataset("json", data_files="data/lichess_db_eval_part1.jsonl", split="train")

Generating train split: 60000000 examples [01:28, 676697.58 examples/s] 


In [12]:
dataset = dataset.map(
    lambda batch: batch_select_highest_depth_pv_scaled(batch),
    batched=True,
    batch_size=64,
    num_proc=1,
    remove_columns=["evals"],
    new_fingerprint="processed_dataset"
)

Map: 100%|██████████| 60000000/60000000 [51:22<00:00, 19466.14 examples/s]  


In [17]:
dataset.save_to_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part1"))

Saving the dataset (8/8 shards): 100%|██████████| 60000000/60000000 [00:39<00:00, 1531498.89 examples/s]


In [19]:
dataset_part2 = load_dataset("json", data_files="data/lichess_db_eval_part2.jsonl", split="train")
dataset_part2 = dataset_part2.map(
    lambda batch: batch_select_highest_depth_pv_scaled(batch),
    batched=True,
    batch_size=64,
    num_proc=1,
    remove_columns=["evals"],
    new_fingerprint="processed_dataset"
)
dataset.save_to_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part2"))

Generating train split: 60000001 examples [01:47, 558218.34 examples/s]
Map: 100%|██████████| 60000001/60000001 [47:48<00:00, 20916.20 examples/s]  
Saving the dataset (8/8 shards): 100%|██████████| 60000000/60000000 [00:40<00:00, 1482589.98 examples/s]


In [None]:
dataset_part3 = load_dataset("json", data_files="data/lichess_db_eval_part3.jsonl", split="train")
dataset_part3 = dataset_part3.map(
    batch_select_highest_depth_pv_scaled,
    batched=True,
    batch_size=64,
    num_proc=1,
    remove_columns=["evals"],
    new_fingerprint="processed_dataset"
)
dataset.save_to_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part3"))

Generating train split: 33340784 examples [28:52, 19247.76 examples/s]  


DatasetGenerationError: An error occurred while generating the dataset

In [29]:
print(len(dataset_part2))

60000001


In [34]:
batch = dataset[:20]
for fen, target in zip(batch["fen"], batch["target"]):
    print(f"FEN: {fen} | Target: {target}")

FEN: 7r/1p3k2/p1bPR3/5p2/2B2P1p/8/PP4P1/3K4 b - - | Target: 0.069
FEN: 8/4r3/2R2pk1/6pp/3P4/6P1/5K1P/8 b - - | Target: 0.0
FEN: 6k1/6p1/8/4K3/4NN2/8/8/8 w - - | Target: 1.0
FEN: r1b2rk1/1p2bppp/p1nppn2/q7/2P1P3/N1N5/PP2BPPP/R1BQ1RK1 w - - | Target: 0.026
FEN: 6k1/4Rppp/8/8/8/8/5PPP/6K1 w - - | Target: 1.0
FEN: 6k1/6p1/6N1/4K3/4N3/8/8/8 b - - | Target: 1.0
FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - | Target: 0.015
FEN: 8/8/2N2k2/8/1p2p3/p7/K7/8 b - - | Target: 0.0
FEN: 8/1r6/2R2pk1/6pp/3P4/6P1/5K1P/8 w - - | Target: 0.0
FEN: 1R4k1/3q1pp1/6n1/b2p2Pp/2pP2b1/p1P5/P1BQrPPB/5NK1 b - - | Target: -0.057
FEN: 8/5kp1/6N1/4K3/4N3/8/8/8 w - - | Target: 1.0
FEN: 1k1r1r2/pbp3pp/1p1q1p2/2p2Q2/4P3/1P1PB3/P1P3PP/4RRK1 w - - | Target: 0.008
FEN: 8/3B4/8/p4p1k/5P1p/Pb6/1P4P1/6K1 w - - | Target: 0.676
FEN: r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6 | Target: 0.007
FEN: 1R6/3q1ppk/6n1/b2p2Pp/2pP2b1/p1P5/P1B1rPPB/2Q2NK1 b - - | Target: -0.076
FEN: 3r4/1p3k2/p1bPR3/5p2/2B2P

In [None]:
sample_piece_maps = fen_to_piece_maps(dataset[0]["fen"])        # 7r/1p3k2/p1bPR3/5p2/2B2P1p/8/PP4P1/3K4 b - -
for row in sample_piece_maps[0]:
    print(row[:8])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


In [None]:
dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part1"))

In [3]:
# Split the dataset into two halves
half_index = len(dataset) // 2
dataset_partial = dataset.select(range(half_index))

# Save the first half to the specified folder
dataset_partial.save_to_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_partial"))

Saving the dataset (4/4 shards): 100%|██████████| 30000000/30000000 [00:08<00:00, 3484144.24 examples/s]


In [4]:
batch = dataset_partial[:20]
for fen, target in zip(batch["fen"], batch["target"]):
    print(f"FEN: {fen} | Target: {target}")

FEN: 7r/1p3k2/p1bPR3/5p2/2B2P1p/8/PP4P1/3K4 b - - | Target: 0.069
FEN: 8/4r3/2R2pk1/6pp/3P4/6P1/5K1P/8 b - - | Target: 0.0
FEN: 6k1/6p1/8/4K3/4NN2/8/8/8 w - - | Target: 1.0
FEN: r1b2rk1/1p2bppp/p1nppn2/q7/2P1P3/N1N5/PP2BPPP/R1BQ1RK1 w - - | Target: 0.026
FEN: 6k1/4Rppp/8/8/8/8/5PPP/6K1 w - - | Target: 1.0
FEN: 6k1/6p1/6N1/4K3/4N3/8/8/8 b - - | Target: 1.0
FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - | Target: 0.015
FEN: 8/8/2N2k2/8/1p2p3/p7/K7/8 b - - | Target: 0.0
FEN: 8/1r6/2R2pk1/6pp/3P4/6P1/5K1P/8 w - - | Target: 0.0
FEN: 1R4k1/3q1pp1/6n1/b2p2Pp/2pP2b1/p1P5/P1BQrPPB/5NK1 b - - | Target: -0.057
FEN: 8/5kp1/6N1/4K3/4N3/8/8/8 w - - | Target: 1.0
FEN: 1k1r1r2/pbp3pp/1p1q1p2/2p2Q2/4P3/1P1PB3/P1P3PP/4RRK1 w - - | Target: 0.008
FEN: 8/3B4/8/p4p1k/5P1p/Pb6/1P4P1/6K1 w - - | Target: 0.676
FEN: r2qk2r/3n2p1/1pp1p3/3pPpb1/P2P1nBp/1NB4P/1PP2P2/R3QR1K w kq f6 | Target: 0.007
FEN: 1R6/3q1ppk/6n1/b2p2Pp/2pP2b1/p1P5/P1B1rPPB/2Q2NK1 b - - | Target: -0.076
FEN: 3r4/1p3k2/p1bPR3/5p2/2B2P

In [None]:
dataset = dataset.map(add_representations, batched=True, batch_size=64, num_proc=4)

Map (num_proc=4):  59%|█████▉    | 35462656/60000000 [58:56<1:18:26, 5212.97 examples/s] 

In [None]:
dataset.save_to_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_partial"))
# dataset_part2 = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part2"))

In [None]:
dataset_part2 = dataset_part2.map(add_representations, batched=True, batch_size=64, num_proc=4)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x00000190D1AA4D30>>
Traceback (most recent call last):
  File "c:\Users\syeda\miniconda3\envs\DL\lib\site-packages\ipykernel\ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
Map (num_proc=4):   0%|          | 0/60000000 [00:00<?, ? examples/s]

In [None]:
dataset_part2.save_to_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part2"))