In [12]:
from __future__ import print_function
import go
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random
import torch
import math
from torch.autograd import Variable

In [13]:
class ResNet(nn.Module):
    def __init__(self, game, num_resBlocks, num_hidden, device):
        super().__init__()
        
        self.device = device
        self.startBlock = nn.Sequential(
            nn.Conv2d(3, num_hidden, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_hidden),
            nn.ReLU()
        )
        
        self.backBone = nn.ModuleList(
            [ResBlock(num_hidden) for i in range(num_resBlocks)]
        )
        
        self.policyHead = nn.Sequential(
            nn.Conv2d(num_hidden, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * game.row_count * game.column_count, game.action_size)
        )
        
        self.valueHead = nn.Sequential(
            nn.Conv2d(num_hidden, 3, kernel_size=3, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3 * game.row_count * game.column_count, 1),
            nn.Tanh()
        )
        
        self.to(device)
        
    def forward(self, x):
        x = self.startBlock(x)
        for resBlock in self.backBone:
            x = resBlock(x)
        policy = self.policyHead(x)
        value = self.valueHead(x)
        return policy, value
        
        
class ResBlock(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_hidden)
        self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_hidden)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = F.relu(x)
        return x
        

In [14]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, prior=0, visit_count=0):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken
        self.prior = prior
        
        self.children = []
        
        self.visit_count = visit_count
        self.value_sum = 0
        
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def select(self):
        best_child = None
        best_ucb = -np.inf
        
        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
                
        return best_child
    
    def get_ucb(self, child):
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * (math.sqrt(self.visit_count) / (child.visit_count + 1)) * child.prior
    
    def expand(self, policy):
        for action, prob in enumerate(policy):
            if prob > 0:
                child_state = self.state.copy()
                child_state = self.game.get_next_state(child_state, action, 1)
                child_state = self.game.change_perspective(child_state, player=-1)

                child = Node(self.game, self.args, child_state, self, action, prob)
                self.children.append(child)
                
        return child
            
    def backpropagate(self, value):
        self.value_sum += value
        self.visit_count += 1
        
        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)  


class MCTS:
    def __init__(self, game, args, model):
        self.game = game
        self.args = args
        self.model = model
        
    @torch.no_grad()
    def search(self, state):
        #aqui buscar tabuleiro antigo
        root = Node(self.game, self.args, state, visit_count=1)
        
        policy, _ = self.model(
            torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
        )
        policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
        policy = (1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] \
            * np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size)
        #por tabuleiro,dentro tornar em gamestate
        valid_moves = self.game.get_valid_moves(root.state)
        policy *= valid_moves
        policy /= np.sum(policy)
        root.expand(policy)

        
        for search in range(self.args['num_searches']):
            node = root
            pas=False###
            
            while node.is_fully_expanded():
                node = node.select()
            #ter de por aqui tambem  
            if node.parent is not None:
                if(node.parent.action_taken ==self.game.action_size-1 and node.action_taken== self.game.action_size-1):
                    pas=True 
            value, is_terminal = self.game.get_value_and_terminated(node.state,pas)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                policy, value = self.model(
                    torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
                )
                policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
                #get_valid_moves
                valid_moves = self.game.get_valid_moves(node.state)
                policy *= valid_moves
                policy /= np.sum(policy)
                
                value = value.item()
                
                node.expand(policy)
                
            node.backpropagate(value)    
            
            
        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs /= np.sum(action_probs)
        return action_probs
        

In [18]:
from tqdm.notebook import trange
from multiprocessing import Pool

