# Alpha Zero For Generalized Game Reinforcement Learning

### Imports

In [None]:
import os
import math
import time
import json
import torch
import random
import pygame
import pickle
import socket
import numpy as np
import torch.nn as nn
from tqdm.notebook import trange
from torch.optim import Adam
import torch.nn.functional as F

## Go

Go is an ancient board game that is believed to be originated in China over 4000 years ago. It's usually played on a 19x19 grid board by two players, one using black stones and the other white stones. The game is known for its simple rules but depth in strategy. The primary objective in Go is to control more territory on the board than your opponent and this is gained by surrounding empty areas of the board with your stones. Players also aim to capture their opponent's stones by completely surrounding them.

### Game Implementation

In [None]:
class Go():

    EMPTY = 0
    BLACK = 1
    WHITE = -1
    BLACKMARKER = 4
    WHITEMARKER = 5
    LIBERTY = 8

    def __init__(self, size, komi):
        self.row_count = size
        self.column_count = size
        self.komi = 5.5
        self.action_size = self.row_count * self.column_count + 1
        self.liberties = []
        self.block = []
        self.seki_liberties = []
        
    def get_initial_state(self):
        '''
        # Description:
        Returns a board of the argument size filled of zeros.

        # Retuns:
        Empty board full of zeros
        '''
        board = np.zeros((self.row_count, self.column_count))
        return board
    

    def count(self, x, y, state: list, player:int , liberties: list, block: list) -> tuple[list, list]:
        '''
        # Description:
        Counts the number of liberties of a stone and the number of stones in a block.
        Follows a recursive approach to count the liberties of a stone and the number of stones in a block.

        # Returns:
        A tuple containing the number of liberties and the number of stones in a block.
        '''
        
        #initialize piece
        piece = state[y][x]
        #if there's a stone at square of the given player
        if piece == player:
            #save stone coords
            block.append((y,x))
            #mark the stone
            if player == self.BLACK:
                state[y][x] = self.BLACKMARKER
            else:
                state[y][x] = self.WHITEMARKER
            
            #look for neighbours recursively
            if y-1 >= 0:
                liberties, block = self.count(x,y-1,state,player,liberties, block) #walk north
            if x+1 < self.column_count:
                liberties, block = self.count(x+1,y,state,player,liberties, block) #walk east
            if y+1 < self.row_count:
                liberties, block = self.count(x,y+1,state,player,liberties, block) #walk south
            if x-1 >= 0:
                liberties, block = self.count(x-1,y,state,player,liberties, block) #walk west

        #if square is empty
        elif piece == self.EMPTY:
            #mark liberty
            state[y][x] = self.LIBERTY
            #save liberties
            liberties.append((y,x))

        # print("Liberties: " + str(len(self.liberties)) + " in: " + str(x) + "," + str(y))
        # print("Block: " + str(len(self.block)) + " in: " + str(x) + "," + str(y))
        return liberties, block

    #remove captured stones
    def clear_block(self, block: list, state: list) -> list:
        '''
        # Description:
        Clears the block of stones captured by the opponent on the board.

        # Returns:
        The board with the captured stones removed.
        '''

        #clears the elements in the block of elements which is captured
        for i in range(len(block)): 
            y, x = block[i]
            state[y][x] = self.EMPTY
        
        return state

    #restore board after counting stones and liberties
    def restore_board(self, state: list) -> list:
        '''
        # Description:
        Restores the board to its original state after counting liberties and stones.
        This is done by unmarking the stones following bitwise operations with the global class variables.
        
        # Returns:
        The board with the stones unmarked.
        '''

        #unmark stones
        # print("Restore Board")
        # print(state)
        for y in range(len(state)):
            for x in range(len(state)):
                #restore piece
                val = state[y][x]
                if val == self.BLACKMARKER:
                    state[y][x] = self.BLACK
                elif val == self.WHITEMARKER:
                    state[y][x] = self.WHITE
                elif val == self.LIBERTY:
                    state[y][x] = self.EMPTY

        # print("After Restore Board")
        # print(state)
        return state

    def print_board(self, state: list) -> None:
            '''
            # Description:
            Draws the board in the console.

            # Returns:
            None
            '''

        # Print column coordinates
            print("   ", end="")
            for j in range(len(state[0])):
                print(f"{j:2}", end=" ")
            print("\n  +", end="")
            for _ in range(len(state[0])):
                print("---", end="")
            print()

            # Print rows with row coordinates
            for i in range(len(state)):
                print(f"{i:2}|", end=" ")
                for j in range(len(state[0])):
                    print(f"{str(int(state[i][j])):2}", end=" ")
                print()
    
    def captures(self, state: list,player: int, a:int, b:int) -> tuple[bool, list]:
        '''
        # Description:
        Checks if a move causes a capture of stones of the player passed as an argument.
        If a move causes a capture, the stones are removed from the board.

        # Returns:
        A tuple containing a boolean indicating if a capture has been made and the board with the captured stones removed.
        '''
        check = False
        neighbours = []
        if(a > 0): neighbours.append((a-1, b))
        if(a < self.column_count - 1): neighbours.append((a+1, b))
        if(b > 0): neighbours.append((a, b - 1))
        if(b < self.row_count - 1): neighbours.append((a, b+1))

        #loop over the board squares
        for pos in neighbours:
            # print(pos)
            x = pos[0]
            y = pos[1]    
            # init piece
            piece = state[x][y]

                #if stone belongs to given colour
            if piece == player:
                # print("opponent piece")
                # count liberties
                liberties = []
                block = []
                liberties, block = self.count(y, x, state, player, liberties, block)
                # print("Liberties in count: " + str(len(liberties)))
                # if no liberties remove the stones
                if len(liberties) == 0: 
                    #clear block
                    state = self.clear_block(block, state)
                    check = True

                #restore the board
                state = self.restore_board(state)

        #print("Captures: " + str(check))
        return check, state
    
    def set_stone(self, a, b, state, player):
        '''
        # Description:
        Places the piece on the board. THIS DOES NOT account for the rules of the game, use get_next_state().

        # Retuns:
        Board with the piece placed.
        '''
        state[a][b] = player
        return state
    
    def get_next_state(self, state, action, player):
        '''
        # Description
        Plays the move, verifies and undergoes captures and saves the state to the history.
        
        # Returns:
        New state with everything updated.
        '''
        if action == self.row_count * self.column_count:
            return state # pass move

        a = action // self.row_count
        b = action % self.column_count

        # checking if the move is part of is the secondary move to a ko fight
        state = self.set_stone(a, b, state, player)
        # print(state)
        state = self.captures(state, -player, a, b)[1]
        return state
    
    def is_valid_move(self, state: list, action: tuple, player: int) -> bool:
        '''
        # Description:
        Checks if a move is valid.
        If a move repeats a previous state or commits suicide (gets captured without capturing back), it is not valid.
        
        A print will follow explaining the invalid move in case it exists.

        # Returns:
        A boolean confirming the validity of the move.
        '''

        a = action[0]
        b = action[1]

        #print(f"{a} , {b}")

        statecopy = np.copy(state).astype(np.int8)

        if state[a][b] != self.EMPTY:
            # print("Space Occupied")
            return False 


        statecopy = self.set_stone(a,b,statecopy,player)

        if self.captures(statecopy, -player, a, b)[0] == True:
            return True
        else:
            #print("no captures")
            libs, block = self.count(b,a,statecopy,player,[],[])
            #print(libs)
            if len(libs) == 0:
                #print("Invalid, Suicide")
                return False
            else:
                return True
        

    def get_valid_moves(self, state, player):
        '''
        # Description:
        Returns a matrix with the valid moves for the current player.
        '''
        newstate = np.zeros((self.row_count, self.column_count))
        for a in range(0, self.column_count):
            for b in range(0, self.row_count):
                if self.is_valid_move(state, (a,b), player):
                    newstate[a][b] = 1
        
        newstate = newstate.reshape(-1)

        empty = 0
        endgame = True
        
        for x in range(self.column_count):
            for y in range(self.row_count):
                if state[x][y] == self.EMPTY:
                    empty += 1
                    if empty >= self.column_count * self.row_count // 3: # if 2/3ds are already filled, skipping becomes available
                        endgame = False
                        break
        if endgame:
            newstate = np.concatenate([newstate, [1]])
        else:
            newstate = np.concatenate([newstate, [0]])
        return (newstate).astype(np.int8)

    def get_value_and_terminated(self, state, action, player):
        '''
        # Description:
        Returns the value of the state and if the game is over.
        '''

        scoring, endgame = self.scoring(state)

        if endgame:
            if player == self.BLACK:
                if scoring > 0:
                    return 1, True
                else:
                    return -1, True
            else:
                if scoring < 0:
                    return 1, True
                else:
                    return -1, True
        else:
            if player == self.BLACK:
                if scoring > 0:
                    return 1, False
                else:
                    return -1, False
            else:
                if scoring < 0:
                    return 1, False
                else:
                    return -1, False


        
    def scoring(self, state: list) -> int:
        '''
        # Description:
        Checks the score of the game. Score is calculated using:

        black - (white + komi)

        # Returns:
        Integer with score.
        '''
        black = 0
        white = 0
        empty = 0
        endgame = True

        for x in range(self.column_count):
            for y in range(self.row_count):
                if state[x][y] == self.EMPTY:
                    empty += 1
                    if empty >= self.column_count * self.row_count // 4:
                        endgame = False
                        break

        black, white = self.count_influenced_territory_enhanced(state)
        black_eyes, black_strong_groups = self.count_eyes_and_strong_groups(state, self.BLACK)
        white_eyes, white_strong_groups = self.count_eyes_and_strong_groups(state, self.WHITE)
        # print(f"Black | Territory: {black} Eyes: {black_eyes} Strong Groups: {black_strong_groups}")
        # print(f"White | Territory: {white} Eyes: {white_eyes} Strong Groups: {white_strong_groups}")
        
        black += black_eyes + black_strong_groups
        white += white_eyes + white_strong_groups
        
        return black - (white + self.komi), endgame
    
    def count_influenced_territory_enhanced(self, board: list) -> tuple[int, int]:
        '''
        # Description 
        Calculates the territory influenced by black and white players on the Go board.

        This function iterates through the board, analyzing each empty point to determine 
        if it's influenced by the surrounding black or white stones. The influence is calculated
        based on the adjacent stones, with positive scores indicating black influence and negative
        scores indicating white influence.

        # Returns:
        Tuple (black_territory, white_territory)
        '''
        black_territory = 0
        white_territory = 0
        visited = set()

        # Function to calculate influence score
        def influence_score(x, y):
            score = 0
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if 0 <= nx < len(board) and 0 <= ny < len(board[0]):
                    score += board[nx][ny]
            return score

        # Function to explore territory
        def explore_territory(x, y):
            nonlocal black_territory, white_territory
            if (x, y) in visited or not (0 <= x < len(board) and 0 <= y < len(board[0])):
                return
            visited.add((x, y))

            if board[x][y] == 0:
                score = influence_score(x, y)
                if score > 0:
                    black_territory += 1
                elif score < 0:
                    white_territory += 1

        for i in range(len(board)):
            for j in range(len(board[0])):
                if board[i][j] == 0 and (i, j) not in visited:
                    explore_territory(i, j)

        return black_territory, white_territory
    
    def is_eye(self, board, x, y, player):

        # An eye is an empty point with all adjacent points of the player's color
        # and at least one diagonal point of the player's color.
        
        if board[x][y] != self.EMPTY:
            return False
        
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            nx, ny = x + dx, y + dy
            if not (0 <= nx < len(board) and 0 <= ny < len(board[0])):
                continue
            if board[nx][ny] != player:
                return False
            
        true_eye = False
        count = 0
        for dx, dy in [(1, 1), (1, -1), (-1, 1), (-1, -1)]:
            nx, ny = x + dx, y + dy

            if 0 <= nx < len(board) and 0 <= ny < len(board[0]) and board[nx][ny] == player:
                count += 1
                if count >= 3:
                    true_eye = True


        return true_eye

    def count_eyes_and_strong_groups(self, board, player):
        eyes = 0
        strong_groups = 0
        visited = set()

        def dfs(x, y):
            if (x, y) in visited or board[x][y] != player:
                return 0

            visited.add((x, y))
            liberties = 0
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if not (0 <= nx < len(board) and 0 <= ny < len(board[0])):
                    continue
                if board[nx][ny] == self.EMPTY:
                    liberties += 1
                elif board[nx][ny] == player:
                    liberties += dfs(nx, ny)

            return liberties

        for x in range(len(board)):
            for y in range(len(board[0])):
                if board[x][y] == player and (x, y) not in visited:
                    liberties = dfs(x, y)
                    if liberties >= 2:  # Arbitrary threshold for a strong group
                        strong_groups += 1
                if board[x][y] != player and (x, y) not in visited and self.is_eye(board, x, y, player):
                    eyes += 1

        return eyes, strong_groups


    def get_opponent(self, player):
        '''
        # Description:
        Changes Opponent
        '''
        return -player
    
    def get_opponent_value(self, value):
        '''
        # Description
        Returns the negative value of the value
        '''
        return -value
    
    def get_encoded_state(self, state):
        '''
        # Description: 
        Converts the current state of the Go board into a 3-layer encoded format suitable for neural network input.
        Each layer in the encoded format represents the presence of a specific type of stone or an empty space on the board:
        - Layer 1 encodes the positions of white stones (represented by -1 in the input state) as 1s, and all other positions as 0s.
        - Layer 2 encodes the positions of empty spaces (represented by 0 in the input state) as 1s, and all other positions as 0s.
        - Layer 3 encodes the positions of black stones (represented by 1 in the input state) as 1s, and all other positions as 0s.
        This encoding helps in clearly distinguishing between different elements on the board for machine learning applications.

        # Returns: 
        A NumPy array of shape (3, height, width) containing the 3-layer encoded representation of the board state. Each layer is a 2D array where the board's height and width correspond to the dimensions of the original state.
        '''
        layer_1 = np.where(np.array(state) == -1, 1, 0).astype(np.float32)
        layer_2 = np.where(np.array(state) == 0, 1, 0).astype(np.float32)
        layer_3 = np.where(np.array(state) == 1, 1, 0).astype(np.float32)

        result = np.stack([layer_1, layer_2, layer_3]).astype(np.float32)

        return result
    
    def change_perspective(self, state, player):
        '''
        # Description: 
        Adjusts the perspective of the Go board state based on the current player.

        # Returns: 
        A two-dimensional array representing the Go board state adjusted for the current player's perspective.
        '''
        return state * player

