In [1]:
from enum import Enum
import torch
from torch.utils.data import Dataset
import chess
import chess.engine
from chess.engine import Cp
import helperfuncs
import bitboards
import net
import autoencoder
import alphabeta
import mctsAZ
import mcts_custom

engine = chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish")
SearchType = Enum('SearchType', 'MINIMAX MCTS CUSTOM')
ReinforcementType = Enum('ReinforcementType', 'MC TD PARAM')
winner_to_num = {chess.WHITE: 1, chess.BLACK: 0, None: 0.5}

class Encode(object):
    def __init__(self, encoder):
        self.encoder = encoder
        
    def __call__(self, sample):
        return self.encoder.encode(sample)
    
class SearchDataset(Dataset):
    def __init__(self, size, transform, reinf, game_generator, *args):
        self.data = game_generator.get_dataset(size, reinf, *args)
        self.transform = transform
    
    def __getitem__(self, idx):
        return [self.transform(self.data[idx][0])] + self.data[idx][1:]
    
    def __len__(self):
        return len(self.data)
        
class GameGenerator:
    def __init__(self, max_moves, draw_cutoff, param):
        """
        MAX_MOVES in halfmoves
        DRAW_CUTOFF in centipawns
        PARAM number in range (0, 1) used in PARAM 'ReinforcementType (simple linear combination of TD and Monte-Carlo learning)
        """
        
        self.MAX_MOVES = max_moves
        self.DRAW_CUTOFF = draw_cutoff
        self.PARAM = param

    def get_dataset(self, size, reinf, *args):
        """Get dataset for NN training no smaller than specified 'size'.
        Args are the 'generate_game' function parameters."""
        
        dataset = []

        while len(dataset) < size:
            #generate a new game
            game = self.generate_game(*args)

            winner, state =  -1, game[-1].state
            
            #remove drawing positions
            i = 0
            while state.can_claim_draw():
                state = game[- (1 + i)].state
                i += 1

            outcome = state.outcome()

            #score the game with engine and determine winner based on engine score and draw cutoff
            score = engine.analyse(state, chess.engine.Limit(time=1))["score"].white()
            if score > Cp(self.DRAW_CUTOFF):
                winner = chess.WHITE
            elif score < Cp(self.DRAW_CUTOFF):
                winner = chess.BLACK
            else:
                winner = None
                
            #calculating TD learning targets
            td_values = []
            for nd in game:
                td_values.append(self.get_evaluation(args[3], nd))
                
            values = []
            for i in range(len(td_values)):
                if((i+1) == len(td_values)):
                    values.append(winner_to_num[winner])
                else:
                    values.append(td_values[i+1])
                    
                print(values[i], levels=2)

                
            for i, nd in enumerate(game):
                if args[3] == SearchType.MINIMAX:
                    position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.get_node().state)).unsqueeze(0).cuda()
                    dataset.append([position, values[i]])

                elif args[3] == SearchType.MCTS:
                    position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.state)).unsqueeze(0).cuda()
                    moves = [move.uci() for move in nd.moves]
                    policy = helperfuncs.policy_from_probability([[moves[i], child.actionValue] for i, child in enumerate(nd.childNodes)])
                    dataset.append([position, values[i], policy.cuda()])

                elif args[3] == SearchType.CUSTOM:
                    position = bitboards.bitboard_to_cnn_input(bitboards.bitboard(nd.state)).unsqueeze(0).cuda()
                    moves = [move.uci() for move in nd.moves]
                    choiceProbability = nd.choiceProbability
                    policy = helperfuncs.policy_from_probability([[moves[i], choiceProbability.value(i)] for i in range(len(choiceProbability.x))])
                    dataset.append([position, values[i], policy.cuda()])

        return dataset

    def generate_game(self, board, nnet, encoder, search_tree, *args):
        """Generate the chess game given the starting 'board' position. Args depend on chosen search type. 
        Three parameters for MINIMAX: depth, lower bound and higher bound of aspiration window.
        One parameter for MCTS and CUSTOM: number of rollouts."""

        game, moves = [], 0

        while not self.stop_cond(board, moves):

            if search_tree == SearchType.MINIMAX:
                node = alphabeta.alphabeta(alphabeta.Node(board), args[0], args[1], args[2], nnet, encoder)
                board = node.get_node().state

            elif search_tree == SearchType.MCTS:
                tree = mctsAZ.Mcts(board, nnet, encoder)
                node = tree.search(args[0])
                board = node.state

            elif search_tree == SearchType.CUSTOM:
                tree = mcts_custom.Mcts(board, nnet, encoder)
                node = tree.search(args[0])
                board = node.state
                
            tree.print_tree(2)

            game.append(node)
            moves += 1
            
        return game

    def stop_cond(self, board, moves):
        '''Stops the game when it has reached terminal position or more moves than allowed were played.'''
        end = False

        if board.is_checkmate() or board.is_stalemate() or board.is_insufficient_material():
            end = True
        elif moves > self.MAX_MOVES:
            end = True

        return end

    
    def get_evaluation(self, search_type, node):
        '''Gets the node evaluation computed during search.'''            
        return node.evaluation
    

