# Alpha Zero For Generalized Game Reinforcement Learning

## Imports

In [None]:
import torch
from torch.optim import Adam
import random
import numpy as np
import os

## Go

Explicação do jogo

### Game Implementation

O que vamos implementar, nuances, etc,etc

In [None]:
class Go():

    EMPTY = 0
    BLACK = 1
    WHITE = -1
    BLACKMARKER = 4
    WHITEMARKER = 5
    LIBERTY = 8

    def __init__(self):
        self.row_count = 9
        self.column_count = 9
        self.komi = 6.5
        self.action_size = self.row_count * self.column_count + 1
        self.liberties = []
        self.block = []
        self.seki_liberties = []
        
    def get_initial_state(self):
        board = np.zeros((self.row_count, self.column_count))
        return board
    

    def count(self, x, y, state: list, player:int , liberties: list, block: list) -> tuple[list, list]:
        '''
        # Description:
        Counts the number of liberties of a stone and the number of stones in a block.
        Follows a recursive approach to count the liberties of a stone and the number of stones in a block.

        # Returns:
        A tuple containing the number of liberties and the number of stones in a block.
        '''
        
        #initialize piece
        piece = state[y][x]
        #if there's a stone at square of the given player
        if piece == player:
            #save stone coords
            block.append((y,x))
            #mark the stone
            if player == self.BLACK:
                state[y][x] = self.BLACKMARKER
            else:
                state[y][x] = self.WHITEMARKER
            
            #look for neighbours recursively
            if y-1 >= 0:
                liberties, block = self.count(x,y-1,state,player,liberties, block) #walk north
            if x+1 < self.column_count:
                liberties, block = self.count(x+1,y,state,player,liberties, block) #walk east
            if y+1 < self.row_count:
                liberties, block = self.count(x,y+1,state,player,liberties, block) #walk south
            if x-1 >= 0:
                liberties, block = self.count(x-1,y,state,player,liberties, block) #walk west

        #if square is empty
        elif piece == self.EMPTY:
            #mark liberty
            state[y][x] = self.LIBERTY
            #save liberties
            liberties.append((y,x))

        # print("Liberties: " + str(len(self.liberties)) + " in: " + str(x) + "," + str(y))
        # print("Block: " + str(len(self.block)) + " in: " + str(x) + "," + str(y))
        return liberties, block

    #remove captured stones
    def clear_block(self, block: list, state: list) -> list:
        '''
        # Description:
        Clears the block of stones captured by the opponent on the board.

        # Returns:
        The board with the captured stones removed.
        '''

        #clears the elements in the block of elements which is captured
        for i in range(len(block)): 
            y, x = block[i]
            state[y][x] = self.EMPTY
        
        return state

    #restore board after counting stones and liberties
    def restore_board(self, state: list) -> list:
        '''
        # Description:
        Restores the board to its original state after counting liberties and stones.
        This is done by unmarking the stones following bitwise operations with the global class variables.
        
        # Returns:
        The board with the stones unmarked.
        '''

        #unmark stones
        # print("Restore Board")
        # print(state)
        for y in range(len(state)):
            for x in range(len(state)):
                #restore piece
                val = state[y][x]
                if val == self.BLACKMARKER:
                    state[y][x] = self.BLACK
                elif val == self.WHITEMARKER:
                    state[y][x] = self.WHITE
                elif val == self.LIBERTY:
                    state[y][x] = self.EMPTY

        # print("After Restore Board")
        # print(state)
        return state

    def print_board(self, state: list) -> None:
            '''
            # Description:
            Draws the board in the console.

            # Returns:
            None
            '''

        # Print column coordinates
            print("   ", end="")
            for j in range(len(state[0])):
                print(f"{j:2}", end=" ")
            print("\n  +", end="")
            for _ in range(len(state[0])):
                print("---", end="")
            print()

            # Print rows with row coordinates
            for i in range(len(state)):
                print(f"{i:2}|", end=" ")
                for j in range(len(state[0])):
                    print(f"{str(int(state[i][j])):2}", end=" ")
                print()
    
    def captures(self, state: list,player: int, a:int, b:int) -> tuple[bool, list]:
        '''
        # Description:
        Checks if a move causes a capture of stones of the player passed as an argument.
        If a move causes a capture, the stones are removed from the board.

        # Returns:
        A tuple containing a boolean indicating if a capture has been made and the board with the captured stones removed.
        '''
        check = False
        neighbours = []
        if(a > 0): neighbours.append((a-1, b))
        if(a < self.column_count - 1): neighbours.append((a+1, b))
        if(b > 0): neighbours.append((a, b - 1))
        if(b < self.row_count - 1): neighbours.append((a, b+1))

        #loop over the board squares
        for pos in neighbours:
            # print(pos)
            x = pos[0]
            y = pos[1]    
            # init piece
            piece = state[x][y]

                #if stone belongs to given colour
            if piece == player:
                # print("opponent piece")
                # count liberties
                liberties = []
                block = []
                liberties, block = self.count(y, x, state, player, liberties, block)
                # print("Liberties in count: " + str(len(liberties)))
                # if no liberties remove the stones
                if len(liberties) == 0: 
                    #clear block
                    state = self.clear_block(block, state)
                    check = True

                #restore the board
                state = self.restore_board(state)

        #print("Captures: " + str(check))
        return check, state
    
    def set_stone(self, a, b, state, player):
        state[a][b] = player
        return state
    
    def get_next_state(self, state, action, player):
        '''
        # Description
        Plays the move, verifies and undergoes captures and saves the state to the history.
        
        # Returns:
        New state with everything updated.
        '''
        if action == self.row_count * self.column_count:
            return state # pass move

        a = action // self.row_count
        b = action % self.column_count

        # checking if the move is part of is the secondary move to a ko fight
        state = self.set_stone(a, b, state, player)
        # print(state)
        state = self.captures(state, -player, a, b)[1]
        return state
    
    def is_valid_move(self, state: list, action: tuple, player: int) -> bool:
        '''
        # Description:
        Checks if a move is valid.
        If a move repeats a previous state or commits suicide (gets captured without capturing back), it is not valid.
        
        A print will follow explaining the invalid move in case it exists.

        # Returns:
        A boolean confirming the validity of the move.
        '''

        a = action[0]
        b = action[1]

        #print(f"{a} , {b}")

        statecopy = np.copy(state).astype(np.int8)

        if state[a][b] != self.EMPTY:
            # print("Space Occupied")
            return False 


        statecopy = self.set_stone(a,b,statecopy,player)

        if self.captures(statecopy, -player, a, b)[0] == True:
            return True
        else:
            #print("no captures")
            libs, block = self.count(b,a,statecopy,player,[],[])
            #print(libs)
            if len(libs) == 0:
                #print("Invalid, Suicide")
                return False
            else:
                return True
        

    def get_valid_moves(self, state, player):
        '''
        # Description:
        Returns a matrix with the valid moves for the current player.
        '''
        newstate = np.zeros((self.row_count, self.column_count))
        for a in range(0, self.column_count):
            for b in range(0, self.row_count):
                if self.is_valid_move(state, (a,b), player):
                    newstate[a][b] = 1
        
        newstate = newstate.reshape(-1)
        newstate = np.concatenate([newstate, [1]])
        return (newstate).astype(np.int8)

    def get_value_and_terminated(self, state, action, player):
        '''
        # Description:
        Returns the value of the state and if the game is over.
        '''

        scoring, endgame = self.scoring(state)

        if endgame:
            if player == self.BLACK:
                if scoring > 0:
                    return 1, True
                else:
                    return -1, True
            else:
                if scoring < 0:
                    return 1, True
                else:
                    return -1, True
        else:
            if player == self.BLACK:
                if scoring > 0:
                    return 1, False
                else:
                    return -1, False
            else:
                if scoring < 0:
                    return 1, False
                else:
                    return -1, False


        
    def scoring(self, state):
        '''
        # Description:
        Checks the score of the game.
        '''
        black = 0
        white = 0
        empty = 0
        endgame = True
        # print("Scoring")
        for x in range(self.column_count):
            for y in range(self.row_count):
                if state[x][y] == self.EMPTY:
                    empty += 1
                    if empty >= self.column_count * self.row_count // 5: # if more than 1/4 of the board is empty, it is not the endgame
                        endgame = False

        black, white = self.count_influenced_territory_enhanced(state)
                            
        return black - (white + self.komi), endgame
    
    def count_influenced_territory_enhanced(self, board):
        black_territory = 0
        white_territory = 0
        visited = set()

        # Function to calculate influence score
        def influence_score(x, y):
            score = 0
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if 0 <= nx < len(board) and 0 <= ny < len(board[0]):
                    score += board[nx][ny]
            return score

        # Function to explore territory
        def explore_territory(x, y):
            nonlocal black_territory, white_territory
            if (x, y) in visited or not (0 <= x < len(board) and 0 <= y < len(board[0])):
                return
            visited.add((x, y))

            if board[x][y] == 0:
                score = influence_score(x, y)
                if score > 0:
                    black_territory += 1
                elif score < 0:
                    white_territory += 1

        for i in range(len(board)):
            for j in range(len(board[0])):
                if board[i][j] == 0 and (i, j) not in visited:
                    explore_territory(i, j)

        return black_territory, white_territory


    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def get_encoded_state(self, state):
        layer_1 = np.where(np.array(state) == -1, 1, 0).astype(np.float32)
        layer_2 = np.where(np.array(state) == 0, 1, 0).astype(np.float32)
        layer_3 = np.where(np.array(state) == 1, 1, 0).astype(np.float32)

        result = np.stack([layer_1, layer_2, layer_3]).astype(np.float32)

        return result
    
    def change_perspective(self, state, player):
        return state * player