### Graphical Interface Implementation

In [None]:
SIZE_BOARD = 9
BLACK = (0,0,0)
WHITE = (255,255,255)
WOOD = (216, 166, 91)
SCREEN_SIZE = 600
SCREEN_PADDING = 50
CELL_SIZE = (SCREEN_SIZE - SCREEN_PADDING) // SIZE_BOARD
PIECE_SIZE = (SCREEN_SIZE - 2*SCREEN_PADDING) // SIZE_BOARD // 3
screen=pygame.display.set_mode((SCREEN_SIZE,SCREEN_SIZE))

def goto_pixels(x):
    return SCREEN_PADDING + x*CELL_SIZE

def goto_coord(x):
    quarter = CELL_SIZE//4
    closest = (x-SCREEN_PADDING)//CELL_SIZE
    if abs(goto_pixels(closest)-(x-SCREEN_PADDING > goto_pixels(closest)-(x-SCREEN_PADDING+quarter))):
        closest = (x-SCREEN_PADDING+quarter)//CELL_SIZE
    return closest

def godraw_board():
    pygame.draw.rect(screen, WOOD, rect=(SCREEN_PADDING, SCREEN_PADDING, CELL_SIZE*(SIZE_BOARD-1), CELL_SIZE*(SIZE_BOARD-1)))
    for i in range(SIZE_BOARD):
        pygame.draw.line(screen, BLACK,(goto_pixels(i),SCREEN_PADDING),(goto_pixels(i),CELL_SIZE*(SIZE_BOARD-1) + SCREEN_PADDING),3)
        pygame.draw.line(screen, BLACK,(SCREEN_PADDING,goto_pixels(i)),(CELL_SIZE*(SIZE_BOARD-1)+SCREEN_PADDING,goto_pixels(i)),3)

def godraw_piece(x,y,player):
    color = BLACK if player == -1 else WHITE
    pygame.draw.circle(screen,color,(goto_pixels(x),goto_pixels(y)),PIECE_SIZE)
    pygame.draw.circle(screen,BLACK,(goto_pixels(x),goto_pixels(y)),PIECE_SIZE,3)

