In [1]:
from enum import Enum
import torch
from torch.utils.data import Dataset
import chess
import net
import autoencoder
import alphabeta
import mctsAZ
import mcts_custom

SearchType = Enum('SearchType', 'MINIMAX MCTS CUSTOM')
MAX_MOVES = 10

class SearchDataset(Dataset):
    def __init__(self, size, *args):
        self.data = get_dataset(size, *args)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)
    
def get_dataset(size, *args):
    '''Get dataset for NN training no smaller than specified "size".
    Args are the generate_game function parameters.'''
    dataset = []
    
    while len(dataset) < size:
        game = generate_game(*args)
        
        for nd in game:
            if args[3] == SearchType.MINIMAX:
                dataset.append((nd.get_node().state, nd.get_val()))
                
            elif args[3] == SearchType.MCTS:
                policy = [child.actionValue for child in nd.childNodes]
                dataset.append((nd.state, nd.actionValue, nd.priorProbability, policy))
                
            elif args[3] == SearchType.CUSTOM:
                dataset.append((nd.state, nd.actionValue, nd.priorProbability, nd.choiceProbability))
            
    return dataset
        
def generate_game(board, nnet, encoder, search_tree, *args):
    '''Generate the chess game given the starting "board" position. Args depend on chosen search type. 
    Three parameters for MINIMAX: depth, lower bound of aspiration window and higher bound.
    One parameter for MCTS and CUSTOM: number of rollouts.'''
    
    game, moves = [], 0
    
    while not stop_cond(board, moves):
        
        if search_tree == SearchType.MINIMAX:
            node = alphabeta.alphabeta(alphabeta.Node(board), args[0], args[1], args[2], nnet, encoder)
            board = node.get_node().state
            
        elif search_tree == SearchType.MCTS:
            tree = mctsAZ.Mcts(board, nnet, encoder)
            node = tree.search(args[0])
            board = node.state
            
        elif search_tree == SearchType.CUSTOM:
            tree = mcts_custom.Mcts(board, nnet, encoder)
            node = tree.search(args[0])
            board = node.state
            
        game.append(node)
        moves += 1

    return game
    

def stop_cond(board, moves):
    end = False
    
    if moves > MAX_MOVES:
        end = True
        
    return end

In [2]:
import math

dataset = get_dataset(20, chess.Board(), net.Net().cuda(), autoencoder.autoencoder().cuda(), SearchType.MCTS, 50)
print("{}\n{}\n{}\n{}\n".format(dataset[0], dataset[1], dataset[2], dataset[3]))

  return self.sig(value), [self.softmax(policy[0]), self.softmax(policy[0])]


(Board('rnbqkbnr/pppppppp/8/8/6P1/8/PPPPPP1P/RNBQKBNR b KQkq - 0 1'), 0.4999206890451147, [['g8h6', 0.00023833855121070321], ['g8f6', 0.00023662285485583773], ['b8c6', 0.00024338003865519153], ['b8a6', 0.0002513123600769718], ['h7h6', 0.00024435574927915774], ['g7g6', 0.00023585672873076453], ['f7f6', 0.0002420726895711947], ['e7e6', 0.0002505785216254733], ['d7d6', 0.0002477247685211681], ['c7c6', 0.0002344633408333137], ['b7b6', 0.00024931136844944315], ['a7a6', 0.00024730850771863266], ['h7h5', 0.00024290525986879648], ['g7g5', 0.0002398015287669534], ['f7f5', 0.00024484218517522097], ['e7e5', 0.0002518755132213668], ['d7d5', 0.00024514230600825027], ['c7c5', 0.0002358766034450295], ['b7b5', 0.00025336415656942307], ['a7a5', 0.00024412974182670168]], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.4999129229121738, 0])
(Board('rnbqkbnr/p1pppppp/8/1p6/6P1/8/PPPPPP1P/RNBQKBNR w KQkq - 0 2'), 0.5000792399365851, [['g1h3', 0.00024760148113849717], ['g1f3', 0.0002393641141253850