# Imports

In [None]:
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import deepcopy

# Go Game

In [None]:
class Piece:
    def __init__(self, action, player, state, args) -> None:
        self.args = args
        self.state = state
        self.action = action  # (x, y)
        self.player = player  # 1 or -1 or 0 (for territory calculations)
        self.x_dim = args[0]
        self.y_dim = args[1]
        self.action_size = self.x_dim * self.y_dim
        self.neighbors = self.get_neighbors(
            action)  # list of neighbour coordinates
        self.group = self.search_group(action, player, state)

    def get_neighbors(self, action):
        x = action[0]
        y = action[1]
        neighbors = [(x-1, y), (x+1, y), (x, y-1), (x, y+1)]
        neighbors = [neighbor for neighbor in neighbors if neighbor[0] >=
                     0 and neighbor[0] < self.x_dim and neighbor[1] >= 0 and neighbor[1] < self.y_dim]
        return neighbors

    def search_group(self, action, player, state):

        neighbors = []
        for neighbor in self.neighbors:
            if state[neighbor[0]][neighbor[1]] != 0 and state[neighbor[0]][neighbor[1]].player == player:
                neighbors.append(state[neighbor[0]][neighbor[1]])

        # if neighbors has no pieces of same player
        if neighbors == []:
            group = Group(action, player, state, args)
            group.add_piece(self)
            return group

        groups = []

        for neighbor in neighbors:
            if neighbor != 0 and neighbor.player == player:  # if neighbor is same player and not empty
                groups.append(neighbor.group)

        if groups == []:
            group = Group(action, player, state, args)
            group.add_piece(self)
            return group

        else:
            if len(groups) > 1:
                for group in groups[1:]:
                    groups[0].merge_group(group)

        groups[0].add_piece(self)
        return groups[0]

    def __str__(self) -> str:
        return str(self.player)


class Group:
    def __init__(self, action, player, state, args) -> None:
        self.args = args
        self.state = state
        self.action = action
        self.player = player
        self.x_dim = args[0]
        self.y_dim = args[1]
        self.pieces = []
        self.liberties = []

    def add_piece(self, piece):
        self.pieces.append(piece)
        piece.group = self

    def search_liberties(self, player, state):
        #print("SEARCHING LIBERTIES\n")
        # print(self.pieces)
        liberties = []
        for piece in self.pieces:
            # print(piece)
            # print(piece.neighbors)
            for neighbor in piece.neighbors:

                # print("CHECKING " +str(neighbor))
                # print(state[neighbor[0]][neighbor[1]])
                if state[neighbor[0]][neighbor[1]] == 0:

                    # print("LIBERTY FOUND at " + str(neighbor))
                    liberties.append(neighbor)
        # print("LIBERTIES: " + str(liberties))
        return set(liberties)

    def merge_group(self, group):
        self.pieces += group.pieces

        for piece in group.pieces:
            piece.group = self

        return self

    def capture(self, state):
        quant = len(self.pieces)
        for piece in self.pieces:
            state[piece.action[0]][piece.action[1]] = 0
            piece.group = None
        return quant


# ##################################################################### #