def gohover_to_select(player,valid_moves,click):
    mouse_x, mouse_y = pygame.mouse.get_pos()
    x, y = None, None
    if ([goto_coord(mouse_x), goto_coord(mouse_y)] in valid_moves):
        x, y = goto_coord(mouse_x), goto_coord(mouse_y)
    
    if (x!=None):
        pixels = (goto_pixels(x),goto_pixels(y))
        distance = pygame.math.Vector2(pixels[0] - mouse_x, pixels[1] - mouse_y).length()
        if distance < PIECE_SIZE:
            s = pygame.Surface((SCREEN_SIZE, SCREEN_SIZE), pygame.SRCALPHA)
            if player == 1:
                pygame.draw.circle(s,(255,255,255,200),(goto_pixels(x),goto_pixels(y)),PIECE_SIZE)
            if player == -1:
                pygame.draw.circle(s,(0,0,0,200),(goto_pixels(x),goto_pixels(y)),PIECE_SIZE)
            pygame.draw.circle(s,BLACK,(goto_pixels(x),goto_pixels(y)),PIECE_SIZE,3)
            screen.blit(s, (0, 0))

    return [None, None, player]

## Attaxx

Attaxx is a strategy board game that was developed in 1990 by the French company Taito. The game is played on an 7x7 grid, and each player starts with two pieces placed in opposite corners of the board. The objective of Attaxx is to have more pieces on the board than your opponent when the game concludes.

### Game Implementation

