In [1]:
from enum import Enum
import torch
from torch.utils.data import Dataset
import chess
import helperfuncs
import bitboards
import net
import autoencoder
import alphabeta
import mctsAZ
import mcts_custom

SearchType = Enum('SearchType', 'MINIMAX MCTS CUSTOM')
MAX_MOVES = 10

class Encode(object):
    def __init__(self, encoder):
        self.encoder = encoder
        
    def __call__(self, sample):
        return self.encoder.encode(sample)
    

class SearchDataset(Dataset):
    def __init__(self, size, transform, *args):
        self.data = get_dataset(size, *args)
        self.transform = transform
    
    def __getitem__(self, idx):
        return [self.transform(self.data[idx][0])] + self.data[idx][1:]
    
    def __len__(self):
        return len(self.data)
    
def get_dataset(size, *args):
    '''Get dataset for NN training no smaller than specified "size".
    Args are the "generate_game" function parameters.'''
    dataset = []
    
    while len(dataset) < size:
        game = generate_game(*args)
        
        for nd in game:
            if args[3] == SearchType.MINIMAX:
                position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.get_node().state)).unsqueeze(0).cuda()
                dataset.append([position, torch.Tensor(nd.get_val()).cuda()])
                
            elif args[3] == SearchType.MCTS:
                position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.state)).unsqueeze(0).cuda()
                moves = [move.uci() for move in nd.moves]
                policy = helperfuncs.policy_from_probability([[moves[i], child.actionValue] for i, child in enumerate(nd.childNodes)])
                dataset.append([position, torch.Tensor([nd.actionValue]).cuda(), policy.cuda()])
                
            elif args[3] == SearchType.CUSTOM:
                position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.state)).unsqueeze(0).cuda()
                moves = [move.uci() for move in nd.moves]
                choiceProbability = nd.choiceProbability
                policy = helperfuncs.policy_from_probability([[moves[i], choiceProbability.value(i)] for i in range(len(choiceProbability.x))])
                dataset.append([position, torch.Tensor([nd.actionValue]).cuda(), policy.cuda()])
            
    return dataset
        
def generate_game(board, nnet, encoder, search_tree, *args):
    '''Generate the chess game given the starting "board" position. Args depend on chosen search type. 
    Three parameters for MINIMAX: depth, lower bound and higher bound of aspiration window.
    One parameter for MCTS and CUSTOM: number of rollouts.'''
    
    game, moves = [], 0
    
    while not stop_cond(board, moves):
        
        if search_tree == SearchType.MINIMAX:
            node = alphabeta.alphabeta(alphabeta.Node(board), args[0], args[1], args[2], nnet, encoder)
            board = node.get_node().state
            
        elif search_tree == SearchType.MCTS:
            tree = mctsAZ.Mcts(board, nnet, encoder)
            node = tree.search(args[0])
            board = node.state
            
        elif search_tree == SearchType.CUSTOM:
            tree = mcts_custom.Mcts(board, nnet, encoder)
            node = tree.search(args[0])
            board = node.state
            
        game.append(node)
        moves += 1

    return game
    

def stop_cond(board, moves):
    end = False
    
    if moves > MAX_MOVES:
        end = True
        
    return end