class Go:
    def __init__(self, args) -> None:
        self.args = args
        self.x_dim = args[0]
        self.y_dim = args[1]
        self.komi = args[2]
        self.state = self.get_initial_state()
        self.player = 1
        self.previous_equals = False
        self.statelist = []
        self.prisioners = [0, 0]

    def get_initial_state(self):
        board = []
        for i in range(self.x_dim):
            board.append([])
            for j in range(self.y_dim):
                board[i].append(0)
        return board

    def put_piece(self, state, action, piece: Piece):

        pri = 0

        state[action[0]][action[1]] = piece  # temporary for checking
        piece.group.liberties = piece.group.search_liberties(
            piece.player, state)  # update liberties4,4
        # if it doesn't capture antything, remove piece and print suicide
        for neighbor in piece.neighbors:
            if state[neighbor[0]][neighbor[1]] != 0:
                state[neighbor[0]][neighbor[1]].group.liberties = state[neighbor[0]
                                                                        ][neighbor[1]].group.search_liberties(state[neighbor[0]][neighbor[1]], state)
                if state[neighbor[0]][neighbor[1]].player != piece.player:
                    if len(state[neighbor[0]][neighbor[1]].group.liberties) == 0:
                        pri += state[neighbor[0]][neighbor[1]
                                                  ].group.capture(state)

        self.prisioners[0 if piece.player == -1 else 1] += pri / 2

        return state

    def suicide(self, state, piece) -> bool:

        # deepcopy the board to verify suicide
        copystate = deepcopy(state)
        copypiece = deepcopy(piece)

        copystate[action[0]][action[1]] = copypiece  # temporary
        copypiece.group.liberties = copypiece.group.search_liberties(
            copypiece.player, copystate)
        # if it has more than 0 liberties, no suicide
        if (len(copypiece.group.liberties) > 0):
            return False

        # if it removes a enemy group, liberties above 0 so legal
        for neighbor in copypiece.neighbors:
            if copystate[neighbor[0]][neighbor[1]] != 0:
                copystate[neighbor[0]][neighbor[1]].group.liberties = copystate[neighbor[0]][neighbor[1]
                                                                                             ].group.search_liberties(copystate[neighbor[0]][neighbor[1]].player, copystate)
                if len(copystate[neighbor[0]][neighbor[1]].group.liberties) == 0 and copystate[neighbor[0]][neighbor[1]].player != copypiece.player:
                    # capture group if it has no liberties and is not same player
                    return False

        return True

    def get_next_state(self, state, action, player):
        next_state = state.copy()

        if go.check_skip(state, action, go.player):
            if self.previous_equals:
                self.get_winner(state)
                return -1
            else:
                go.player = go.change_player()
                self.previous_equals = True
                return state

        # print("ACTION: " + str(action))
        piece = Piece(action, player, state, self.args)
        if self.suicide(state, piece):
            print("Suicide is an illegal move")
            return state

        # print("NEIGHBOURS: " + str(piece.neighbors))
        # print("PLAYER: " + str(piece.player))
        next_state = self.put_piece(next_state, action, piece)
        statemat = self.convert_state_to_matrix(next_state)
        #print("ADDING TO REPEAT CHECK: ")
        self.add_matrix_to_positions(statemat)
        go.player = go.change_player()
        self.previous_equals = False

        # print("GROUP PIECES: " + str(piece.group.pieces))
        # print("GROUP LIBERTIES: " + str(piece.group.liberties))

        return next_state

    def is_valid_move(self, state, action, player):
        x, y = action

        statecopy = deepcopy(state)
        temppiece = Piece(action, player, statecopy, args)

        if state[x][y] != 0:
            return False

        self.put_piece(statecopy, action, temppiece)

        if self.suicide(statecopy, temppiece):
            return False

        # print("NEW TEMPORARY STATE:")
        statecopy = self.convert_state_to_matrix(statecopy)
        # print(str(statecopy))
        # print("VERIFYING REPEATED LIST: ")
        # print(str(self.statelist))
        if any(np.array_equal(statecopy, stateelement) for stateelement in self.statelist):
            print("Invalid Move: Repeated State")
            return False
        return True

    def print_board(self, state):
        print("\nEvaluation: 1 | %.1f | - | %.1f | -1" %(self.get_score( state, 1), self.get_score(state, -1)))
        print("Prisioners: 1 | %.1f | - | %.1f | -1\n" %(self.prisioners[1], self.prisioners[0]))
        # Print column coordinates
        print("   ", end="")
        for j in range(len(state[0])):
            print(f"{j:2}", end=" ")
        print("\n  +", end="")
        for _ in range(len(state[0])):
            print("---", end="")
        print()

        # Print rows with row coordinates
        for i in range(len(state)):
            print(f"{i:2}|", end=" ")
            for j in range(len(state[0])):
                print(f"{str(state[i][j]):2}", end=" ")
            print()

    def change_player(self):
        return -self.player

    def get_valid_moves(self, state, player):
        valid_actions = []
        for i in range(len(state)):
            for j in range(len(state[1])):
                if state[i, j] == 0 and self.is_valid_move(state, (i, j), player):
                    valid_actions.append((i, j))
        return valid_actions

    def check_skip(self, state, action, player):
        if (action == (-1, -1)):
            print("Player " + str(player) + " skips")
            return True

    def evaluate(self, state, player):

        statecopy = deepcopy(state)

        pieces = 0
        territory = 0
        prisioners = self.prisioners[0 if player == -1 else 1]
        territory_spaces = set()

        for i in range(self.x_dim):
            for j in range(self.y_dim):
                if statecopy[i][j] != 0 and statecopy[i][j].player == player:
                    pieces += 1
                    for neighbor in statecopy[i][j].neighbors:
                        piece = Piece(
                            (neighbor[0], neighbor[1]), 0, statecopy, args)
                        statecopy[neighbor[0]][neighbor[1]] = piece
                        sum = 0
                        for adj in piece.neighbors:
                            if statecopy[adj[0]][adj[1]] != 0:
                                sum += 1
                        if sum > 0 and abs(sum) >= 2:
                            territory_spaces.add((neighbor[0], neighbor[1]))

                    territory = len(territory_spaces)

        return pieces, territory, prisioners
    
    def get_score(self, state, player):
        p1, t1, pri1 = self.evaluate(state, player)
        if player == -1:
            p1 += self.komi

        score = p1 + t1 + pri1
        return score

    def get_winner(self, state):
        p1, t1, pri1 = self.evaluate(state, 1)
        p2, t2, pri2 = self.evaluate(state, -1)

        if p1 > p2:
            print("Player 1 wins with " + str(p1) + " stones, " + str(pri1) + " prisioners and " + str(t1) +
                  " of territory against " + str(p2) + " stones, " + str(pri2) + " prisioners and " + str(t2) + " of territory")
        elif p2 > p1:
            print("Player -1 wins with " + str(p2) + " stones, " + str(pri2) + " prisioners and " + str(t2) +
                  " of territory against " + str(p1) + " stones, " + str(pri1) + " prisioners and " + str(t1) + " of territory")

    def add_matrix_to_positions(self, matrix):
        # print(str(matrix))
        self.statelist.append(matrix)

    def convert_state_to_matrix(self, state):
        mat = np.zeros((self.x_dim, self.y_dim))
        for i in range(len(state)):
            for j in range(len(state[1])):
                if str(state[i][j]) == "1":
                    mat[i][j] = 1
                elif str(state[i][j]) == "-1":
                    mat[i][j] = -1
        return mat