class AlphaZero:
    def __init__(self, model, optimizer, game, args):
        self.model = model
        self.optimizer = optimizer
        self.game = game
        self.args = args
        self.mcts = MCTS(game, args, model)
        
    def selfPlay(self):
        memory = []
        player = 1
        state = self.game.get_initial_state()
        pas=0
        first_time = 0
        
        while True:
            passe=False
            neutral_state = self.game.change_perspective(state, player)
            #passar variavel de passar e tabuleiro antigo##########################################
            action_probs = self.mcts.search(neutral_state)
            
            memory.append((neutral_state, action_probs, player))
            
            temperature_action_probs = action_probs ** (1 / self.args['temperature']) # Divide temperature_action_probs with its sum in case of an error
            temperature_action_probs /= np.sum(temperature_action_probs)
                
            action = np.random.choice(self.game.action_size, p=temperature_action_probs)
            row, col = self.game.action_to_row_col(action)
            ################################################# nao sei se está bem
            if first_time != 0:
                while True:
                    if self.game.superko_state(state, player, board_superko, row, col) == False:
                        action = np.random.choice(self.game.action_size, p=temperature_action_probs)
                        row, col = self.game.action_to_row_col(action)
                    break
            first_time = 1
            ################################################# nao sei se está bem
            state = self.game.get_next_state(state, action, player)
            print(temperature_action_probs)
            print(action)
            print(state)
            if action==self.game.action_size-1:
                pas+=1
                if(pas==2):
                    passe=True
            else:
                pas=0
            
            board_superko = state
            value, is_terminal = self.game.get_value_and_terminated(state,passe)
            
            if is_terminal:
                returnMemory = []
                for hist_neutral_state, hist_action_probs, hist_player in memory:
                    hist_outcome = value if hist_player == player else self.game.get_opponent_value(value)
                    returnMemory.append((
                        self.game.get_encoded_state(hist_neutral_state),
                        hist_action_probs,
                        hist_outcome
                    ))
                return returnMemory
            
            player = self.game.get_opponent(player)
                
    def train(self, memory):
        random.shuffle(memory)
        for batchIdx in range(0, len(memory), self.args['batch_size']):
            sample = memory[batchIdx:min(len(memory) - 1, batchIdx + self.args['batch_size'])] # Change to memory[batchIdx:batchIdx+self.args['batch_size']] in case of an error
            state, policy_targets, value_targets = zip(*sample)
            
            state, policy_targets, value_targets = np.array(state), np.array(policy_targets), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            policy_targets = torch.tensor(policy_targets, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_policy, out_value = self.model(state)
            
            policy_loss = F.cross_entropy(out_policy, policy_targets)
            value_loss = F.mse_loss(out_value, value_targets)
            loss = policy_loss + value_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            
            self.model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations']):
                memory += self.selfPlay()
                
            self.model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
            
            torch.save(self.model.state_dict(), f"model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}_{self.game}.pt")
            
    def learn_parallel(self):
        num_processes = self.args['num_parallel_games']

        for iteration in range(self.args['num_iterations']):
            memory = []

            self.model.eval()

            def self_play_wrapper(_):
                return self.selfPlay()

            with Pool(num_processes) as pool:
                result_memory = pool.map(self_play_wrapper, range(num_processes))

            memory = [item for sublist in result_memory for item in sublist]

            self.model.train()

            for epoch in trange(self.args['num_epochs']):
                self.train(memory)

            torch.save(self.model.state_dict(), f"model_para{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_para{iteration}_{self.game}.pt")

In [20]:
import go  

class Connect2Game:
    def __init__(self, n):
        self.row_count = n
        self.column_count = n
        self.action_size = n * n + 1
        self.game_over = False

    def __repr__(self):
        return "Go"

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))

    def get_next_state(self, state, action, player):
        b = go.GameState(state)
        b.turn = player
        row = action // self.column_count
        col = action % self.column_count

        if action == self.column_count ** 2:
            boa = b.pass_turn()
        else:
            boa = b.move(row, col)
        return boa.board

    def get_valid_moves(self, state):
        valid_moves = [0] * self.action_size
        valid_moves[-1] = 1
        b = go.GameState(state)
        possi = go.check_possible_moves(b)
        for i in possi:
            action = i[0] * self.column_count + i[1]
            valid_moves[action] = 1
        return valid_moves
    
    def action_to_row_col(self, action):
        row = action // self.column_count
        col = action % self.column_count
        return row, col


    def get_value_and_terminated(self, state, pas):
        b = go.GameState(state)
        if pas:
            b.pass_count = 2
            self.game_over = True
        return b.get_value_and_terminated(b)

    def print_state(self, state):
        b = go.GameState(state)
        print(b)

    def is_game_over(self):
        return self.game_over

    def get_opponent(self, player):
        return -player
    
    def turn(self, player):
        return player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        return state * player

    def get_encoded_state(self, state):
        encoded_state = np.stack(
            (state == -1, state == 0, state == 1)
        ).astype(np.float32)

        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)

        return encoded_state
    
    def superko_state(self, state, player, previous_state, i,j):
        b = go.GameState(state)
        last = go.GameState(previous_state)
        new_board = b.board
        last_board = last.board
        new_board, _ = go.check_for_captures(new_board,player)
        if np.array_equal(new_board, last_board):
            return True
        return False
        
        



In [21]:

game = Connect2Game(3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(game, 9, 128, device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

args = {
    'C': 2,
    'num_searches': 600,
    'num_iterations': 8,
    'num_selfPlay_iterations': 2,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

alphaZero = AlphaZero(model, optimizer, game, args)
alphaZero.learn()

#alphaZero = AlphaZero(model, optimizer, game, args)
#alphaZero.learn_parallel()

  0%|          | 0/2 [00:00<?, ?it/s]

Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes in a row
Reason for game ending: 2 passes i

In [None]:
model = ResNet(game, 9, 128, device)

# Load the pre-trained weights
model.load_state_dict(torch.load("model_7_Go.pt"))
model.eval()

alphaZero = AlphaZero(model, optimizer, game, args)

In [None]:
"""
def print_board(game, state):
    print("Current Board:")
    game.print_state(state)


def human_player(game):
    print("Your turn! Enter your move (row, column) or type 'pass': ")
    while True:
        try:
            move_input = input()
            if move_input.lower() == 'pass':
                return game.row_count * game.column_count  # Pass move
            else:
                row, col = map(int, move_input.split(','))
                action = row * game.column_count + col
                if 0 <= row < game.row_count and 0 <= col < game.column_count and game.get_valid_moves(game.get_initial_state())[action] == 1:
                    return action
                else:
                    print("Invalid move. Please enter a valid move.")
        except ValueError:
            print("Invalid input. Please enter a valid move.")


def play_game(alphaZero, human_player):
    game_state = game.get_initial_state()
    while not game.is_game_over():
        print_board(game, game_state)

        if game.turn == 1:
            action_probs = alphaZero.mcts.search(game.change_perspective(game_state, game.turn))
            action = np.argmax(action_probs)
        else:
            action = human_player(game)

        game_state = game.get_next_state(game_state, action, game.turn)

    print_board(game, game_state)
    print("Game over!")



# Start the game
play_game(alphaZero, human_player)
"""

SyntaxError: incomplete input (3276390864.py, line 1)