In [None]:
from checkers.game import Game
import chess
import numpy as np
import math
import torch
import torch.nn as nn
from copy import deepcopy
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
class ReplayBuffer(object):
    def __init__(self, buffer_size):

        # total size of the replay buffer
        self.total_size = buffer_size

        # create a list to store the transitions
        self._data_buffer = []
        self._next_idx = 0

    def __len__(self):
        return len(self._data_buffer)

    def add(self, state, reward, policy):
        # create a tuple
        trans = (state, reward, policy)

        # interesting implementation
        if self._next_idx >= len(self._data_buffer):
            self._data_buffer.append(trans)
        else:
            self._data_buffer[self._next_idx] = trans

        # increase the index
        self._next_idx = (self._next_idx + 1) % self.total_size

    def _encode_sample(self, indices):
        # lists for transitions
        state_list, rewards_list, policy_list = [], [], []

        # collect the data
        for idx in indices:
            # get the single transition
            data = self._data_buffer[idx]
            state, reward, policy = data
            # store to the list
            state_list.append(np.array(state, copy=False))
            rewards_list.append(np.array(reward, copy=False))
            policy_list.append(np.array(policy, copy=False))
        # return the sampled batch data as numpy arrays
        return np.array(state_list), np.array(rewards_list), np.array(policy_list)

    def sample_batch(self, batch_size):
        # sample indices with replaced
        indices = [np.random.randint(0, len(self._data_buffer)) for _ in range(batch_size)]
        return self._encode_sample(indices)

In [None]:
class CheckersEnv(object):
    def __init__(self):
        self.game = Game()
        
    def reset(self):
        self.game = Game()

    def diagonalPosToCord(self, diagonal_pos):
        pos_val = diagonal_pos * 2 - 1
        row = (pos_val) // 8
        col = (pos_val) % 8 - (row) % 2

        return row, col

    def getAvailableMoves(self, game = None):
        chosenGame = self.game
        if game:
            chosenGame = game

        moves = []
        for move in chosenGame.get_possible_moves():
            moves.append((move[0],move[1]))
        return moves
    
    def swapPlayer(self, player):
        if player == 1:
            return 2
        return 1

    #returns Correct, S', R, Done
    def step(self, action, player, game = None, tupleify = False):
        chosenGame = self.game
        if game:
            chosenGame = game

        if chosenGame.whose_turn() != player:
            print("NOT MY TURN!")
            return -1,None,None,None
        
        if action not in self.getAvailableMoves(game):
            print("NOT POSSIBLE MOVE!")
            return -2,None,None,None
        
        chosenGame.move([action[0],action[1]])


        done = chosenGame.is_over()
        reward = 0 #Draws are not punished at the moment
        if done:
            reward = 1 if chosenGame.get_winner() == player else -1

        if tupleify:
            s_prime = self.tupleifyBoard(game = chosenGame)
        else:
            s_prime = self.npifyBoard(game = chosenGame)

        return 0, s_prime, reward, done, chosenGame.whose_turn()
    
    #create_new_board_from_move
    def npifyBoard(self, game = None):
        chosenGame = self.game
        if game:
            chosenGame = game
        npBoard = np.zeros((8,8))

        for piece in chosenGame.board.pieces:
            if piece.captured:
                continue

            row, col = self.diagonalPosToCord(piece.position)

            npBoard[row][col] = (1 if piece.player == 1 else -1) * (2 if piece.king else 1)
        return npBoard
    
    def tupleifyBoard(self, game = None):
        return tuple(map(tuple, self.npifyBoard(game)))

    #there are 32 valid positions: and due to capture chains we can have 31 valid destinations
    #most of these comboes of start and end will be impossible, but that can be handled by mcts I suppose
    def actionToNNIndex(self, a):
        start = a[0]-1
        end = a[1]
        return start * 32 + end
    

    def nnifyBoard(self, state, player):
        nnifiedBoard = np.zeros((1,4,8,8))
        board = state
        
        if player == 2:
            board = np.rot90(board, 2)
            board *= -1

        for i in range(len(board)):
            for j in range(len(board[i])):
                if board[i][j] == 1:
                    nnifiedBoard[0][0][i][j] = 1
                elif board[i][j] == 2:
                    nnifiedBoard[0][1][i][j] = 1
                elif board[i][j] == -1:
                    nnifiedBoard[0][2][i][j] = 1
                elif board[i][j] == -2:
                    nnifiedBoard[0][3][i][j] = 1

        return torch.tensor(nnifiedBoard,dtype=torch.float32)

    def nnifyActionProbs(self, actions, probabilities):
        probs = np.zeros((32*32,))

        for i in range(len(actions)):
            probs[self.actionToNNIndex(actions[i])] = probabilities[i]

        return probs


