In [5]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import chess

from beschess.data.embedding import BalancedBatchSampler, PuzzleDataset, DirectLoader, generate_split_indices
# from beschess.utils import tensor_to_board

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [7]:
puzzle_classes = np.load('../data/processed/tag_classes.npy', allow_pickle=True)
classes = {idx + 1: name for idx, name in enumerate(puzzle_classes)}
classes[0] = 'quiet'
classes

{1: 'LinearAttack',
 2: 'DoubleAttack',
 3: 'MatingNet',
 4: 'Overload',
 5: 'Displacement',
 6: 'Sacrifice',
 7: 'EndgameTactic',
 8: 'PieceEndgame',
 0: 'quiet'}

In [8]:
data_dir = '../data/processed/'
quiet_boards_file = data_dir + 'quiet_boards_preeval.npy'
puzzle_boards_file = data_dir + 'boards_packed.npy'
puzzle_labels_file = data_dir + 'tags_packed.npy'

quiet_boards = np.load(quiet_boards_file, mmap_mode='r')
puzzle_boards = np.load(puzzle_boards_file, mmap_mode='r')
puzzle_labels = np.load(puzzle_labels_file, mmap_mode='r')

dataset = PuzzleDataset(
    quiet_boards=quiet_boards,
    puzzle_boards=puzzle_boards,
    puzzle_labels=puzzle_labels,
)

splits = generate_split_indices(dataset)
q_train, p_train = splits['train']
q_val, p_val = splits['val']
q_test, p_test = splits['test']

print(f"Training set: {len(q_train)} ({len(q_train) / len(quiet_boards):.2%}) quiet puzzles, {len(p_train)} ({len(p_train) / len(puzzle_boards):.2%}) tagged puzzles")
print(f"Validation set: {len(q_val)} ({len(q_val) / len(quiet_boards):.2%}) quiet puzzles, {len(p_val)} ({len(p_val) / len(puzzle_boards):.2%}) tagged puzzles")
print(f"Test set: {len(q_test)} ({len(q_test) / len(quiet_boards):.2%}) quiet puzzles, {len(p_test)} ({len(p_test) / len(puzzle_boards):.2%}) tagged puzzles")

train_dl = DataLoader(
    dataset,
    batch_sampler=BalancedBatchSampler(
        dataset, 
        q_train, 
        p_train,
        batch_size=4,
    ),
)

# train_direct_dl = DirectLoader(
#     dataset,
#     BalancedBatchSampler(
#         dataset, 
#         q_train, 
#         p_train,
#         batch_size=4,
#     ),
#     device=device
# )


Training set: 1705295 (90.00%) quiet puzzles, 3753369 (90.00%) tagged puzzles
Validation set: 94739 (5.00%) quiet puzzles, 208521 (5.00%) tagged puzzles
Test set: 94739 (5.00%) quiet puzzles, 208521 (5.00%) tagged puzzles


Building label map: 100%|██████████| 3753369/3753369 [00:29<00:00, 128350.25it/s]


In [None]:
print("Visualizing samples from dataloader...")
for batch_idx, (boards, labels) in enumerate(train_dl):
    for board, label in zip(boards, labels):
        label_idx = label.nonzero()
        print(f"Class: {[classes[i.item()] for i in label_idx]}")
        # b = tensor_to_board(board)
        # print(b.fen())
        # display(b)
    break

print("Visualizing samples from directloader...")
for batch_idx, (boards, labels) in enumerate(train_direct_dl):
    for board, label in zip(boards, labels):
        label_idx = label.nonzero()
        print(f"Class: {[classes[i.item()] for i in label_idx]}")
        # b = tensor_to_board(board)
        # print(b.fen())
        # display(b)
    break

Visualizing samples from dataloader...
torch.Size([9])
Class: ['quiet']
torch.Size([9])
Class: ['quiet']
torch.Size([9])
Class: ['MatingNet', 'Overload']
torch.Size([9])
Class: ['DoubleAttack', 'MatingNet']
Visualizing samples from directloader...


NameError: name 'train_direct_dl' is not defined