In [None]:
class Attaxx:
    def __init__(self, args):
        self.column_count = args[0]
        self.row_count = args[1]
        self.action_size = (self.column_count * self.row_count) ** 2 + 1
    
    def get_initial_state(self):
        state = np.zeros((self.column_count, self.row_count))
        state[0][0] = 1
        state[self.column_count-1][self.row_count-1] = 1
        state[0][self.column_count-1] = -1
        state[self.row_count-1][0] = -1
        return state
    
    def get_next_state(self, state, action, player):
        if action == self.action_size - 1:
            return state
        move = self.int_to_move(action)
        a, b, a1, b1 = move[0], move[1], move[2], move[3]
        if abs(a-a1)==2 or abs(b-b1)==2:
            state[a][b] = 0
            state[a1][b1] = player
        else:
            state[a1][b1] = player
        self.capture_pieces(state, move, player)
        return state

    def is_valid_move(self, state, action, player):
        a, b, a1, b1 = action
        if (a==a1 and b==b1):
            return False
        if abs(a-a1)>2 or abs(b-b1)>2 or state[a1][b1]!=0 or state[a][b]!=player or ((abs(a-a1)==1 and abs(b-b1)==2) or (abs(a-a1)==2 and abs(b-b1)==1)):
            return False
        return True

    def capture_pieces(self, state, action, player):
        a, b, a1, b1 = action
        for i in range(a1-1, a1+2):
            for j in range(b1-1, b1+2):
                try:
                    if state[i][j]==-player and i>=0 and j>=0:
                        state[i][j] = player
                except IndexError:
                    pass
                continue

    def check_available_moves(self, state, player):
        for i in range(self.column_count):
            for j in range(self.row_count):
                if state[i][j] == player:
                    for a in range(self.column_count):
                        for b in range(self.row_count):
                            action = (i, j, a, b)
                            if self.is_valid_move(state, action, player):
                                return True
        return False

    def move_to_int(self, move):
        return move[3] + move[2]*self.column_count + move[1]*self.column_count**2 + move[0]*self.column_count**3

    def int_to_move(self, num):
        move = [(num // self.column_count**3) % self.column_count, 
                (num // self.column_count**2) % self.column_count, 
                (num // self.column_count) % self.column_count, 
                num % self.column_count]
        return move

    
    def get_valid_moves(self, state, player):
        possible_moves = set()

        for i in range(self.column_count):
            for j in range(self.row_count):
                state[i][j] = int(state[i][j])
                if state[i][j] == player:
                    moves_at_point = set(self.get_moves_at_point(state, player, i, j))
                    possible_moves = possible_moves.union(moves_at_point)
        
        possible_moves_to_int = []
        for move in possible_moves:
            possible_moves_to_int.append(self.move_to_int(move))
        binary_representation = [1 if i in possible_moves_to_int else 0 for i in range(self.action_size)]

        return binary_representation

    def get_moves_at_point(self, state, player, a, b):
        moves_at_point = []
        for i in range(self.column_count):
            for j in range(self.row_count):
                possible_action = (a, b, i, j)
                if self.is_valid_move(state, possible_action, player):
                    moves_at_point.append(possible_action)
        return moves_at_point 

    def check_board_full(self, state):
        for row in state:
            if 0 in row:
                return False
        
        return True

    def check_win_and_over(self, state, action):
        # action is not necessary for attaxx, but is necessary for go

        count_player1 = 0
        count_player2 = 0

        for i in range(self.column_count):
            for j in range(self.row_count):
                if state[i][j] == 1:
                    count_player1+=1
                elif state[i][j] == -1:
                    count_player2+=1
        if count_player1 == 0:
            return -1, True
        elif count_player2 == 0:
            return 1, True
        
        if self.check_board_full(state):
            if count_player1>count_player2:
                return 1, True
            elif count_player2>count_player1:
                return -1, True
            elif count_player1==count_player2:
                return 2, True
        
        return 0, False
    
    def get_value_and_terminated(self, state, action, player):
        winner, game_over = self.check_win_and_over(state, action = None)
        return winner, game_over
    
    def print_board(self, state):
        state = state.astype(int)
        # Print column coordinates
        print("   ", end="")
        for j in range(len(state[0])):
            print(f"{j:2}", end=" ")
        print("\n  +", end="")
        for _ in range(len(state[0])):
            print("---", end="")
        print()

        # Print rows with row coordinates
        for i in range(len(state)):
            print(f"{i:2}|", end=" ")
            for j in range(len(state[0])):
                print(f"{str(state[i][j]):2}", end=" ")
            print()

    def get_encoded_state(self, state):
        layer_1 = np.where(np.array(state) == -1, 1, 0).astype(np.float32) #returns same sized board replacing all -1 with 1 and all other positions with 0
        layer_2 = np.where(np.array(state) == 0, 1, 0).astype(np.float32) #same logic for each possible number in position (-1, 1, or 0)
        layer_3 = np.where(np.array(state) == 1, 1, 0).astype(np.float32)
        
        result = np.stack([layer_1, layer_2, layer_3]).astype(np.float32) #encoded state
        
        return result

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        return state * player 

### Graphical Interface Implementation

In [None]:
SIZE_BOARD = 6
RED = (238, 167, 255)
BLUE = (113, 175, 255)
GRAY = (115, 115, 115)
BLACK = (0, 0, 0)
SCREEN_SIZE=600
SCREEN_PADDING = 100
CELL_SIZE = (SCREEN_SIZE - SCREEN_PADDING) // SIZE_BOARD
PIECE_SIZE = (SCREEN_SIZE - 2*SCREEN_PADDING) // SIZE_BOARD // 3
screen=pygame.display.set_mode((SCREEN_SIZE,SCREEN_SIZE))

def atto_pixels(x):
    return SCREEN_PADDING + x*CELL_SIZE

def atto_coord(x):
    quarter = CELL_SIZE//4
    closest = (x-SCREEN_PADDING)//CELL_SIZE
    if abs(atto_pixels(closest)-(x-SCREEN_PADDING > atto_pixels(closest)-(x-SCREEN_PADDING+quarter))):
        closest = (x-SCREEN_PADDING+quarter)//CELL_SIZE
    return closest

def atdraw_board():
    pygame.draw.rect(screen, GRAY, rect=(SCREEN_PADDING, SCREEN_PADDING, CELL_SIZE * (SIZE_BOARD - 1), CELL_SIZE * (SIZE_BOARD - 1)))
    for i in range(SIZE_BOARD + 1):
        # Draw vertical lines
        pygame.draw.line(screen, BLACK,
                         (SCREEN_PADDING - (PIECE_SIZE*2) + i * CELL_SIZE, SCREEN_PADDING - PIECE_SIZE*2 ),
                         (SCREEN_PADDING - PIECE_SIZE*2 + i * CELL_SIZE, SCREEN_PADDING + CELL_SIZE * (SIZE_BOARD) - PIECE_SIZE*2),
                         3)

        # Draw horizontal lines
        pygame.draw.line(screen, BLACK,
                         (SCREEN_PADDING - PIECE_SIZE*2, SCREEN_PADDING - PIECE_SIZE*2+ i * CELL_SIZE),
                         (SCREEN_PADDING - PIECE_SIZE*2 + CELL_SIZE * (SIZE_BOARD), SCREEN_PADDING - PIECE_SIZE*2 + i * CELL_SIZE),
                         3)



def atdraw_piece(x,y,player):
    color = RED if player == -1 else BLUE
    pygame.draw.circle(screen,color,(atto_pixels(x),atto_pixels(y)),PIECE_SIZE)
    pygame.draw.circle(screen,BLACK,(atto_pixels(x),atto_pixels(y)),PIECE_SIZE,3)


def athover_to_select(sel_x, sel_y, player, valid_moves, click, selected_piece, cur_pieces, last_click_time):

    current_time = pygame.time.get_ticks()
    mouse_x, mouse_y = pygame.mouse.get_pos()
    x, y = None, None
    if ([atto_coord(mouse_x), atto_coord(mouse_y), player] in cur_pieces):
        x, y = atto_coord(mouse_x), atto_coord(mouse_y)

        if click and current_time - last_click_time > 100:  # 100 milliseconds debounce
            if selected_piece:
                if x == sel_x and y == sel_y: #deselection
                    selected_piece = False
                    sel_x = -1
                    sel_y = -1
                    
            else: #selection
                selected_piece = True
                sel_x = x
                sel_y = y

            last_click_time = current_time
    
    if click and current_time - last_click_time > 100 and [atto_coord(mouse_x), atto_coord(mouse_y)] in valid_moves and selected_piece:

        cur_pieces.append([atto_coord(mouse_x), atto_coord(mouse_y),player])
        player = -player
        selected_piece = False
    
    # Draw hollow circles on valid moves if a piece is selected
    if selected_piece:
        for move in valid_moves:
            px, py = atto_pixels(move[0]), atto_pixels(move[1])
            s = pygame.Surface((SCREEN_SIZE, SCREEN_SIZE), pygame.SRCALPHA)
            pygame.draw.circle(s, (113, 175, 255, 100) if player == 1 else (238, 167, 255, 100), (px, py), PIECE_SIZE, 3)  # Change color as needed
            screen.blit(s, (0, 0))
    return [sel_x, sel_y, player, selected_piece, last_click_time]

## MCTS

Monte Carlo Tree Search (MCTS) is a heuristic search algorithm commonly used in decision-making processes for games and other domains. It iteratively builds a search tree by simulating random plays and selecting moves based on statistical evaluation. The key steps of MCTS include selection, expansion, simulation, and backpropagation. This algorithm has proven effective in games like Go and chess, providing a balance between exploration and exploitation to find optimal strategies.

For this project, and in the context of AlphaZero, this process is aided by a Neural Network.

In [None]:
class Node:
    '''
    # Alpha Zero Node
    ## Description:
        A node for the AlphaZero MCTS. It contains the state, the action taken to get to the state, the prior probability of the action, the visit count, the value sum, and the children of the node.
    ## Metohds:
        - `is_expanded()`: Returns whether the node has been expanded.
        - `select()`: Selects the best child node based on the UCB.
        - `get_ucb()`: Returns the UCB of a child node.
        - `expand()`: Expands the node by adding children.
        - `backpropagate()`: Backpropagates the value of the node to the parent node.
        '''
    def __init__(self, game, args, state, player, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.player = player
        self.prior = prior
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_expanded(self):
        '''
        # is_expanded
        ## Description:
            Returns whether the node has been expanded.
        ## Returns:
            - `bool`: Whether the node has been expanded.'''
        return len(self.children) > 0
    
    def select(self):
        '''
        # Description: 
        Selects the best child node from the current node's children in a Monte Carlo Tree Search using the Upper Confidence Bound (UCB) algorithm. 

        # Returns: 
        The best child node, chosen based on the highest UCB value or randomly if there's a tie.
        '''
        best_child = []
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = [child]
                best_ucb = ucb
            elif ucb == best_ucb:
                best_child.append(child)
                
        return best_child[0] if len(best_child) == 1 else random.choice(best_child)
    
    def get_ucb(self, child):
        '''
        # Description: 
        Calculates the Upper Confidence Bound (UCB) value for a given child node in a Monte Carlo Tree Search.

        # Returns: 
        The calculated UCB value for the given child node.
        '''
        if child.visit_count == 0:
            q_value = child.prior * self.args['C'] * (math.sqrt(self.visit_count)) / (child.visit_count + 1)
        else:
            q_value = -(child.value_sum / child.visit_count) + child.prior * self.args['C'] * (math.sqrt(self.visit_count)) / (child.visit_count + 1)
        return q_value

    def serialize(self):
        # Serialize only essential data
        node_data = {
            'game': self.game,
            'args': self.args,
            'parent': self.parent,
            'state': self.state,
            'action_taken': self.action_taken,
            'player': self.player,
            'prior': self.prior,
            'visit_count': self.visit_count,
            'value_sum': self.value_sum,
            'children': [child for child in self.children]  # Assuming each child has a unique ID
        }
        return json.dumps(node_data)


    def deserialize(node_json):
        # Convert JSON back into a Node object
        node_data = json.loads(node_json)
        node = Node(  # assuming constructor can handle this data
            game = node_data['game'],
            args = node_data['args'],
            parent = node_data['parent'],
            player = node_data['player'],
            state=node_data['state'],
            action_taken=node_data['action_taken'],
            prior=node_data['prior'],
            visit_count=node_data['visit_count'],
        )
        node.value_sum = node_data['value_sum']

        for child in node_data['children']:
            child.parent = node
            node.children.append(child)

        # You'll need to handle children reconstruction separately
        return node
    
    def expand(self, policy):
        '''
        # Description: 
        Expands the current node by adding new child nodes based on the given policy probabilities. For each possible action, it calculates the next state, adjusts the perspective for the opponent, and creates a new child node if the probability for that action is greater than zero.

        # Returns: 
        None
        '''
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)
                child = Node(self.game, self.args, child_state, self.game.get_opponent(self.player), self, action, prob)
                self.children.append(child)
            
    def backpropagate(self, value):
        '''
        # Description: 
        Performs the backpropagation step in Monte Carlo Tree Search. It updates the current node's value sum and visit count based on the received value.

        # Returns: 
        None
        '''
        self.value_sum += value
        self.visit_count += 1
        
        if self.parent is not None:
            value = self.game.get_opponent_value(value)
            self.parent.backpropagate(value)  

class MCTS:
    def __init__(self, model, game, args):
        self.model = model
        self.game = game
        self.args = args
        self.tree_dict = {}
        
    @torch.no_grad()
    def search(self, states, player):
        """
        # Description:
        Performs Monte Carlo Tree Search (MCTS) in batch to find the best action.

        # Returns:
        An array of arrays of action probabilities for each possible action.
        """

        action_prob_list = []

        for state in states:
            
            root = Node(self.game, self.args, state, player, visit_count=1)

            searches = self.args['num_mcts_searches']

            if str(state)+str(player) in self.tree_dict.keys():
                action_prob_list.append(self.tree_dict.get(str(state)+str(player)))
                continue
            
            policy, _ = self.model(
                torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
            )
            policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
            policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
                * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
                
            valid_moves = self.game.get_valid_moves(state, player)

            if self.args["game"] == "Attaxx":
                if np.sum(valid_moves) == 0:
                    valid_moves[-1] = 1
                else:
                    valid_moves[-1] = 0

            policy *= valid_moves
            policy /= np.sum(policy)
            root.expand(policy)
                
            for search in range(searches):
                node = root
                while node.is_expanded():
                    node = node.select()
                    
                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken, node.player)
                value = self.game.get_opponent_value(value)
                    
                if node.parent is not None:
                    if node.action_taken == self.game.action_size - 1 and node.parent.action_taken == self.game.action_size - 1 and self.args['game'] == 'Go':
                        is_terminal = True # if the action is pass when the previous action was also pass, end the game

                if not is_terminal:
                    policy, value = self.model(torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0))
                    policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                    valid_moves = self.game.get_valid_moves(node.state, player)

                    if self.args["game"] == "Attaxx":
                        if np.sum(valid_moves) == 0:
                            valid_moves[-1] = 1
                        else:
                            valid_moves[-1] = 0

                    policy *= valid_moves
                    policy /= np.sum(policy)
                        
                    value = value.item()
                    node.expand(policy)

                node.backpropagate(value)

            action_probs = np.zeros(self.game.action_size)

            for child in root.children:
                action_probs[child.action_taken] = child.visit_count

            action_probs /= np.sum(action_probs)
            action_prob_list.append(action_probs)

            self.tree_dict.update({str(state)+str(player): action_probs})

        return action_prob_list

## ResNet


A Residual Neural Network (ResNet) is a deep learning architecture designed to address the challenges of training very deep networks. It introduces residual blocks with skip connections, enabling the network to skip layers during training. This helps in overcoming vanishing gradient issues, allowing the training of deep models.

In [None]:
class ResNet(nn.Module):
    '''
    # ResNet
    ## Description:
        A ResNet model for AlphaZero.
        The model takes in a state and outputs a policy and value.
         - The policy is a probability distribution over all possible actions.
         - The value is a number between -1 and 1, where -1 means the current player loses and 1 means the current player wins following a tanh activation.
        '''
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        self.device = device

        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding="same"),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding="same"),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding="same"),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
        self.to(device)
        
    def forward(self, x):
        '''
        # Description:
        The forward pass of the model. This overrides the forward method of nn.Module so that it can be called directly on the model.

        # Returns:
        - `policy`: The policy output of the model.
        - `value`: The value output of the model.
        '''
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
        
