In [None]:
from gym_go.envs import go_env
from gym_go import gogame, govars
import torch
from torch.utils.data import DataLoader


In [197]:
from RL_GoBot import var
from RL_GoBot.model import GoBot
from RL_GoBot.learning import get_optimizer, compute_loss, test, train
from RL_GoBot.data_base import GoDatabase

In [198]:
from math import sqrt


def action2d(action) :
    return action // var.BOARD_SIZE, action % var.BOARD_SIZE 


def roll_policy(state, net, prt = False):    
    """
    return :
        the score of the outcome of the self played game seen by the first player to play
    """
    player = state[2,0,0]
    game_ended = gogame.game_ended(state)

    game_turns = 0
    while not game_ended:
        if game_turns < var.MAX_TURNS :
            if torch.all(torch.tensor(state[3,:,:]) == 1):  # if there is no more legal mouves, the only move is pass
                action = var.BOARD_SIZE**2      # pass move
            else :
                result = net([state]).result[0]
                # print(result)
                sorted_indices_actions = sorted(range(len(result)), key=lambda i: result[i])    # sort from the most willing move to play until the most unwilling move, by the actual policy
                invalid_moves = gogame.invalid_moves(state)
                for i in range(var.BOARD_SIZE**2+1) :
                    action = sorted_indices_actions[i]
                    if not invalid_moves[action] :  # check if it is a valide move
                        break
        else :
            action = var.BOARD_SIZE**2
        
        # if prt : print(action2d(action))
        # if prt : print(state)
        state = gogame.next_state(state, action)
        game_ended = gogame.game_ended(state)
        game_turns += 1

    score = gogame.winning(state, var.KOMI)  # from absolut black perspective
    if player :
        score = - score
    return score        # from the perspective of the player who where supposed to play at the starting state of the roll_policy
 

In [199]:
import random

class Node:
    def __init__(self, state, action: int, p: torch.Tensor, depth=0):
        self.state = state
        self.action = action
        self.p = p
        self.depth = depth
        self.end_game = self.next_state()[govars.DONE_CHNL,0,0]

        self.N = 0
        self.Wv = 0
        self.Wr = 0
        self.Q = 0

        self.next_nodes = []

        self.first_search = True
        self.value = 0
        self.roll = 0

    def next_state(self):
        if self.action is None:
            return self.state
        return gogame.next_state(self.state, self.action)

    def action_2d(self):
        if self.action is None:
            return "None"
        return self.action // var.BOARD_SIZE, self.action % var.BOARD_SIZE


    def update_Q(self):
        # Calcul de Q à partir de Wv et Wr
        if self.N > 0:
            self.Q = ((1 - var.QU_RATIO) * self.Wv + var.QU_RATIO * self.Wr) / self.N

    def __str__(self):
        print("___ node info ___")
        print("action : {} | depth : {}".format(self.action_2d(), self.depth))
        print("N : {}, P : {}".format(self.N, self.p))
        print(self.state[3])
        return ' '


class MCTS:
    def __init__(self, net: GoBot, root_node: Node, temperature = var.TEMPERATURE_MCTS):
        self.root = root_node
        self.net = net
        self.temperature = temperature 
        self.root_depth = 0
        self.policy = None


    def extend_node(self, node):
        # print("\n -- extend --")
        # Simulation de la création des enfants
        node.Wv = node.N * node.value
        node.Wr = node.N * node.roll
        next_state = node.next_state()

        # Prédiction du NN
        result = self.net([next_state]).result[0]

        # Mask des coups invalides et softmax
        invalid_moves = gogame.invalid_moves(next_state)  # contain also the pass move
        result[invalid_moves == 1] = float('-inf')
        # print(result[:-1].view(7,7))
        # print(result[-1])
        prior_probs = torch.softmax(result, dim=-1)
        # print(prior_probs[:-1].view(7,7))
        # print(prior_probs[-1])

        for action, p in enumerate(prior_probs):
            if p != 0 : 
                node.next_nodes.append(Node(
                    next_state,
                    action,
                    p,
                    node.depth + 1
                ))
        # print(len(node.next_nodes))


    def push_search(self, node: Node, prt=False):
        # Cas de la première exploration
        if node.first_search:
            node.first_search = False
            next_state = node.next_state()
            node.value = self.net(next_state).value
            node.roll = roll_policy(next_state, self.net, prt)
            node.Q = (1 - var.QU_RATIO) * node.value + var.QU_RATIO * node.roll
        
        if node.end_game:
            node.N += 1
            return -node.value, -node.roll

        # Cas N >= threshold
        if node.N >= var.N_THRESHOLD:
            if not node.next_nodes:
                self.extend_node(node)

            # Sélection du next_node selon UCT
            node_to_search = max(
                node.next_nodes,
                key=lambda n: n.Q + var.C_PUCT * n.p * sqrt(node.N) / (1 + n.N)
            )
            # qpu = [n.Q + var.C_PUCT * n.p * sqrt(node.N) / (1 + n.N) for n in node.next_nodes]
            # print(qpu)
            # print(node_to_search)

            value, output = self.push_search(node_to_search, prt)
            node.Wv += value
            node.Wr += output
            node.update_Q()

        else:
            value = node.value
            output = node.roll

        node.N += 1
        return -value, -output


    def best_next_node(self):
        """Retourne l'enfant avec le plus grand nombre de visites N"""
        if not self.root.next_nodes:
            print("Error : No policy is possible without expending this node")
            return None
        best_node = max(self.root.next_nodes, key=lambda n: n.N)
        return best_node
    

    def next_node(self):
        """Échantillonne un enfant selon la policy"""
        if self.policy is None:
            self.tree_policy()
            
        r = random.random()
        self.affiche_policy()
        print("r : ", r)
        cumulative = 0
        for node in self.root.next_nodes:
            cumulative += self.policy[node.action]
            if cumulative > r:
                print("action : ", node.action)
                self.policy = None
                return node

        self.policy = None
        # Au cas où la somme n'est pas exactement 1 à cause des flottants
        return self.next_nodes[-1]
    

    def tree_policy(self):
        visited_N = [0 for i in range(var.BOARD_SIZE**2 + 1)]
        for node in self.root.next_nodes:
            visited_N[node.action] = node.N

        visited_N_temperated = [N**(1 / self.temperature) for N in visited_N]
        diviseur = sum(visited_N_temperated)
        self.policy = [N / diviseur for N in visited_N_temperated]
        return self.policy


    def affiche_policy(self):
        tensor = torch.tensor(self.policy[:-1])      # tensor 1D
        tensor_7x7 = tensor.view(7, 7)  
        print("policy : ", tensor_7x7)
        print("pass : ", self.policy[-1])

    def __str__(self):
        print(self.root)
        return ' '