In [None]:
class MCTSCheckers(object):
    def __init__(self, num_searches, env, net, c):
        self.env = env
        self.game = env.game
        self.num_searches = num_searches
        self.c = c

        self.net = net

    def simulate(self, player):
        rootS = self.env.tupleifyBoard()
        self.Q = {}
        self.N = {rootS: 0}
        self.Na = {}
        self.V = {}
        initP, _ = self.net(self.env.nnifyBoard(rootS,player))
        self.P = {rootS: initP.detach().numpy()[0]}

        for b in range(self.num_searches):
            game = deepcopy(self.env.game)
            p = player
            self.search(rootS, p, game)

        valids = self.env.getAvailableMoves()
        visits = [self.Na[(rootS, a)] if (rootS, a) in self.Na else 0 for a in valids]
        sum_v = float(sum(visits))
        return [a / sum_v for a in visits], valids

    def search(self, s, player, game, id=0, return_multiplier=-1):
        if s not in self.P: #leaf node
            p_vals, v = self.net(self.env.nnifyBoard(s, player))
            self.P[s] = p_vals.detach().numpy()[0]
            self.N[s] = 0

            return return_multiplier*v
        

        best, best_a = None, None
        for action in self.env.getAvailableMoves(game):
            if (s,action) not in self.Na.keys():
                Na = 0
            else:
                Na = self.Na[(s,action)]

            if s not in self.N.keys():
                N = 0
            else:
                N = self.N[s]

            if (s,action) in self.Q.keys():
                a_to_index = self.env.actionToNNIndex(action)
                UCB = self.Q[(s,action)] + self.c * self.P[s][a_to_index]*math.sqrt(N)/(1 + Na)
            else:
                a_to_index = self.env.actionToNNIndex(action)
                UCB = self.c * self.P[s][a_to_index] * math.sqrt(N + .000001)
            
            if best is None or UCB > best:
                best = UCB
                best_a = action

        beforePlayer = player
        err, s_prime, r, done, player = self.env.step(best_a, player, game, tupleify=True)
        
        if(done):
            if (s, best_a) in self.Q.keys():
                self.Q[(s,best_a)] = (self.Na[(s,best_a)] * self.Q[(s,best_a)] + r)/(self.Na[(s,best_a)]+1)
                self.Na[(s,best_a)] += 1
            else:
                self.Q[(s,best_a)] = r
                self.Na[(s,best_a)] = 1

            self.N[s] += 1
            return return_multiplier*r
        else:
            if beforePlayer == player:
                v = self.search(s_prime, player, game, id=id+1, return_multiplier=1)
            else:
                v = self.search(s_prime, player, game, id=id+1)
            
            if (s, best_a) in self.Q.keys():
                self.Q[(s,best_a)] = (self.Na[(s,best_a)] * self.Q[(s,best_a)] + v)/(self.Na[(s,best_a)]+1)
                self.Na[(s,best_a)] += 1
            else:
                self.Q[(s,best_a)] = v
                self.Na[(s,best_a)] = 1

            self.N[s] += 1
            return return_multiplier*v

In [None]:
class CheckersConvNet(nn.Module):
    def __init__(self, input_channels, policy_output_dim):
        super(CheckersConvNet, self).__init__()

        def _conv_layer(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            )
        

        self.conv1 = _conv_layer(input_channels, 64)
        self.conv2 = _conv_layer(64, 128)
        self.conv3 = _conv_layer(128, 256)
        
        self.policy_head = nn.Linear(256 * 64, policy_output_dim)  # Flattened size after conv layers
        
        #value head
        self.value_flat = nn.Linear(256 * 64, 128)
        self.value_head = nn.Linear(128, 1) #only want a single value output



    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        x = x.view(x.size(0), -1)

        p = self.policy_head(x)
        p = F.softmax(p, dim=-1)

        v = self.value_flat(x)
        v = F.relu(v)
        v = self.value_head(v)
        v = torch.tanh(v)

        return p, v