class ResBlock(nn.Module):
    '''
    # Description:
    A residual block for the ResNet model.
    '''
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding="same")
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding="same")
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        """
        # Description:
        Forward pass through the residual block.

        # Returns:
        Output tensor after passing through the block.
        """
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x

## AlphaZero

Alpha Zero is an algorithm introduced by DeepMind, starting from random play and given no domain knowledge except the game rules, a trained agent is capable of achieving superhuman level of performance, it completes this goal by combining a Monte Carlo Tree Search (MCTS) and a Neural Network in a policy iteration framework to achieve stable learning. Combining these elements an agent can then learn through self-play.

In [None]:
class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(model, game, args)

    def augment_state(self, state, probs):

        augmented_states = []

        skip_prob = probs[-1]
        action_probs_matrix = np.array(probs[:-1]).reshape(self.game.column_count, self.game.row_count)
        augmented_action_probs = []

        def augment_and_append(transformed_state, transformed_probs_matrix):

            # Append state
            augmented_states.append(transformed_state)

            # Flatten probs matrix, append the last value, and then append to augmented_action_probs
            augmented_action_probs.append(list(transformed_probs_matrix.flatten()) + [skip_prob])

        # Original state and probs
        augment_and_append(state, action_probs_matrix)

        # Rotate 90 degrees clockwise
        augment_and_append(np.rot90(state, k=1), np.rot90(action_probs_matrix, k=1))

        # Rotate 180 degrees clockwise
        augment_and_append(np.rot90(state, k=2), np.rot90(action_probs_matrix, k=2))

        # Rotate 270 degrees clockwise
        augment_and_append(np.rot90(state, k=3), np.rot90(action_probs_matrix, k=3))

        # Flip horizontally
        augment_and_append(np.fliplr(state), np.fliplr(action_probs_matrix))

        # Flip vertically
        augment_and_append(np.flipud(state), np.flipud(action_probs_matrix))

        # Rotate 90 degrees clockwise and flip horizontally
        augment_and_append(np.rot90(np.fliplr(state), k=1), np.rot90(np.fliplr(action_probs_matrix), k=1))

        # Rotate 90 degrees clockwise and flip vertically
        augment_and_append(np.rot90(np.flipud(state), k=1), np.rot90(np.flipud(action_probs_matrix), k=1))

        return augmented_states, augmented_action_probs


    def selfPlay(self):
        player = 1

        memory = []
        states = []

        for _ in range(0, self.args['parallel_games']):
            state = self.game.get_initial_state()
            states.append(state)
            memory.append([])

        iter = 0
        prev_skip = False
        temperature = self.args['temperature']
        debugging = False

        returnData = []

        while True:
            if self.args["game"] == "Attaxx" and debugging:
                print("\nSEARCHING...")

            neutral_states_list = []

            for state in states:
                neutral_states_list.append(self.game.change_perspective(state, player))

            action_probs_list = self.mcts.search(states, player)

            for i, (neutral_state, action_probs) in enumerate(zip(neutral_states_list, action_probs_list)):
                memory[i].append((neutral_state, action_probs, player))

            for idx, (state, action_probs) in enumerate(zip(states, action_probs_list)):
                temperature_action_probs = action_probs ** (1 / temperature)
                temperature_action_probs /= np.sum(temperature_action_probs)

                action = np.random.choice(self.game.action_size, p=temperature_action_probs)

                state = self.game.get_next_state(state, action, player)

                if self.args["game"] == "Attaxx" and debugging:
                    print(f"Player: {player} with move {self.game.int_to_move(action)}\nBoard:")
                    self.game.print_board(state)    

                value, is_terminal = self.game.get_value_and_terminated(state, action, player)
                    

                if action == self.game.action_size - 1 and self.args['game'] == 'Go':
                    if prev_skip:
                        is_terminal = True
                    else:
                        prev_skip = True
                else:
                    prev_skip = False

                if is_terminal or iter >= self.args['max_moves']:
                    returnMemory = []
                    if self.args["game"] == "Attaxx" and debugging:
                        print("GAME OVER\n\n")
                    for hist_neutral_state, hist_action_probs, hist_player in memory[idx]:
                        hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)

                        if self.args['augment']:
                            augmented_states, augmented_action_probs = self.augment_state(hist_neutral_state, hist_action_probs)

                            for augmented_state, augmented_probs in zip(augmented_states, augmented_action_probs):
                                returnMemory.append((self.game.get_encoded_state(augmented_state), augmented_probs, hist_outcome))
                        else:
                            returnMemory.append((self.game.get_encoded_state(hist_neutral_state), hist_action_probs, hist_outcome))

                        returnData = returnData + returnMemory

                    del memory[idx]
                    del states[idx]

                if len(memory) <= 0:
                    return returnData

            player = self.game.get_opponent(player)

            if temperature >= 0.1:
                temperature = temperature * self.args['cooling_constant']
            else:
                temperature = 0.1

            iter += 1
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:batchIdx+self.args['batch_size']]
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self, memory = None, LAST_ITERATION=0):
        primary_memory = []

        if memory != None:
            primary_memory = memory

        for iteration in range(LAST_ITERATION+1, self.args['num_iterations']):
            print(f"Iteration {iteration + 1}")

            secondary_memory = []

            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                states = self.selfPlay()
                self.mcts.tree_dict = {}
                secondary_memory += states

            training_memory = []
            if self.args['experience_replay']:
                sample_size = int(len(primary_memory) * 0.3)

                training_memory += random.sample(primary_memory, min(sample_size, len(primary_memory)))
                training_memory += secondary_memory
                
                primary_memory += secondary_memory
            else:
                training_memory += secondary_memory

            print(f"Memory size: {len(training_memory)}")

            self.model.train()

            for epoch in trange(self.args['num_epochs']):
                self.train(training_memory)

            print("\n")
                
            torch.save(self.model.state_dict(), f"DevelopmentModels/{self.args['alias']}/model_{iteration}.pt")
            torch.save(self.optimizer.state_dict(), f"DevelopmentModels/{self.args['alias']}/optimizer_{iteration}.pt")
            with open(f'DevelopmentModels/{self.args["alias"]}/memory_{iteration}.pkl', 'wb') as f:
                pickle.dump(primary_memory, f)
            print("Data Saved!")

## Main

The following code was used both for training and testing the generated models.

In [None]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


SAVE_NAME = None