### Graphical Interface Implementation

In [None]:
import pygame

data={'player1':(201,153,255),
      'player2':(179,236,255),
      }

SIZE_BOARD = 9
BLACK = (0,0,0)
WHITE = (255,255,255)
GREEN = (140, 217, 166)


pygame.init()
pygame_icon = pygame.image.load('image.png')
pygame.display.set_icon(pygame_icon)

SCREEN_SIZE=600
SCREEN_PADDING = 50
CELL_SIZE = (SCREEN_SIZE - SCREEN_PADDING) // SIZE_BOARD
PIECE_SIZE = (SCREEN_SIZE - 2*SCREEN_PADDING) // SIZE_BOARD // 3

screen=pygame.display.set_mode((SCREEN_SIZE,SCREEN_SIZE))

pygame.display.set_caption("gO depression")

def to_pixels(x):
    return SCREEN_PADDING + x*CELL_SIZE

def to_coord(x):
    quarter = CELL_SIZE//4
    closest = (x-SCREEN_PADDING)//CELL_SIZE
    if abs(to_pixels(closest)-(x-SCREEN_PADDING > to_pixels(closest)-(x-SCREEN_PADDING+quarter))):
        closest = (x-SCREEN_PADDING+quarter)//CELL_SIZE
    return closest

def draw_board():
    pygame.draw.rect(screen, GREEN, rect=(SCREEN_PADDING, SCREEN_PADDING, CELL_SIZE*(SIZE_BOARD-1), CELL_SIZE*(SIZE_BOARD-1)))
    for i in range(SIZE_BOARD):
        pygame.draw.line(screen, BLACK,(to_pixels(i),SCREEN_PADDING),(to_pixels(i),CELL_SIZE*(SIZE_BOARD-1) + SCREEN_PADDING),3)
        pygame.draw.line(screen, BLACK,(SCREEN_PADDING,to_pixels(i)),(CELL_SIZE*(SIZE_BOARD-1)+SCREEN_PADDING,to_pixels(i)),3)