args = [5, 5, 5.5]  # x, y, komi
go = Go(args)
state = go.get_initial_state()

while True:
    go.print_board(state)

    action = input("Input move x,y | -1,-1 to pass: \n")
    action = action.split(",")

    try:
        action = (int(action[0]), int(action[1]))
        while (action[0] >= go.x_dim or action[1] >= go.y_dim or action[0] < -1 or action[1] < -1):
            print("Invalid Move: Out of Bounds")
            action = input("Input move x,y | -1,-1 to pass: \n")
            action = action.split(",")
            action = (int(action[0]), int(action[1]))

    except:
        print("Invalid Move: Not a number")
        continue

    while (state[action[0]][action[1]] != 0 or go.is_valid_move(state, action, go.player) == False):
        print("Invalid Move: Not a valid move")
        action = input("Input move x,y | -1,-1 to pass: \n")
        action = action.split(",")
        action = (int(action[0]), int(action[1]))

    state = go.get_next_state(state, action, go.player)

    if state == -1:
        break


# Attaxx Game

In [None]:
class Attaxx:
    def __init__(self, args):
        self.x_dim = args[0]
        self.y_dim = args[1]
        self.action_size = self.x_dim * self.y_dim
    
    def get_initial_state(self):
        state = np.zeros((self.x_dim, self.y_dim))
        state[0][0] = 1
        state[self.x_dim-1][self.y_dim-1] = 1
        state[0][self.x_dim-1] = -1
        state[self.y_dim-1][0] = -1
        return state

    def get_next_state(self, state, action, player):
        a, b, a1, b1 = action
        if abs(a-a1)==2 or abs(b-b1)==2:
            state[a][b] = 0
            state[a1][b1] = player
        else:
            state[a1][b1] = player
        self.capture_pieces(state, action, player)
        return state
        

    def is_valid_move(self, state, action, player):
        a, b, a1, b1 = action
        if abs(a-a1)>2 or abs(b-b1)>2 or state[a1][b1]!=0 or state[a][b]!=player or ((abs(a-a1)==1 and abs(b-b1)==2) or (abs(a-a1)==2 and abs(b-b1)==1)):
            return False

        return True

    def capture_pieces(self, state, action, player):
        a, b, a1, b1 = action
        for i in range(a1-1, a1+2):
            for j in range(b1-1, b1+2):
                try:
                    if state[i][j]==-player and i>=0 and j>=0:
                        state[i][j] = player
                except IndexError:
                    pass
                continue

    def check_available_moves(self, state, player):
        for i in range(self.x_dim):
            for j in range(self.y_dim):
                if state[i][j] == player:
                    for a in range(self.x_dim):
                        for b in range(self.y_dim):
                            action = (i, j, a, b)
                            if self.is_valid_move(state, action, player):
                                return True
        return False

    def get_valid_moves(self, state, player):

        possible_moves = set()

        for i in range(self.x_dim):
            for j in range(self.y_dim):
                state[i][j] = int(state[i][j])
                if state[i][j] == player:
                    moves_at_point = set(self.get_moves_at_point(state, player, i, j))
                    possible_moves = possible_moves.union(moves_at_point)
        
        return possible_moves

    def get_moves_at_point(self, state, player, a, b):

        moves_at_point = []

        for i in range(self.x_dim):
            for j in range(self.y_dim):
                possible_action = (a, b, i, j)
                if self.is_valid_move(state, possible_action, player):
                    moves_at_point.append(possible_action)
        return moves_at_point 

    def check_board_full(self, state):
        for row in state:
            if 0 in row:
                return False
        
        return True

    def check_win_and_over(self, state):

        count_player1 = 0
        count_player2 = 0

        for i in range(self.x_dim):
            for j in range(self.y_dim):
                if state[i][j] == 1:
                    count_player1+=1
                elif state[i][j] == -1:
                    count_player2+=1
        if count_player1 == 0:
            return -1, True
        elif count_player2 == 0:
            return 1, True
        
        if self.check_board_full(state):
            if count_player1>count_player2:
                return 1, True
            elif count_player2>count_player1:
                return -1, True
            elif count_player1==count_player2:
                return 2, True
        
        return 0, False
    
    def get_value_and_terminated(self, state):
        winner, game_over = self.check_win_and_over(state)
        return winner, game_over
    
    def print_board(self, state):
        state = state.astype(int)
        # Print column coordinates
        print("   ", end="")
        for j in range(len(state[0])):
            print(f"{j:2}", end=" ")
        print("\n  +", end="")
        for _ in range(len(state[0])):
            print("---", end="")
        print()

        # Print rows with row coordinates
        for i in range(len(state)):
            print(f"{i:2}|", end=" ")
            for j in range(len(state[0])):
                print(f"{str(state[i][j]):2}", end=" ")
            print()