if __name__ == '__main__':

    # Go / Attaxx
    GAME = "Attaxx"

    # Board size (7/9 for Go, 4/5/6 for Attaxx)
    SIZE = 6

    # True to load previous model
    # False to start from scratch
    LOAD = True
    LAST_ITERATION = -1

    # Save Name
    SAVE_NAME = "a6x6"

    # False for training
    # True for playing
    TEST = True

    # False if locally 
    # True if playing in the server
    ONLINE = False

    # Train from scratch
    if not LOAD and not TEST:
        LAST_ITERATION=-1

    if GAME == 'Go':
        if SIZE == 7:
            args = {
                'game': 'Go',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                'batch_size': 16,                 # batch size for training
                'temperature': 3,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.90,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 2,                           # the value of the constant policy
                'experience_replay': True,        # recycle a certain % of old random selfplay data in the current training iteration
                'augment': False,                 # whether to augment the training data with flipped and rotated states
                'parallel_games': 10,            # number of games run in parallel
                'dirichlet_alpha': 0.03,          # the value of the dirichlet noise (alpha)
                'dirichlet_epsilon': 0.25,        # the value of the dirichlet noise (epsilon)
                'alias': ('Go' + SAVE_NAME)
            }

            game = Go(size = SIZE, komi = 5.5)
            model = ResNet(game, 10, 10, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
            
        elif SIZE == 9:
            args = {
                'game': 'Go',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 20,    # number of self-play games to play within each iteration
                'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 60,                 # number of epochs for training on self-play data for each iteration
                'batch_size': 32,                 # batch size for training
                'temperature': 3,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.85,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 1,                           # the value of the constant policy
                'experience_replay': True,        # recycle a certain % of old random selfplay data in the current training iteration
                'augment': False,                 # whether to augment the training data with flipped and rotated states
                'parallel_games': 5,            # number of games run in parallel
                'dirichlet_alpha': 0.03,          # the value of the dirichlet noise (alpha)
                'dirichlet_epsilon': 0.03,        # the value of the dirichlet noise (epsilon)
                'alias': ('Go' + SAVE_NAME)
            }

            game = Go(size = SIZE, komi = 5.5)
            model = ResNet(game, 9, 3, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    elif GAME == 'Attaxx':
        game_size = [SIZE,SIZE]
        if SIZE == 4:
            args = {
                'game': 'Attaxx',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 20,  # number of self-play games to play within each iteration
                'num_mcts_searches': 100,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 10,                # number of epochs for training on self-play data for each iteration
                'batch_size': 16,                # batch size for training
                'temperature': 3,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.9,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 1,                           # the value of the constant policy
                'dirichlet_alpha': 0.03,           # the value of the dirichlet noise
                'dirichlet_epsilon': 0.03,       # the 001value of the dirichlet noise
                'parallel_games': 10,            # number of games run in parallel
                'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                'augment': False,                  # whether to augment the training data with flipped and rotated states
                'alias': ('Attaxx' + SAVE_NAME)
            }

            game = Attaxx(game_size)
            model = ResNet(game, 4, 8, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

        elif SIZE == 5:
            args = {
                'game': 'Attaxx',
                'num_iterations': 10000,             # number of highest level iterations
                'num_selfPlay_iterations': 20,  # number of self-play games to play within each iteration
                'num_mcts_searches': 100,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 10,                # number of epochs for training on self-play data for each iteration
                'batch_size': 64,                # batch size for training
                'temperature': 1,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.85,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 1,                           # the value of the constant policy
                'dirichlet_alpha': 0.03,           # the value of the dirichlet noise
                'dirichlet_epsilon': 0.03,       # the value of the dirichlet noise
                'parallel_games': 15,            # number of games run in parallel
                'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                'augment': False,                  # whether to augment the training data with flipped and rotated states
                'alias': ('Attaxx' + SAVE_NAME)
            }

            game = Attaxx(game_size)
            model = ResNet(game, 8, 16, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

        elif SIZE == 6:
            args = {
                'game': 'Attaxx',
                'num_iterations': 20,             # number of highest level iterations
                'num_selfPlay_iterations': 20,  # number of self-play games to play within each iteration
                'num_mcts_searches': 100,         # number of mcts simulations when selecting a move within self-play
                'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                'num_epochs': 20,                # number of epochs for training on self-play data for each iteration
                'batch_size': 128,                # batch size for training
                'temperature': 1,                 # temperature for the softmax selection of moves
                'cooling_constant': 0.85,         # value that gets multiplied to the temperature to gradually reduce it  
                'C': 1,                           # the value of the constant policy
                'dirichlet_alpha': 0.03,           # the value of the dirichlet noise
                'dirichlet_epsilon': 0.03,       # the value of the dirichlet noise
                'parallel_games': 20,            # number of games run in parallel
                'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                'augment': False,                  # whether to augment the training data with flipped and rotated states
                'alias': ('Attaxx' + SAVE_NAME)
            }

            game = Attaxx(game_size)
            model = ResNet(game, 12, 32, device)
            optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    else:
        print("Game Unavailable")

    if LOAD:
        model.load_state_dict(torch.load(f'FinalModels/{SAVE_NAME}/{SAVE_NAME}.pt', map_location=device))
        optimizer.load_state_dict(torch.load(f'FinalModels/{SAVE_NAME}/o{SAVE_NAME}.pt', map_location=device))
    
        #with open(f'DevelopmentModels/{GAME+SAVE_NAME}/memory_{LAST_ITERATION}.pkl', 'rb') as f:
         #   memory = pickle.load(f)
    else:
        memory = None

    if not TEST:
        os.makedirs(f'DevelopmentModels/{GAME+SAVE_NAME}', exist_ok=True)
        alphaZero = AlphaZero(model, optimizer, game, args)
        alphaZero.learn(memory, LAST_ITERATION)

    elif not ONLINE:

        if not LOAD:
            print("No model to test")
            exit()

        global SIZE_BOARD, BLACK, WHITE, WOOD, SCREEN_SIZE

        if GAME == 'Go':

            
            SIZE_BOARD = SIZE
            BLACK = (0,0,0)
            WHITE = (255,255,255)
            WOOD = (216, 166, 91)
            SCREEN_SIZE = 600
            SCREEN_PADDING = 50
            CELL_SIZE = (SCREEN_SIZE - SCREEN_PADDING) // SIZE_BOARD
            PIECE_SIZE = (SCREEN_SIZE - 2*SCREEN_PADDING) // SIZE_BOARD // 3

            PLAYER1 = "AI"
            PLAYER2 = "AI"

            mcts = MCTS(model, game, args)
            state = game.get_initial_state()
            #game.print_board(state)

            player = 1
            prev_skip = False
            rendering = True
            click = False
            valid_moves = []

            for i in range(SIZE_BOARD):
                for j in range(SIZE_BOARD):
                    valid_moves.append([i, j])

            cur_pieces = []

            a = None

            if not rendering:
                game.print_board(state)
            else:
                pygame.init()
                #pygame_icon = pygame.image.load('image.png')
                #pygame.display.set_icon(pygame_icon)

                screen=pygame.display.set_mode((SCREEN_SIZE,SCREEN_SIZE))

                pygame.display.set_caption("Go")

            while True:

                if rendering:
                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            pygame.quit()
                        if event.type == pygame.MOUSEBUTTONDOWN:
                            click = True
                            a, b, player = gohover_to_select(player, valid_moves, click)
                        if event.type == pygame.MOUSEBUTTONUP:
                            click = False

                    screen.fill(WOOD)
                    godraw_board()
                    for i in range(0,len(state)):
                        for j in range(0,len(state)):
                            if state[i][j] == 0:
                                if game.is_valid_move(state, (i,j), player):
                                    valid_moves.append([i,j])
                            else:
                                godraw_piece(i,j, state[i][j])

                if player == 1:
                    
                    if PLAYER1 == 'user':

                        valid_move_selected = False

                        a, b, player = gohover_to_select(player, valid_moves, click)

                        if click:

                            mouse_x, mouse_y = pygame.mouse.get_pos()
                            a, b = goto_coord(mouse_x), goto_coord(mouse_y)

                            action = a * SIZE + b

                            state = game.get_next_state(state, action, player)

                            winner, win = game.get_value_and_terminated(state, action, player)
                
                            if action == game.action_size:
                                if prev_skip:
                                    win = True
                                else:
                                    prev_skip = True
                            else:
                                prev_skip = False

                            if win:
                                print(f"player {winner} wins")
                                break

                            player = -player

                    else:
                        tmp_state = game.change_perspective(state, -1)
                        action = mcts.search([tmp_state], -player)                    
                        action = np.argmax(action[0])
                        print(f"\nAlphaZero Action: {action // game.row_count} {action % game.column_count}\n")
                        state = game.get_next_state(state, action, player)

                        winner, win = game.get_value_and_terminated(state, action, player)
                
                        if action == game.action_size:
                            if prev_skip:
                                win = True
                            else:
                                prev_skip = True
                        else:
                            prev_skip = False

                        if win:
                            print(f"player {winner} wins")
                            break

                        player = -player
                else:
                    if PLAYER2 == 'user':
                        valid_move_selected = False

                        a, b, player = gohover_to_select(player, valid_moves, click)

                        if click:

                            mouse_x, mouse_y = pygame.mouse.get_pos()
                            a, b = goto_coord(mouse_x), goto_coord(mouse_y)

                            action = a * SIZE + b

                            state = game.get_next_state(state, action, player)

                            winner, win = game.get_value_and_terminated(state, action, player)
                
                            if action == game.action_size:
                                if prev_skip:
                                    win = True
                                else:
                                    prev_skip = True
                            else:
                                prev_skip = False

                            if win:
                                print(f"player {winner} wins")
                                break

                            player = -player
                    else:
                        action = mcts.search([state], player)                    
                        action = np.argmax(action[0])
                        

                        print(f"\nAlphaZero Action: {action // game.row_count} {action % game.column_count}\n")
                        state = game.get_next_state(state, action, player)

                        winner, win = game.get_value_and_terminated(state, action, player)
                
                        if action == game.action_size:
                            if prev_skip:
                                win = True
                            else:
                                prev_skip = True
                        else:
                            prev_skip = False

                        if win:
                            print(f"player {winner} wins")
                            break

                        player = -player

                pygame.display.flip()

        elif GAME == 'Attaxx':
            PLAYER1 = "AI"
            PLAYER2 = "AI"
            SIZE_BOARD = SIZE

            RED = (238, 167, 255)
            BLUE = (113, 175, 255)
            GRAY = (115, 115, 115)
            BLACK = (0, 0, 0)

            pygame.init()

            SCREEN_SIZE=600
            SCREEN_PADDING = 100
            CELL_SIZE = (SCREEN_SIZE - SCREEN_PADDING) // SIZE_BOARD
            PIECE_SIZE = (SCREEN_SIZE - 2*SCREEN_PADDING) // SIZE_BOARD // 3

            screen=pygame.display.set_mode((SCREEN_SIZE,SCREEN_SIZE))

            pygame.display.set_caption("Attaxx")
            click = False
            valid_moves = []
            for i in range(SIZE_BOARD):
                for j in range(SIZE_BOARD):
                    valid_moves.append([i, j])

            cur_pieces =[]

            player = 1
            selected_piece = False
            last_click_time = 0
            x, y = -1, -1

            mcts = MCTS(model, game, args)
            state = game.get_initial_state()
            sel = False
            while True:

                screen.fill(GRAY)
                atdraw_board()

                for i in range(0,len(state)):
                        for j in range(0,len(state)):
                            if state[i][j] != 0:
                                atdraw_piece(i,j, state[i][j])

                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                    if event.type == pygame.MOUSEBUTTONDOWN:
                        click = True
                        x, y, player, selected_piece, last_click_time = athover_to_select(x, y, player, valid_moves, click, selected_piece, cur_pieces, last_click_time)
                        pygame.time.delay(50)
                    if event.type == pygame.MOUSEBUTTONUP:
                        click = False
                        pygame.time.delay(50)


                for piece in cur_pieces:
                    atdraw_piece(piece[0], piece[1], piece[2])

                if player == 1:

                    if PLAYER1 == 'user':
                        x, y, player, selected_piece, last_click_time = athover_to_select(x, y, player, valid_moves, click, selected_piece, cur_pieces, last_click_time)
                        if x != -1 and y != -1:
                            valid_moves = []
                            for i in range(0,len(state)):
                                for j in range(0,len(state)):
                                    if game.is_valid_move(state, (x,y,i,j), player):
                                        valid_moves.append([i,j])

                        if sel:
                            for move in valid_moves:
                                px, py = atto_pixels(move[0]), atto_pixels(move[1])
                                s = pygame.Surface((SCREEN_SIZE, SCREEN_SIZE), pygame.SRCALPHA)
                                pygame.draw.circle(s, (113, 175, 255, 100) if player == 1 else (238, 167, 255, 100), (px, py), PIECE_SIZE, 3)  # Change color as needed
                                screen.blit(s, (0, 0))

                        if click and sel:
                            mouse_x, mouse_y = pygame.mouse.get_pos()
                            a, b = atto_coord(mouse_x), atto_coord(mouse_y)
                            if a == x and b == y:
                                sel = False
                            else:
                                move = (x,y,atto_coord(mouse_x), atto_coord(mouse_y))

                                action = game.move_to_int(move)
                                state = game.get_next_state(state, action, player)

                                winner, win = game.get_value_and_terminated(state, action, player)
                                if win:
                                    game.print_board(state)
                                    print(f"player {winner} wins")
                                    pygame.quit()
                                    break
                                
                                player = -player
                                x = -1
                                y = -1
                                sel = False
                        
                        if click and not sel:
                            mouse_x, mouse_y = pygame.mouse.get_pos()
                            x, y = atto_coord(mouse_x), atto_coord(mouse_y)
                            if state[x,y] == player:
                                sel = True

                    else:

                        tmp_state = game.change_perspective(state, -1)
                        action = mcts.search([tmp_state], -player)
                        action = np.argmax(action)
                        print(f"\nAlphaZero Action: {game.int_to_move(action)}\n")
                        state = game.get_next_state(state, action, player)
                        winner, win = game.get_value_and_terminated(state, action, player)
                        if win:
                            game.print_board(state)
                            print(f"player {winner} wins")
                            pygame.quit()
                            break

                        player = -player

                else:
                    if PLAYER2 == 'user':
                        x, y, player, selected_piece, last_click_time = athover_to_select(x, y, player, valid_moves, click, selected_piece, cur_pieces, last_click_time)
                        if x != -1 and y != -1:
                            valid_moves = []
                            for i in range(0,len(state)):
                                for j in range(0,len(state)):
                                    if game.is_valid_move(state, (x,y,i,j), player):
                                        valid_moves.append([i,j])

                        if sel:
                            for move in valid_moves:
                                px, py = atto_pixels(move[0]), atto_pixels(move[1])
                                s = pygame.Surface((SCREEN_SIZE, SCREEN_SIZE), pygame.SRCALPHA)
                                pygame.draw.circle(s, (113, 175, 255, 100) if player == 1 else (238, 167, 255, 100), (px, py), PIECE_SIZE, 3)  # Change color as needed
                                screen.blit(s, (0, 0))

                        if click and sel:
                            mouse_x, mouse_y = pygame.mouse.get_pos()
                            a, b = atto_coord(mouse_x), atto_coord(mouse_y)
                            if a == x and b == y:
                                sel = False
                            else:
                                move = (x,y,atto_coord(mouse_x), atto_coord(mouse_y))

                                action = game.move_to_int(move)
                                state = game.get_next_state(state, action, player)

                                winner, win = game.get_value_and_terminated(state, action, player)
                                if win:
                                    game.print_board(state)
                                    print(f"player {winner} wins")
                                    pygame.quit()
                                    break
                                print("brug")
                                player = -player
                                x = -1
                                y = -1
                                sel = False
                        
                        if click and not sel:
                            mouse_x, mouse_y = pygame.mouse.get_pos()
                            x, y = atto_coord(mouse_x), atto_coord(mouse_y)
                            if state[x,y] == player:
                                sel = True

                    else:

                        tmp_state = game.change_perspective(state, -1)
                        action = mcts.search([tmp_state], -player)
                        action = np.argmax(action)
                        print(f"\nAlphaZero Action: {game.int_to_move(action)}\n")
                        state = game.get_next_state(state, action, player)
                        winner, win = game.get_value_and_terminated(state, action, player)
                        if win:
                            game.print_board(state)
                            print(f"player {winner} wins")
                            pygame.quit()
                            break

                        player = -player



                pygame.display.flip()

# Online Client Runtime

In [None]:
ONLINE = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if ONLINE:
    
    mode="A6x6" # "A6x6" "G7x7" "G9x9" "A5x5"

    def load_model():

        global model, optimizer, mcts, game, args, state
        game_size = [mode[1],mode[1]]
        if mode[0] == "A":
            game_size = [int(mode[1]),int(mode[1])]

            if int(mode[1]) == 4:
                args = {
                    'game': 'Attaxx',
                    'num_iterations': 1,              # number of highest level iterations
                    'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                    'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                    'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                    'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                    'batch_size': 16,                 # batch size for training
                    'temperature': 3,                 # temperature for the softmax selection of moves
                    'cooling_constant': 0.9,          # value that gets multiplied to the temperature to gradually reduce it  
                    'C': 2,                           # the value of the constant policy
                    'dirichlet_alpha': 0.03,          # the value of the dirichlet noise
                    'dirichlet_epsilon': 0.25,        # the 001value of the dirichlet noise
                    'parallel_games': 10,             # number of games run in parallel
                    'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                    'augment': False,                 # whether to augment the training data with flipped and rotated states
                }

                game = Attaxx(game_size)
                model = ResNet(game, 4, 8, device)
                optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

            elif int(mode[1]) == 5:
                args = {
                    'game': 'Attaxx',
                    'num_iterations': 1,              # number of highest level iterations
                    'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                    'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                    'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                    'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                    'batch_size': 16,                 # batch size for training
                    'temperature': 3,                 # temperature for the softmax selection of moves
                    'cooling_constant': 0.9,          # value that gets multiplied to the temperature to gradually reduce it  
                    'C': 2,                           # the value of the constant policy
                    'dirichlet_alpha': 0.03,          # the value of the dirichlet noise
                    'dirichlet_epsilon': 0.25,        # the 001value of the dirichlet noise
                    'parallel_games': 10,             # number of games run in parallel
                    'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                    'augment': False,                 # whether to augment the training data with flipped and rotated states
                }

                game = Attaxx(game_size)
                model = ResNet(game, 8, 16, device)
                optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

            elif int(mode[1]) == 6:
                args = {
                    'game': 'Attaxx',
                    'num_iterations': 1,              # number of highest level iterations
                    'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                    'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                    'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                    'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                    'batch_size': 16,                 # batch size for training
                    'temperature': 3,                 # temperature for the softmax selection of moves
                    'cooling_constant': 0.9,          # value that gets multiplied to the temperature to gradually reduce it  
                    'C': 2,                           # the value of the constant policy
                    'dirichlet_alpha': 0.03,          # the value of the dirichlet noise
                    'dirichlet_epsilon': 0.25,        # the 001value of the dirichlet noise
                    'parallel_games': 10,             # number of games run in parallel
                    'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                    'augment': False,                 # whether to augment the training data with flipped and rotated states
                }

                game = Attaxx(game_size)
                model = ResNet(game, 12, 32, device)
                optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
                
            state = game.get_initial_state()

            if int(mode[1]) == 4:
                model.load_state_dict(torch.load(f'FinalModels/a4x4/a4x4.pt', map_location=device))
                optimizer.load_state_dict(torch.load(f'FinalModels/a4x4/oa4x4.pt', map_location=device))
                print("Successfuly loaded Attaxx 4x4")
            elif int(mode[1]) == 5:
                model.load_state_dict(torch.load(f'FinalModels/a5x5/a5x5.pt', map_location=device))
                optimizer.load_state_dict(torch.load(f'FinalModels/a5x5/oa5x5.pt', map_location=device))
                print("Successfuly loaded Attaxx 5x5")
            elif int(mode[1]) == 6:
                model.load_state_dict(torch.load(f'FinalModels/a6x6/a6x6.pt', map_location=device))
                optimizer.load_state_dict(torch.load(f'FinalModels/a6x6/oa6x6.pt', map_location=device))
                print("Successfuly loaded Attaxx 6x6")

            mcts = MCTS(model, game, args)
            
        elif mode[0] == "G":
            game_size = int(mode[1])

            if game_size == 7:
                args = {
                    'game': 'Go',
                    'num_iterations': 1,              # number of highest level iterations
                    'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                    'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                    'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                    'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                    'batch_size': 16,                 # batch size for training
                    'temperature': 3,                 # temperature for the softmax selection of moves
                    'cooling_constant': 0.9,          # value that gets multiplied to the temperature to gradually reduce it  
                    'C': 2,                           # the value of the constant policy
                    'dirichlet_alpha': 0.03,          # the value of the dirichlet noise
                    'dirichlet_epsilon': 0.25,        # the 001value of the dirichlet noise
                    'parallel_games': 10,             # number of games run in parallel
                    'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                    'augment': False,                 # whether to augment the training data with flipped and rotated states
                }
                print('here')
                game = Go(size = int(mode[1]), komi = 5.5)
                model = ResNet(game, 10, 10, device)
                optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
            
            elif game_size == 9:
                args = {
                    'game': 'Go',
                    'num_iterations': 1,              # number of highest level iterations
                    'num_selfPlay_iterations': 15,    # number of self-play games to play within each iteration
                    'num_mcts_searches': 200,         # number of mcts simulations when selecting a move within self-play
                    'max_moves': 512,                 # maximum number of moves in a game (to avoid infinite games which should not happen but just in case)
                    'num_epochs': 20,                 # number of epochs for training on self-play data for each iteration
                    'batch_size': 16,                 # batch size for training
                    'temperature': 3,                 # temperature for the softmax selection of moves
                    'cooling_constant': 0.9,          # value that gets multiplied to the temperature to gradually reduce it  
                    'C': 2,                           # the value of the constant policy
                    'dirichlet_alpha': 0.03,          # the value of the dirichlet noise
                    'dirichlet_epsilon': 0.25,        # the 001value of the dirichlet noise
                    'parallel_games': 10,             # number of games run in parallel
                    'experience_replay': True,        # we recycle 30% of old random selfplay data in the current training iteration
                    'augment': False,                 # whether to augment the training data with flipped and rotated states
                }

                game = Go(size = int(mode[1]), komi = 5.5)
                model = ResNet(game, 9, 3, device)
                optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

            state = game.get_initial_state()

            if game_size == 7:
                model.load_state_dict(torch.load(f'FinalModels/g7x7/g7x7.pt', map_location=device))
                optimizer.load_state_dict(torch.load(f'FinalModels/g7x7/og7x7.pt', map_location=device))
                print("Successfuly loaded Go 7x7")
            elif game_size == 9:
                model.load_state_dict(torch.load(f'FinalModels/g9x9/g9x9.pt', map_location=device))
                optimizer.load_state_dict(torch.load(f'FinalModels/g9x9/og9x9.pt', map_location=device))
                print("Successfuly loaded Go 9x9")

            mcts = MCTS(model, game, args)

    def generate_move():
        
        if mode[0] == "A":
            if player == 1:
                tmp_state = game.change_perspective(state, -1)
                action = mcts.search([tmp_state], -player)
                action = np.argmax(action)
            else:
                action = mcts.search(state, player)
                action = np.argmax(action)

            move = game.int_to_move(action)
            print(f"\nAlphaZero Action: {move}\n")
            
            state = game.get_next_state(state, action, player)

            return f"MOVE {move[0]} {move[1]} {move[2]} {move[3]}"
        
        else:
            if player == 1:
                tmp_state = game.change_perspective(state, -1)
                action = mcts.search([tmp_state], -player)
                action = np.argmax(action)
            else:
                action = mcts.search(state, player)
                action = np.argmax(action)

            print(f"\nAlphaZero Action: {action // game.row_count} {action % game.column_count}\n")
            state = game.get_next_state(state, action, player)

            return f"MOVE {action // game.row_count} {action % game.column_count}"

    def apply_opponent_move(response):

        numbers = [int(x) for x in response.split()[1:]]

        if mode[-1] == "A":
            action = game.move_to_int((response[0],response[1],response[2],response[3]))
            state = game.get_next_state(state, action, player)

        else:
            action = numbers[0] * game.row_count + numbers[1]
            state = game.get_next_state(state, action, player)

    def connect_to_server(host='localhost', port=12345):
        global player
        client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client_socket.connect((host, port))
        
        response = client_socket.recv(1024).decode()
        print(f"Server ResponseINIT: {response}")
        
        mode = response[-4:]
        print("Playing:", mode)

        load_model()
        
        if "1" in response:
            player=1
        else:
            player=-1
            first=True
        
        while True:
            # Generate and send a random move
            if player == 1 or not first:
                move = generate_move()
                player = -player
                time.sleep(1)
                client_socket.send(move.encode())
                print("Send:",move)
            
                # Wait for server response
                response = client_socket.recv(1024).decode()
                apply_opponent_move(response)
                print(f"Server Response1: {response}")
                if "END" in response: break
            
            first = False
            response = client_socket.recv(1024).decode()
            apply_opponent_move(response)
            print(f"Server Response2: {response}")
            if "END" in response: break

            player = -player

        client_socket.close()

    connect_to_server()