def draw_piece(x,y,player):
    color = BLACK if player == -1 else WHITE
    pygame.draw.circle(screen,color,(to_pixels(x),to_pixels(y)),PIECE_SIZE)
    pygame.draw.circle(screen,BLACK,(to_pixels(x),to_pixels(y)),PIECE_SIZE,3)

def hover_to_select(player,valid_moves,click):
    mouse_x, mouse_y = pygame.mouse.get_pos()
    x, y = None, None
    if ([to_coord(mouse_x), to_coord(mouse_y)] in valid_moves):
        x, y = to_coord(mouse_x), to_coord(mouse_y)
    
    if (x!=None):
        pixels = (to_pixels(x),to_pixels(y))
        distance = pygame.math.Vector2(pixels[0] - mouse_x, pixels[1] - mouse_y).length()
        if distance < PIECE_SIZE:
            s = pygame.Surface((SCREEN_SIZE, SCREEN_SIZE), pygame.SRCALPHA)
            if player == 1:
                pygame.draw.circle(s,(255,255,255,200),(to_pixels(x),to_pixels(y)),PIECE_SIZE)
            if player == -1:
                pygame.draw.circle(s,(0,0,0,200),(to_pixels(x),to_pixels(y)),PIECE_SIZE)
            pygame.draw.circle(s,BLACK,(to_pixels(x),to_pixels(y)),PIECE_SIZE,3)
            screen.blit(s, (0, 0))
        if click:
            cur_pieces.append([x, y, player])
            valid_moves.remove([x, y])
            return [x, y, -1*player]
    return [None, None, player]