# Graphical Interface

In [2]:
# por adicionar

# Monte Carlo Tree Search

In [1]:
import numpy as np
import math
from attax_game import Attaxx

class Node():
    def __init__(self, game, C, state, player, parent = None, action_taken = None) -> None:
        self.game = game
        self.C = C
        self.state = state.astype(int)
        self.parent = parent
        self.action_taken = action_taken
        self.player = player
        self.children = []
        self.expandable_moves = game.get_valid_moves(self.state, player)

        self.visit_count = 0
        self.value_sum = 0
    
    def is_fully_expanded(self):
        return len(self.expandable_moves) == 0 and len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb  = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        
        return best_child

    def get_ucb(self, child):
        q_value = 1 - ((child.value_sum/child.visit_count)+1) / 2
        return q_value + self.C * math.sqrt(math.log(self.visit_count) / child.visit_count)
    
    def expand(self, player):
        moves_arr = list(self.expandable_moves)
        random_index = np.random.choice(len(self.expandable_moves), 1)[0]
        action = moves_arr[random_index]
        self.expandable_moves.remove(action)

        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, player)
        child = Node(self.game, self.C, child_state, action, -player)
        self.children.append(child)
        return child
    
    def simulate(self, player):
        value, is_terminal = self.game.get_value_and_terminated(self.state)
        
        if is_terminal:
            return value

        rollout_state = self.state.copy()
        rollout_player = player
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state, rollout_player)
            action = np.random.choice(valid_moves)
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.get_value_and_terminated(rollout_state)
            if is_terminal:
                return value
            rollout_player = -player
    
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1

        if self.parent is not None:
            self.parent.backpropagate(value)