In [None]:
class MCTSStandard(object):
    def __init__(self, num_searches, env, c):
        self.env = env
        self.game = env.game
        self.num_searches = num_searches
        self.c = c

    def simulate(self, player):
        rootS = self.env.tupleifyBoard()
        self.T = {}
        self.N = {rootS:0}
        self.Na = {}

        for b in range(self.num_searches):
            game = deepcopy(self.env.game)
            p = player
            self.search(rootS, p, game)

        valids = self.env.getAvailableMoves()
        visits = [self.Na[(rootS, a)] if (rootS, a) in self.Na else 0 for a in valids]
        sum_v = float(sum(visits))
        return [a / sum_v for a in visits], valids

    def search(self, s, player, game, id=0, return_multiplier=-1):
        if s not in self.N: #leaf node=
            self.N[s] = 0
            return -self.rollout(player, game)

        best, best_a = None, None
        for action in self.env.getAvailableMoves(game):
            if (s,action) not in self.Na.keys():
                Na = 0
                T_val = 0
            else:
                Na = self.Na[(s,action)]
                T_val = self.T[(s,action)]
            
            if s not in self.N.keys():
                N = 0
            else:
                N = self.N[s]

            if Na == 0:
                Na = .0001

            UCB = T_val / Na + math.sqrt((2*math.log(N+1))/Na)
            
            if best is None or UCB > best:
                best = UCB
                best_a = action

        beforePlayer = player
        err, s_prime, r, done, player = self.env.step(best_a, player, game, tupleify=True)
        
        if(done):
            if (s, best_a) in self.T.keys():
                self.T[(s,best_a)] += r
                self.Na[(s,best_a)] += 1
            else:
                self.T[(s,best_a)] = r
                self.Na[(s,best_a)] = 1

            self.N[s] += 1
            return return_multiplier*r
        else:
            if beforePlayer == player:
                v = self.search(s_prime, player, game, id=id+1, return_multiplier=1)
            else:
                v = self.search(s_prime, player, game, id=id+1)
            
            if (s, best_a) in self.T.keys():
                self.T[(s,best_a)] += v
                self.Na[(s,best_a)] += 1
            else:
                self.T[(s,best_a)] = v
                self.Na[(s,best_a)] = 1

            self.N[s] += 1
            return return_multiplier*v
        
    def rollout(self, player, game):
        while True:
            action = self.env.getAvailableMoves(game)[random.randint(0,len(self.env.getAvailableMoves(game))-1)]
            _, _, reward, done, player = self.env.step(action, player, game, tupleify=True)

            if done:
                reward = reward if player == 1 else -reward
                return reward

In [None]:
def _batch_to_tensor(batch_data):
    # store the tensor
    batch_data_tensor = {'state': [], 'reward': [], 'policy': []}
    # get the numpy arrays
    state_arr, reward_arr, policy_arr = batch_data
    # convert to tensors
    batch_data_tensor['state'] = torch.tensor(state_arr, dtype=torch.float32)
    batch_data_tensor['reward'] = torch.tensor(reward_arr, dtype=torch.float32)
    batch_data_tensor['policy'] = torch.tensor(policy_arr, dtype=torch.float32)

    return batch_data_tensor


def trainModel(model, v_loss, p_loss, optim, batch):

    batch_data_tensor = _batch_to_tensor(batch)

    state_tensor = batch_data_tensor['state']
    reward_tensor = batch_data_tensor['reward']
    policy_tensor = batch_data_tensor['policy']

    state_tensor = torch.reshape(state_tensor, (32, 4, 8, 8))

    policy_estimates, value_estimates = model(state_tensor)

    value_loss = v_loss(value_estimates, reward_tensor)

    policy_loss = p_loss(policy_estimates, policy_tensor)

    total_loss = value_loss + policy_loss

    optim.zero_grad()
    total_loss.backward()
    optim.step()

    return (policy_loss + value_loss).item()


In [None]:
class RandomCheckers(object):
    def __init__(self, env):
        self.env = env

    def simulate(self, player):

        valids = self.env.getAvailableMoves()
        return [1 / len(valids) for a in valids], valids

In [None]:
def evaluateModels(baseline, comparison, matchCount, env):
    results = []
    for _ in range(matchCount):
        env.reset()
        player = 1
        while True:
            if player == 1:
                probs, actions = baseline.simulate(player)
            else:
                probs, actions = comparison.simulate(player)
            action = random.choices(actions, weights=probs, k=1)[0]
            _, _, reward, done, player = env.step(action, player)
            if done:
                reward = reward if player == 1 else -reward
                results.append(reward)
                break

    return results