click = False
valid_moves = []
for i in range(SIZE_BOARD):
    for j in range(SIZE_BOARD):
        valid_moves.append([i, j])

cur_pieces = []
player = 1

while True:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()
        if event.type == pygame.MOUSEBUTTONDOWN:
            click = True
        if event.type == pygame.MOUSEBUTTONUP:
            click = False

    screen.fill(GREEN)
    draw_board()

    for piece in cur_pieces:
        draw_piece(piece[0], piece[1], piece[2])

    x, y, player = hover_to_select(player, valid_moves, click)

    pygame.display.flip()



## Attaxx

### Game Implementation

### Graphical Interface Implementation

## ResNet

## AlphaZero

## Main

In [None]:
os.chdir(os.path.dirname(os.path.abspath(__file__)))

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

SAVE_NAME = None

if __name__ == '__main__':

    GAME = input("Game: (Go/Attaxx) ")

    LOAD = input("Load:\nTrue will load a previous model, False will start from scratch (True/False):\n")
    if LOAD == 'True':
        LOAD = True
        SAVE_NAME = input("Alias of the model: ")
        MODEL = input("Model name: ")
        OPT = input("Optimizer name: ")
    else:
        LOAD = False
        SAVE_NAME = input("Alias of the new model: ")

    TEST = input("Test:\nTrue will play against the model, False will train the model (True/False):\n")
    if TEST == 'True':
        TEST = True	
    else:
        TEST = False

    if GAME == 'Go':
        args = {
            'game': 'Go',
            'num_iterations': 20,             # number of highest level iterations
            'num_selfPlay_iterations': 20,   # number of self-play games to play within each iteration
            'num_mcts_searches': 500,         # number of mcts simulations when selecting a move within self-play
            'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
            'num_epochs': 1200,                  # number of epochs for training on self-play data for each iteration
            'batch_size': 128,                # batch size for training
            'temperature': 1.30,              # temperature for the softmax selection of moves
            'C': 2,                           # the value of the constant policy
            'augment': True,                 # whether to augment the training data with flipped states
            'dirichlet_alpha': 0.03,           # the value of the dirichlet noise (alpha)
            'dirichlet_epsilon': 0.25,        # the value of the dirichlet noise (epsilon)
            'alias': ('Go' + SAVE_NAME)
        }

        game = Go()
        model = ResNet(game, 9, 3, device)
        optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    elif GAME == 'Attaxx':
        game_size = [5,5]
        args = {
            'game': 'Attaxx',
            'num_iterations': 10,              # number of highest level iterations
            'num_selfPlay_iterations': 1000,   # number of self-play games to play within each iteration
            'num_mcts_searches': 500,         # number of mcts simulations when selecting a move within self-play
            'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
            'num_epochs': 500,                  # number of epochs for training on self-play data for each iteration
            'batch_size': 500,                # batch size for training
            'temperature': 1.25,              # temperature for the softmax selection of moves
            'C': 2,                           # the value of the constant policy
            'augment': False,                 # whether to augment the training data with flipped states
            'dirichlet_alpha': 0.3,           # the value of the dirichlet noise
            'dirichlet_epsilon': 0.125,       # the value of the dirichlet noise
            'alias': ('Attaxx' + SAVE_NAME)
        }

        game = Attaxx(game_size)
        model = ResNet(game, 20, 48, device)
        optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    if LOAD:
        model.load_state_dict(torch.load(f'AlphaZero/Models/{GAME+SAVE_NAME}/{MODEL}.pt', map_location=device))
        #model.load_state_dict(torch.load(f'AlphaZero/Models/{GAME+SAVE_NAME}/{MODEL}.pt', map_location=torch.device('cpu')))
        optimizer.load_state_dict(torch.load(f'AlphaZero/Models/{GAME+SAVE_NAME}/{OPT}.pt', map_location=device))

    if not TEST:
        os.makedirs(f'AlphaZero/Models/{GAME+SAVE_NAME}', exist_ok=True)
        alphaZero = AlphaZero(model, optimizer, game, args)
        alphaZero.learn()
    else:
        if not LOAD:
            print("No model to test")
            exit()
        if GAME == 'Go':
            game = Go()

            model.load_state_dict(torch.load(f'AlphaZero/Models/{GAME+SAVE_NAME}/{MODEL}.pt'))
            mcts = MCTS(model, game, args)
            state = game.get_initial_state()
            game.print_board(state)

            player = 1

            while True:
                if player == 1:
                    a, b = tuple(int(x.strip()) for x in input("\nInput your move: ").split(' '))
                    print("\n")
                    action = a * 9 + b
                    state = game.get_next_state(state, action, player)
                else:
                    neut = game.change_perspective(state, player)
                    action = mcts.search(neut, player)
                    action = np.argmax(action)
                    print(f"\nAlphaZero Action: {action // game.row_count} {action % game.column_count}\n")
                    state = game.get_next_state(state, action, player)

                winner, win = game.get_value_and_terminated(state, action, player)
                if win:
                    game.print_board(state)
                    print(f"player {winner} wins")
                    exit()

                player = - player
                game.print_board(state)
            
        elif GAME == 'Attaxx':
            game = Attaxx([5,5])

            model.load_state_dict(torch.load(f'AlphaZero/Models/{GAME+SAVE_NAME}/{MODEL}.pt', map_location=device))
            mcts = MCTS(model, game, args)
            state = game.get_initial_state()
            game.print_board(state)

            player = 1

            while True:
                if player == 1:
                    move = tuple(int(x.strip()) for x in input("\nInput your move: ").split(' '))
                    print("\n")
                    action = game.move_to_int(move)
                    state = game.get_next_state(state, action, player)
                else:
                    #neut = game.change_perspective(state, player)
                    #print(neut)
                    action = mcts.search(state, player)
                    action = np.argmax(action)
                    print(f"\nAlphaZero Action: {game.int_to_move(action)}\n")
                    state = game.get_next_state(state, action, player)

                winner, win = game.get_value_and_terminated(state, action, player)
                if win:
                    game.print_board(state)
                    print(f"player {winner} wins")
                    exit()

                player = -player
                game.print_board(state)