class MCTS():
    def __init__(self, game, C, num_searches) -> None:
        self.game = game
        self.C = C
        self.num_searches = num_searches
    
    def search(self, state, player):
        #define root
        root = Node(self.game, self.C, state, player)
        #selection 
        for search in range(self.num_searches):
            node = root

            while node.is_fully_expanded():
                node = node.select()
            
            value, is_terminal = self.game.get_value_and_terminated(node.state)
            
            if not is_terminal:
                node = node.expand(player)
                value = node.simulate()
            
            node.backpropagate(value)

            action_probs = {}
            for child in root.children:
                action_probs[child.action_taken] = child.visit_count
            return action_probs

attaxx_game = Attaxx([5, 5])

args = {
    'C': 1.41,
    'num_searches': 1000
}

state = attaxx_game.get_initial_state()
mcts = MCTS(attaxx_game, 1.41, 1000)
player = 1

while True: 
    attaxx_game.print_board(state)
    print("So true")

    if player == 1:
        print("Player 1")
        if attaxx_game.check_available_moves(state, player):
            # print(attaxx_game.get_valid_moves(state, player))
            a, b, a1, b1 = tuple(int(x.strip()) for x in input().split(' ')) #input e assim: 0 0 0 0
            action = (a, b, a1, b1)
            if attaxx_game.is_valid_move(state, action, player):
                attaxx_game.get_next_state(state, action, player)
                player = - player
                winner, win = attaxx_game.check_win_and_over(state)
                if win:
                    attaxx_game.print_board(state)
                    print(f"player {winner} wins")
                    exit()
    
    else:
        print("Player -1 MCTS")
        # print(attaxx_game.get_valid_moves(state, player))
        mcts_prob = mcts.search(state, player)
        action_selected = max(mcts_prob)
        print("Acoes: " + action_selected)
        if attaxx_game.is_valid_move(state, action_selected, player):
                attaxx_game.get_next_state(state, action, player)
                player = - player
                winner, win = attaxx_game.check_win_and_over(state)
                if win:
                    attaxx_game.print_board(state)
                    print(f"player {winner} wins")
                    exit()




# Neural Network

In [None]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden):
        super().__init__()

        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.y_dim * game.x_dim, game.action_size)
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.y_dim * game.x_dim, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
    
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

# Alpha Zero

In [None]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)
        
    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()
        
        while True:
            neutral_state = self.game.change_perspective(state, player)
            action_probs = self.mcts.search(neutral_state)
            
            memory.append((neutral_state, action_probs, player))
            
            action = np.random.choice(self.game.action_size, p=action_probs)
            
            state = self.game.get_next_state(state, action, player)
            
            value, is_terminal = self.game.get_value_and_terminated(state, action)
            
            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory
            
            player = self.game.get_opponent(player)
                
            
    
    def train(self, memory):
        pass

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in range(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()
                
            self.model.train()
            for epoch in range(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")

# Runtime

In [None]:
t = "Go"

In [None]:
if t == "Go":
    args = [5, 5, 5.5]  # x, y, komi
    game = Go(args)

if t == "Attaxx":
    args = [5,5]
    game = Attaxx(args)

state = game.get_initial_state()

model = ResNet(game, 9, 64)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

args = {
    'C': 2,
    'num_searches': 60,
    'num_iterations': 3,
    'num_selfPlay_iterations': 10,
    'num_epochs': 4
}