In [None]:
import time


def one_self_play_MCTS(net):
    with torch.no_grad():
        data_set = []
        state = gogame.init_state(var.BOARD_SIZE)
        root_node = Node(state, None, 1)
        tree = MCTS(net, root_node)
        while not gogame.game_ended(state) : 
            print("new root \n", tree)
            tmp = time.time()
            
            for _ in range(var.N_TREE_SEARCH) : 
                tree.push_search(tree.root)
            MCTS_policy = tree.tree_policy()
            data_set.append([state, MCTS_policy])
        
            next_root = tree.next_node()
            next_state = gogame.next_state(state, next_root.action)
            tree.root = next_root
            state = next_state
            print("\n", time.time() - tmp)

    reward = gogame.winning(state, var.KOMI)
    black_area, white_area = gogame.areas(state)
    print(black_area, white_area)
    # if data_set[0,0][2,0,0] == 1:   # if the first to play is white (nevers happen)
    #     reward = -reward
    for move in data_set:
        move.append(reward)
        reward = -reward

    return data_set

In [None]:
def self_play_MCTS(N, net, db : GoDatabase = None):
  data = []
  for i in range(N):
    game_moves = one_self_play_MCTS(net)
    if db is not None:
      db.save_one_game(game_moves)
  data.extend(game_moves)
  return data

In [None]:
def get_data(db : GoDatabase, batch_size):
    return torch.utils.data.DataLoader(db, batch_size, shuffle=True)

In [201]:
env = go_env.GoEnv(size=var.BOARD_SIZE, komi=var.KOMI)
state, reward, done, info = env.step((0,0))
print(state[:4], '\n----\n')
state, reward, done, info = env.step((1,0))
print(state[:4])


[[[1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]

 [[1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1.]]

 [[1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]] 
----

[[[1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0.]
  [1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0.]
  [0. 0.

In [202]:
net = GoBot()
out = net.forward(state)
print(out.result[0])
print(out.result[0,10])



OutputFormating([ 0.0978,  0.0354, -0.0486,  0.0619, -0.0110, -0.0393, -0.0566,
                 -0.0397,  0.0071,  0.0428,  0.0042, -0.0550, -0.0078,  0.0586,
                  0.0387, -0.0130,  0.0160, -0.0071,  0.0238, -0.0182,  0.0513,
                  0.0203, -0.0594,  0.0764,  0.0039,  0.0482, -0.0049,  0.0076,
                  0.0010,  0.0637, -0.0508, -0.0416, -0.0563,  0.0328, -0.0321,
                  0.0374, -0.0392, -0.0263,  0.0755,  0.0021,  0.0135, -0.0780,
                  0.0898,  0.0912,  0.0785, -0.0205, -0.0145, -0.0646,  0.0547,
                  0.0566], grad_fn=<AliasBackward0>)
OutputFormating(0.0042, grad_fn=<AliasBackward0>)


In [203]:
state = gogame.init_state(var.BOARD_SIZE)
print(roll_policy(state, net))

-1.0


In [204]:
net = GoBot()
net.load_model("test.pth")

In [205]:
state = gogame.init_state(var.BOARD_SIZE)
root_node = Node(state, None, 1)
tree = MCTS(net, root_node)
for i in range(var.N_THRESHOLD + 3) :
    tree.push_search(tree.root)


In [None]:
game_1 = one_self_play_MCTS(net)

new root 
 ___ node info ___
action : None | depth : 0
N : 0, P : 1
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
 
 
policy :  tensor([[0.0163, 0.0163, 0.0163, 0.0163, 0.0163, 0.0163, 0.0230],
        [0.0163, 0.0364, 0.0230, 0.0163, 0.0230, 0.0163, 0.0163],
        [0.0163, 0.0230, 0.0000, 0.0163, 0.0163, 0.0163, 0.0163],
        [0.0325, 0.0163, 0.0163, 0.0163, 0.0230, 0.0325, 0.0163],
        [0.0671, 0.0163, 0.0163, 0.0163, 0.0163, 0.0163, 0.0230],
        [0.0163, 0.0282, 0.0163, 0.0230, 0.0230, 0.0163, 0.0230],
        [0.0163, 0.0163, 0.0163, 0.0230, 0.0163, 0.0230, 0.0230]])
pass :  0.02300431950313354
r :  0.16375092393801471
action :  8

 10.586579084396362
new root 
 ___ node info ___
action : (1, 1) | depth : 1
N : 5, P : OutputFormating(0.0219)
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0

In [207]:
print(game_1)

[[array([[[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
    

In [None]:
db = GoDatabase(collection_name = "test_conter")

db.save_one_game(episode = 0, game = game_1)

In [None]:
data = db.load_episodes(64)
