In [8]:
from enum import Enum
import torch
from torch.utils.data import Dataset
import chess
import chess.engine
from chess.engine import Cp
import helperfuncs
import bitboards
import net
import autoencoder
import alphabeta
import mctsAZ
import mcts_custom

MAX_MOVES = 10  ##in halfmoves
DRAW_CUTOFF = 0 ##in centipawns
param = 0.2 ##number in range (0, 1) used in PARAM 'ReinforcementType (simple linear combination of TD and Monte-Carlo learning)'
engine = chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish")

SearchType = Enum('SearchType', 'MINIMAX MCTS CUSTOM')
ReinforcementType = Enum('ReinforcementType', 'MC TD PARAM')
winner_to_num = {chess.WHITE: 1, chess.BLACK: 0, None: 0.5}

class Encode(object):
    def __init__(self, encoder):
        self.encoder = encoder
        
    def __call__(self, sample):
        return self.encoder.encode(sample)
    
class SearchDataset(Dataset):
    def __init__(self, size, transform, reinf, *args):
        self.data = get_dataset(size, reinf, *args)
        self.transform = transform
    
    def __getitem__(self, idx):
        return [self.transform(self.data[idx][0])] + self.data[idx][1:]
    
    def __len__(self):
        return len(self.data)
    
def get_dataset(size, reinf, *args):
    '''Get dataset for NN training no smaller than specified "size".
    Args are the "generate_game" function parameters.'''
    dataset = []
    
    while len(dataset) < size:
        game = generate_game(*args)
        
        winner, state =  -1, game[-1].state
        outcome = state.outcome()
        
        if outcome:
            winner = outcome.winner
        else:
            score = engine.analyse(state, chess.engine.Limit(time=1))["score"].white()
            if score > Cp(DRAW_CUTOFF):
                winner = chess.WHITE
            elif score < Cp(DRAW_CUTOFF):
                winner = chess.BLACK
            else:
                winner = None 
                
        for nd in game:
            val = get_value(reinf, args[3], winner, nd)
            
            if args[3] == SearchType.MINIMAX:
                position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.get_node().state)).unsqueeze(0).cuda()
                dataset.append([position, val])
                
            elif args[3] == SearchType.MCTS:
                position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.state)).unsqueeze(0).cuda()
                moves = [move.uci() for move in nd.moves]
                policy = helperfuncs.policy_from_probability([[moves[i], child.actionValue] for i, child in enumerate(nd.childNodes)])
                dataset.append([position, val, policy.cuda()])
                
            elif args[3] == SearchType.CUSTOM:
                position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.state)).unsqueeze(0).cuda()
                moves = [move.uci() for move in nd.moves]
                choiceProbability = nd.choiceProbability
                policy = helperfuncs.policy_from_probability([[moves[i], choiceProbability.value(i)] for i in range(len(choiceProbability.x))])
                dataset.append([position, val, policy.cuda()])
            
    return dataset
        
def generate_game(board, nnet, encoder, search_tree, *args):
    '''Generate the chess game given the starting "board" position. Args depend on chosen search type. 
    Three parameters for MINIMAX: depth, lower bound and higher bound of aspiration window.
    One parameter for MCTS and CUSTOM: number of rollouts.'''
    
    game, moves = [], 0
    
    while not stop_cond(board, moves):
        
        if search_tree == SearchType.MINIMAX:
            node = alphabeta.alphabeta(alphabeta.Node(board), args[0], args[1], args[2], nnet, encoder)
            board = node.get_node().state
            
        elif search_tree == SearchType.MCTS:
            tree = mctsAZ.Mcts(board, nnet, encoder)
            node = tree.search(args[0])
            board = node.state
            
        elif search_tree == SearchType.CUSTOM:
            tree = mcts_custom.Mcts(board, nnet, encoder)
            node = tree.search(args[0])
            board = node.state
            
        game.append(node)
        moves += 1

    return game

def stop_cond(board, moves):
    end = False
    
    if board.outcome():
        end = True
    elif moves > MAX_MOVES:
        end = True
        
    return end

def get_value(reinf_type, search_type, winner, node):
    if reinf_type == ReinforcementType.MC:
        val = winner_to_num[winner]
        
    elif reinf_type == ReinforcementType.TD or ReinforcementType.PARAM:
        if search_type == SearchType.MINIMAX:
            val = node.get_val()[0]
        elif search_type == SearchType.MCTS or search_type == SearchType.CUSTOM:
            val = node.actionValue
            
    elif reinf_type == ReinforcementType.PARAM:
        val = param * val + (1 - param) * winner_to_num[winner]
    
    return torch.Tensor([val]).cuda()