def trainAndSelfPlay(total_episodes, trainFreq, evalFreq, num_searches, c_constant, lr, weight_decay, batch_size, checkpoint, training_epochs = 5):
    env = CheckersEnv()
    net = CheckersConvNet(4, 32*32)
    oldNet = CheckersConvNet(4, 32*32)
    oldNet.load_state_dict(net.state_dict())
    mcts = MCTSCheckers(num_searches, env, net, c_constant)
    oldMCTS = MCTSCheckers(num_searches, env, oldNet, c_constant)
    ranModel = RandomCheckers(env)
    standardMonte = MCTSStandard(num_searches, env, c_constant)

    p_loss = nn.CrossEntropyLoss()
    v_loss = nn.MSELoss()
    optim = torch.optim.Adam(lr= lr, weight_decay=weight_decay, params=net.parameters())

    replay_buffer = ReplayBuffer(50000) #hard coded buffer size

    episode_lens = []
    training_loss = []
    winrates_past = []
    winrates_random = []

    for episode_t in range(total_episodes):
        env.reset()
        player = 1

        positions = [env.npifyBoard()]
        policies = []
        players = []
        moves = 0
        print("STARTED EPISODE:", episode_t)
        while True:
            probs, actions = mcts.simulate(player)
            
            policies.append(env.nnifyActionProbs(actions, probs))
            players.append(player)

            #sample random move
            action = random.choices(actions, weights=probs, k=1)[0]

            status, new_state, reward, done, player = env.step(action, player)

            if done:
                reward = reward if player == 1 else -reward
                break

            positions.append(new_state)

            if status != 0:
                print('SOMETHING WENT WRONG WHEN TAKING A MOVE!')
                continue

            moves += 1

        episode_lens.append(moves)

        print("ENDED EPISODE:", episode_t, "WITH",moves,"MOVES AND PLAYER 1 RESULT",reward)

        player = 1
        for i in range(len(positions)):
            replay_buffer.add(env.nnifyBoard(positions[i],players[i]),reward * (-1)**i,policies[i])


        if episode_t%trainFreq == 0 and episode_t != 0:
            for i in range(training_epochs):
                batch = replay_buffer.sample_batch(batch_size)
                loss = trainModel(net, v_loss, p_loss, optim, batch)
                training_loss.append(loss)


        if episode_t%evalFreq == 0 and episode_t != 0:
            print("EVALUATING OLD AlphaZero, and MCTS")
            alphaResults = evaluateModels(oldMCTS, mcts,11,env)
            print("ALPHA RESULTS: ",alphaResults)

            winrate = (alphaResults.count(0)*.5 + alphaResults.count(-1))/11
            winrates_past.append(winrate)

            if alphaResults.count(1) < alphaResults.count(-1):
                print("OLD ALPHAZERO NETWORK IS UPDATED")
                oldNet.load_state_dict(net.state_dict())
            ranResults = evaluateModels(standardMonte, mcts, 11, env)
            winrate = (ranResults.count(0)*.5 + ranResults.count(-1))/11
            print("MCTS RESULTS: ",ranResults, winrate)
            winrates_random.append(winrate)
    
    with open('loss.txt', 'w') as file:
        for value in training_loss:
            file.write(f"{value}\n")

    with open('episodelen.txt','w') as file:
        for value in episode_lens:
            file.write(f"{value}\n")

    with open('winrates_past.txt','w') as file:
        for value in winrates_past:
            file.write(f"{value}\n")

    with open('winrates_mcts.txt','w') as file:
        for value in winrates_random:
            file.write(f"{value}\n")


    return training_loss, episode_lens


In [None]:
total_episodes = 400
trainFreq = 3
evalFreq = 30
num_searches = 60
c_constant = 1
lr = .001
weight_decay = 5e-4
batch_size = 32
checkpoint = 200
train_loss, episode_len = trainAndSelfPlay(total_episodes, trainFreq, evalFreq, num_searches, c_constant, lr, weight_decay, batch_size, checkpoint)

In [None]:
plt.plot(episode_len)

In [None]:
plt.plot(train_loss)

In [None]:
winrates_mcts = []
with open('winrates_mcts.txt', 'r') as w:
    winrates_mcts = w.read().split('\n')
    winrates_mcts.pop()
    print(winrates_mcts)

winrates_mcts = [float(x) for x in winrates_mcts]

In [None]:
plt.plot(winrates_mcts)

In [None]:
winrates_random = []
with open('winrates_random.txt', 'r') as w:
    winrates_random = w.read().split('\n')
    winrates_random.pop()

winrates_random = [float(x) for x in winrates_random]

In [None]:
plt.plot(winrates_random)

In [None]:
winrates_past = []
with open('winrates_past.txt', 'r') as w:
    winrates_past = w.read().split('\n')
    winrates_past.pop()

winrates_past = [float(x) for x in winrates_past]

In [None]:
plt.plot(winrates_past)

In [None]:
winrates_past2 = []
with open('winrates_past2.txt', 'r') as w:
    winrates_past2 = w.read().split('\n')
    winrates_past2.pop()
winrates_past2 = [float(x) for x in winrates_past2]

In [None]:
plt.plot(winrates_past2)