#     def get_value(self, reinf_type, search_type, winner, node):
#         if reinf_type == ReinforcementType.MC:
#             val = winner_to_num[winner]
#         elif reinf_type == ReinforcementType.TD or reinf_type == ReinforcementType.PARAM:
#             if search_type == SearchType.MINIMAX:
#                 val = node.get_val()[0]
#             elif search_type == SearchType.MCTS or search_type == SearchType.CUSTOM:
#                 val = node.actionValue

#         if reinf_type == ReinforcementType.PARAM:
#             val = self.PARAM * val + (1 - self.PARAM) * winner_to_num[winner]

#         return torch.Tensor([val]).cuda()

In [2]:
#checking few basic statistics about generated datasets
import statistics

encoder = autoencoder.autoencoder().cuda()
encoder.load_state_dict(torch.load("autoencoderftest2.pt"))
nnet = net.Net().cuda()
nnet.load_state_dict(torch.load("nnet_mcts.pt"))

args = (chess.Board(), nnet, encoder, SearchType.CUSTOM, 200)
GameGenerator = GameGenerator(32, 0, 0)
dataset = SearchDataset(64, Encode(encoder), ReinforcementType.MC, GameGenerator, *args)

vals, policies, positions = [], [], []
for position, val, policy in dataset:
    vals.append(val.item())
    policies.append(policy)
    positions.append(position)

print("Game result mean: ", statistics.mean(vals), " Standard deviation: ", statistics.stdev(vals))

print("Example policy: ", policies[32][0].detach().cpu().numpy())

* visited: 199, actionValue: 0.507646
	* visited: 10 , actionValue: 0.509109, choiceProbability: 0.103185
	* visited: 5  , actionValue:  0.49031, choiceProbability: 0.036482
	* visited: 8  , actionValue: 0.487378, choiceProbability:  0.02682
	* visited: 7  , actionValue: 0.481156, choiceProbability: 0.026722
	* visited: 8  , actionValue: 0.495283, choiceProbability: 0.029067
	* visited: 8  , actionValue: 0.489985, choiceProbability:  0.02655
	* visited: 13 , actionValue: 0.521813, choiceProbability: 0.022883
	* visited: 13 , actionValue: 0.511621, choiceProbability: 0.098768
	* visited: 12 , actionValue: 0.507984, choiceProbability: 0.090726
	* visited: 8  , actionValue: 0.504631, choiceProbability: 0.028344
	* visited: 9  , actionValue: 0.507316, choiceProbability: 0.028425
	* visited: 11 , actionValue:  0.51001, choiceProbability: 0.023624
	* visited: 4  , actionValue: 0.445211, choiceProbability: 0.030177
	* visited: 15 , actionValue: 0.527745, choiceProbability: 0.109788
	* visited

IndexError: list index out of range