#Игра

In [None]:
import numpy as np
import math
import time

#import pyspiel

In [None]:
# Состояние игры крестики-нолики
class State_TicTacToe:
    def __init__(self, board_size=3, win_size=3):
        self.board_size=board_size
        self.win_size = win_size

        self._kernel = self._create_kernel()
        self.restart()

    def restart(self):
        self.state = np.zeros((self.board_size, self.board_size))
        self.turn = -1
        self.rewards = {-1: 0, +1: 0}
        self.is_done = False

    # Создает ядро свертки для расчета побед
    def _create_kernel(self):
        kernel = np.zeros((2 * self.win_size + 2, self.win_size, self.win_size))
        for i in range(self.win_size):
            kernel[i, i, :] = np.ones(self.win_size)
        for i in range(self.win_size, 2 * self.win_size):
            kernel[i, :, i - self.win_size] = np.ones(self.win_size).T
        kernel[2 * self.win_size] = np.eye(self.win_size)
        kernel[2 * self.win_size + 1] = np.fliplr(np.eye(self.win_size))
        return kernel


    # Проверяет победы для состояний states, в кот. ходы были совершены игроками turns, turn={-1, 1}
    def _test_win(self):
        rows, cols, w_size = *self.state.shape, self.win_size
        expanded_states = np.lib.stride_tricks.as_strided(
            self.state,
            shape=(rows - w_size + 1, cols - w_size + 1, w_size, w_size),
            strides=(*self.state.strides, *self.state.strides),
            writeable=False,
        )
        feature_map = np.einsum('xyij,sij->sxy', expanded_states, self._kernel)
        return -self.turn * (feature_map == self.turn * w_size).any().astype(int)


    # Выполнение хода и проверка на некорректный ход (проигрышь) / выигрыш / ничью
    def apply_action(self, action):
        # Проверка корректности хода
        if (self.state[(action)] != 0):
            self.rewards = {self.turn: -1, -self.turn: +1}
            self.is_done = True
            self.turn = 0
            return

        # Совершение хода
        self.state[action] = self.turn

        # Проверка победы
        if self._test_win():
            self.rewards = {self.turn: +1, -self.turn: -1}
            self.is_done = True
            self.turn = 0
            return

        # Проверка ничьи
        if (self.state != 0).all():
            self.is_done = True
            self.turn = 0
            return

        # Иначе, сменить ход
        self.turn = -self.turn

    # Возвращает True, если игра завершена
    def is_terminal(self):
        return self.is_done

    # Возвращает список результатов игры (награды) каждого игрока
    def returns(self):
        return np.array([self.rewards[-1], self.rewards[+1]])

    # Создает и возвращает копию текущего состояния
    def clone(self):
        cloned = State_TicTacToe(self.board_size, self.win_size)
        cloned.state = np.copy(self.state)
        cloned.turn = self.turn
        cloned.rewards = dict(self.rewards)
        cloned.is_done = self.is_done
        return cloned

    # Выводит на экран состояние игры
    def visualize_state(self):
        print(f"player {self.turn}'s turn:")
        print(str(self.state)
              .replace(".", "")
              .replace("[[", "")
              .replace(" [", "")
              .replace("]]", "")
              .replace("]", "")
              .replace("-0", " .")
              .replace("0", ".")
              .replace("-1", " X")
              .replace("1", "O")
        )

    # Возвращает список допустимых действий для текущего или указанного игрока
    def legal_actions(self, player=None):
        return list(zip(*np.where(self.state == 0)))

    # Выводит текущего игрока
    def current_player(self):
        return 0 if self.turn == -1 else 1

# MCTS

In [1]:
import numpy as np
import math
import time

In [2]:
# Простой evaluator, делающий случайные rollout'ы
class RandomRolloutEvaluator(object):
    """
    Evaluator возвращает средний выигрыш, совершая случайные действия из данного состояния, пока игра не закончится.
    n_rollouts - количество случайных rollout'ов.
    """

    def __init__(self, n_rollouts=1, random_state=None):
        self.n_rollouts = n_rollouts
        self._random_state = random_state or np.random.RandomState()

    # Возвращает V(s)
    def evaluate(self, state):
        result = 0
        for _ in range(self.n_rollouts):
            working_state = state.clone()
            while not working_state.is_terminal():
                actions = working_state.legal_actions()
                action = actions[self._random_state.choice(len(actions))]
                working_state.apply_action(action)
            result += np.array(working_state.returns())
        return result / self.n_rollouts

    # Возвращает 𝜋(a|s)
    def prior(self, state):
        legal_actions = state.legal_actions()
        return [(action, 1.0 / len(legal_actions)) for action in legal_actions]

In [3]:
# Вершина дерева поиска
class SearchNode(object):
    """
    A SearchNode represents a state and possible continuations from it. Each child
    represents a possible action, and the expected result from doing so.

    Attributes:
        action: Действие a из родительского узла s.
        player: Игрок, совершивший действие.
        prior: Вероятность выбора действия P(s,a).
        explore_count: Количество выбора данного действия N(s,a).
        total_reward: Сумма наград от родительского узла W(s,a).
            Средняя награда Q(s,a) = W(s,a) / N(s,a).
        outcome: The rewards for all players if this is a terminal node or the
            subtree has been proven, otherwise None.
        children: A list of SearchNodes representing the possible actions from this
            node, along with their expected rewards.
    """
    __slots__ = ["action", "player", "prior", "explore_count", "total_reward", "outcome", "children"]

    def __init__(self, action, player, prior):
        self.action = action      # a
        self.prior = prior        # P(s,a)
        self.explore_count = 0    # N(s,a)
        self.total_reward = 0.0   # W(s,a)

        self.player = player
        self.outcome = None
        self.children = []

    # Возвращает UCT дочернего узла
    def uct_value(self, parent_explore_count, uct_c):
        if self.outcome is not None:
            return self.outcome[self.player]

        if self.explore_count == 0:
            return float("inf")

        return self.total_reward / self.explore_count + \
            uct_c * math.sqrt(math.log(parent_explore_count) / self.explore_count)

    # Возвращает PUCT дочернего узла
    def puct_value(self, parent_explore_count, uct_c):
        if self.outcome is not None:
            return self.outcome[self.player]

        return ((self.explore_count and self.total_reward / self.explore_count) +
                uct_c * self.prior * math.sqrt(parent_explore_count) /
                (self.explore_count + 1))

    # Возвращает лучшее действие в вершине, либо подтвержденное или наиболее посещаемое
    def sort_key(self):
        """ Такой порядок приводит к выбору:
        - Наивысшего подтвержденного результата > 0 over anything else, including a promising but unproven action.
        - Подтвержденная ничья, только если она исследовалась чаще других неопределенных или проигрышных.
        - Неопределенное действие with most exploration over loss of any difficulty
        - Самый сложный проигрыш, если все проигрыши
        - Highest expected reward if explore counts are equal (маловероятно).
        - Longest win, if multiple are proven (unlikely due to early stopping).
        """
        return (0 if self.outcome is None else self.outcome[self.player],
                self.explore_count,
                self.total_reward)

    # Возвращает лучший дочерний узел в порядке ключа сортировки
    def best_child(self):
        return max(self.children, key=SearchNode.sort_key)

In [4]:
# Бот, использующий Monte-Carlo Tree Search алгоритм
class MCTSBot(object):
    def __init__(self,
                 uct_c,
                 max_simulations,
                 evaluator,
                 solve=True,
                 random_state=None,
                 child_selection_fn=SearchNode.uct_value,
                 dirichlet_noise=None,
                 verbose=False):

        self.max_utility = 1.0  #game.max_utility()
        self.uct_c = uct_c
        self.max_simulations = max_simulations
        self.evaluator = evaluator
        self.verbose = verbose
        self.solve = solve

        self._dirichlet_noise = dirichlet_noise
        self._random_state = random_state or np.random.RandomState()
        self._child_selection_fn = child_selection_fn

    # Возвращает политику бота и действие в данном состоянии
    def step_with_policy(self, state):
        t1 = time.time()
        root = self.mcts_search(state)
        best = root.best_child()

        if self.verbose:
            seconds = time.time() - t1
            print("Finished {} sims in {:.3f} secs, {:.1f} sims/s".format(
                root.explore_count, seconds, root.explore_count / seconds))

        mcts_action = best.action
        policy = [(action, (1.0 if action == mcts_action else 0.0))
                  for action in state.legal_actions(state.current_player())]

        return policy, mcts_action

    def step(self, state):
        root = self.mcts_search(state)
        return root.best_child().action

    # Применение UCT политики до достижения листовой вершины
    def _apply_tree_policy(self, root, state):
        """ Листовая вершина - терминальная или еще неоцененная вершина.

        Args:
            root: The root node in the search tree.
            state: The state of the game at the root node.

        Returns:
            visit_path: Лист вершин от корня к конечному узлу
            working_state: The state of the game at the leaf node.
        """

        visit_path = [root]
        working_state = state.clone()
        current_node = root
        while (not working_state.is_terminal() and current_node.explore_count > 0):
            # If it reaches a node that has been evaluated before but hasn't been expanded,
            # then expand it's children and continue.
            if not current_node.children:
                # For a new node, initialize its state, then choose a child as normal.
                legal_actions = self.evaluator.prior(working_state)
                if current_node is root and self._dirichlet_noise:
                    epsilon, alpha = self._dirichlet_noise
                    noise = self._random_state.dirichlet([alpha] * len(legal_actions))
                    legal_actions = [(a, (1 - epsilon) * p + epsilon * n)
                                    for (a, p), n in zip(legal_actions, noise)]
                # Reduce bias from move generation order.
                self._random_state.shuffle(legal_actions)
                player = working_state.current_player()
                current_node.children = [SearchNode(action, player, prior) for action, prior in legal_actions]

            chosen_child = max(
                current_node.children,
                key=lambda c: self._child_selection_fn(c, current_node.explore_count, self.uct_c)
            )

            working_state.apply_action(chosen_child.action)
            current_node = chosen_child
            visit_path.append(current_node)

        return visit_path, working_state

    # Ванильный Monte-Carlo Tree Search алгоритм
    def mcts_search(self, state):
        root = SearchNode(None, state.current_player(), 1)
        for _ in range(self.max_simulations):
            visit_path, working_state = self._apply_tree_policy(root, state)
            if working_state.is_terminal():
                returns = working_state.returns()
                visit_path[-1].outcome = returns
                solved = self.solve
            else:
                returns = self.evaluator.evaluate(working_state)
                solved = False

            while visit_path:
                node = visit_path.pop()
                node.total_reward += returns[node.player]
                node.explore_count += 1

                if solved and node.children:
                    # If any have max utility (won?), or all children are solved,
                    # choose the one best for the player choosing.
                    best = None
                    all_solved = True
                    for child in node.children:
                        if child.outcome is None:
                            all_solved = False
                        elif best is None or child.outcome[node.player] > best.outcome[node.player]:
                            best = child
                    if (best is not None and (all_solved or best.outcome[node.player] == self.max_utility)):
                        node.outcome = best.outcome
                    else:
                        solved = False
            if root.outcome is not None:
                break

        return root

## Тест

In [None]:
eval = RandomRolloutEvaluator(5)

In [None]:
state = State_TicTacToe(3, 3)

In [None]:
bot = MCTSBot(uct_c=1,
              max_simulations=200,
              evaluator=eval,
              solve=True,
              child_selection_fn=SearchNode.uct_value)

In [None]:
root = bot.mcts_search(state)

In [None]:
state.restart()
state.visualize_state()

while not state.is_terminal():
    #action = bot.step(state)

    root = bot.mcts_search(state)
    for child in root.children:
        print(child.action, child.total_reward, child.outcome, child.explore_count)
    action = root.best_child().action

    state.apply_action(action)
    state.visualize_state()

player -1's turn:
. . .
. . .
. . .
(1, 0) -0.4000000000000001 None 7
(0, 2) 14.000000000000002 None 32
(2, 2) 9.800000000000002 None 25
(0, 0) 10.799999999999997 None 27
(1, 2) 3.0 None 15
(0, 1) 4.6 None 16
(2, 0) 5.6 None 18
(1, 1) 22.6 None 45
(2, 1) 3.2000000000000006 None 14
player 1's turn:
 .  .  .
 .  X  .
 .  .  .
(0, 2) -17.600000000000005 None 44
(1, 0) -9.0 None 12
(2, 0) -15.8 None 35
(0, 0) -11.6 None 19
(0, 1) -8.399999999999999 None 11
(2, 1) -12.0 None 21
(1, 2) -11.2 None 18
(2, 2) -16.6 None 39
player -1's turn:
 .  .  O
 .  X  .
 .  .  .
(0, 0) 11.600000000000001 None 31
(2, 0) 5.6 None 21
(1, 2) 13.4 None 33
(2, 1) 1.6 None 12
(2, 2) 11.599999999999998 None 30
(0, 1) 27.0 None 53
(1, 0) 5.200000000000001 None 19
player 1's turn:
 .  X  O
 .  X  .
 .  .  .
(2, 0) -9.8 None 13
(1, 0) -12.6 None 20
(2, 2) -17.599999999999998 None 37
(1, 2) -21.600000000000005 None 50
(0, 0) -10.4 None 15
(2, 1) -25.4 None 64
player -1's turn:
 .  X  O
 .  X  .
 .  O  .
(1, 0) 8.39999

#Модель

In [5]:
"""An AlphaZero style model with a policy and value head"""

import collections
import functools
import os
from typing import Sequence

import numpy as np
import tensorflow.compat.v1 as tf

In [6]:
def cascade(x, fns):
    for fn in fns:
        x = fn(x)
    return x

tfkl = tf.keras.layers
conv_2d = functools.partial(tfkl.Conv2D, padding="same")


def batch_norm(training, updates, name):
    """A batch norm layer.

    Args:
      training: A placeholder of whether this is done in training or not.
      updates: A list to be extended with this layer's updates.
      name: Name of the layer.

    Returns:
      A function to apply to the previous layer.
    """
    bn = tfkl.BatchNormalization(name=name)
    def batch_norm_layer(x):
        # This emits a warning that training is a placeholder instead of a concrete
        # bool, but seems to work anyway.
        applied = bn(x, training)
        updates.extend(bn.updates)
        return applied
    return batch_norm_layer


def residual_layer(inputs, num_filters, kernel_size, training, updates, name):
  return cascade(inputs, [
      conv_2d(num_filters, kernel_size, name=f"{name}_res_conv1"),
      batch_norm(training, updates, f"{name}_res_batch_norm1"),
      tfkl.Activation("relu"),
      conv_2d(num_filters, kernel_size, name=f"{name}_res_conv2"),
      batch_norm(training, updates, f"{name}_res_batch_norm2"),
      lambda x: tfkl.add([x, inputs]),
      tfkl.Activation("relu"),
  ])

# Inputs for training the Model
class TrainInput(collections.namedtuple("TrainInput", "observation legals_mask policy value")):
    @staticmethod
    def stack(train_inputs):
        observation, legals_mask, policy, value = zip(*train_inputs)
        return TrainInput(
            np.array(observation, dtype=np.float32),
            np.array(legals_mask, dtype=bool),
            np.array(policy),
            np.expand_dims(value, 1)
        )


# An AlphaZero style model with a policy and value head
class Model(object):
    # Init a model. Use build_model, from_checkpoint or from_graph instead
    def __init__(self, session, saver, path):
        self._session = session
        self._saver = saver
        self._path = path

        def get_var(name):
            return self._session.graph.get_tensor_by_name(name + ":0")

        self._input = get_var("input")
        self._legals_mask = get_var("legals_mask")
        self._training = get_var("training")
        self._value_out = get_var("value_out")
        self._policy_softmax = get_var("policy_softmax")
        self._policy_loss = get_var("policy_loss")
        self._value_loss = get_var("value_loss")
        self._l2_reg_loss = get_var("l2_reg_loss")
        self._policy_targets = get_var("policy_targets")
        self._value_targets = get_var("value_targets")
        self._train = self._session.graph.get_operation_by_name("train")

    # Build a model with the specified params
    @classmethod
    def build_model(cls, model_type, input_shape, output_size, nn_width, nn_depth,
                    weight_decay, learning_rate, path):
        g = tf.Graph()  # Allow multiple independent models and graphs.
        with g.as_default():
            cls._define_graph(model_type, input_shape, output_size, nn_width,
                              nn_depth, weight_decay, learning_rate)
            init = tf.variables_initializer(tf.global_variables(), name="init_all_vars_op")
            with tf.device("/cpu:0"):  # Saver only works on CPU.
                saver = tf.train.Saver(max_to_keep=10000, sharded=False, name="saver")
        session = tf.Session(graph=g)
        session.__enter__()
        session.run(init)
        return cls(session, saver, path)

    # Load a model from a checkpoint
    @classmethod
    def from_checkpoint(cls, checkpoint, path=None):
        model = cls.from_graph(checkpoint, path)
        model.load_checkpoint(checkpoint)
        return model

    # Load only the model from a graph or checkpoint
    @classmethod
    def from_graph(cls, metagraph, path=None):
        if not os.path.exists(metagraph):
            metagraph += ".meta"
        if not path:
            path = os.path.dirname(metagraph)
        g = tf.Graph()  # Allow multiple independent models and graphs.
        with g.as_default():
            saver = tf.train.import_meta_graph(metagraph)
        session = tf.Session(graph=g)
        session.__enter__()
        session.run("init_all_vars_op")
        return cls(session, saver, path)

    def __del__(self):
        if hasattr(self, "_session") and self._session:
            self._session.close()

    # Define the model graph
    @staticmethod
    def _define_graph(model_type, input_shape, output_size, nn_width, nn_depth, weight_decay, learning_rate):
        # Inference inputs
        input_size = int(np.prod(input_shape))
        observations = tf.placeholder(tf.float32, [None, input_size], name="input")
        legals_mask = tf.placeholder(tf.bool, [None, output_size], name="legals_mask")
        training = tf.placeholder(tf.bool, name="training")

        bn_updates = []

        # Main torso of the network
        if model_type == "resnet":
            torso = cascade(observations, [
                tfkl.Reshape(input_shape),
                conv_2d(nn_width, 3, name="torso_in_conv"),
                batch_norm(training, bn_updates, "torso_in_batch_norm"),
                tfkl.Activation("relu"),
            ])
            for i in range(nn_depth):
                torso = residual_layer(torso, nn_width, 3, training, bn_updates, f"torso_{i}")

        policy_head = cascade(torso, [
            conv_2d(filters=2, kernel_size=1, name="policy_conv"),
            batch_norm(training, bn_updates, "policy_batch_norm"),
            tfkl.Activation("relu"),
            tfkl.Flatten(),
        ])
        policy_logits = tfkl.Dense(output_size, name="policy")(policy_head)
        policy_logits = tf.where(legals_mask, policy_logits, -1e32 * tf.ones_like(policy_logits))
        unused_policy_softmax = tf.identity(tfkl.Softmax()(policy_logits), name="policy_softmax")
        policy_targets = tf.placeholder(shape=[None, output_size], dtype=tf.float32, name="policy_targets")
        policy_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=policy_logits, labels=policy_targets),
            name="policy_loss"
        )

        value_head = cascade(torso, [
            conv_2d(filters=1, kernel_size=1, name="value_conv"),
            batch_norm(training, bn_updates, "value_batch_norm"),
            tfkl.Activation("relu"),
            tfkl.Flatten(),
        ])
        value_out = cascade(value_head, [
            tfkl.Dense(nn_width, name="value_dense"),
            tfkl.Activation("relu"),
            tfkl.Dense(1, name="value"),
            tfkl.Activation("tanh"),
        ])
        # Need the identity to name the single value output from the dense layer.
        value_out = tf.identity(value_out, name="value_out")
        value_targets = tf.placeholder(shape=[None, 1], dtype=tf.float32, name="value_targets")
        value_loss = tf.identity(tf.losses.mean_squared_error(value_out, value_targets), name="value_loss")

        l2_reg_loss = tf.add_n([
            weight_decay * tf.nn.l2_loss(var)
            for var in tf.trainable_variables()
            if "/bias:" not in var.name
        ], name="l2_reg_loss")

        total_loss = policy_loss + value_loss + l2_reg_loss
        optimizer = tf.train.AdamOptimizer(learning_rate)
        with tf.control_dependencies(bn_updates):
            unused_train = optimizer.minimize(total_loss, name="train")

    @property
    def num_trainable_variables(self):
        return sum(np.prod(v.shape) for v in tf.trainable_variables())

    def inference(self, observation, legals_mask):
        return self._session.run(
            [self._value_out, self._policy_softmax],
            feed_dict={self._input: np.array(observation, dtype=np.float32),
                       self._legals_mask: np.array(legals_mask, dtype=bool),
                       self._training: False})

    # Runs a training step
    def update(self, train_inputs: Sequence[TrainInput]):
        batch = TrainInput.stack(train_inputs)

        # Run a training step and get the losses.
        _, policy_loss, value_loss, l2_reg_loss = self._session.run(
            [self._train, self._policy_loss, self._value_loss, self._l2_reg_loss],
            feed_dict={self._input: batch.observation,
                       self._legals_mask: batch.legals_mask,
                       self._policy_targets: batch.policy,
                       self._value_targets: batch.value,
                       self._training: True})

        return policy_loss, value_loss, l2_reg_loss

    def save_checkpoint(self, step):
        return self._saver.save(
            self._session,
            os.path.join(self._path, "checkpoint"),
            global_step=step)

    def load_checkpoint(self, path):
        return self._saver.restore(self._session, path)

# Alpha Zero

In [None]:
!pip install open_spiel

In [7]:
import numpy as np
import collections
import traceback
import functools
import itertools
import datetime
import random
import time
import json
import os

#from open_spiel.python.algorithms.alpha_zero import model as model_lib
import pyspiel

from open_spiel.python.utils import lru_cache
from open_spiel.python.utils import file_logger
from open_spiel.python.utils import spawn
from open_spiel.python.utils import stats

In [8]:
# An AlphaZero MCTS Evaluator
class AlphaZeroEvaluator(object):
    def __init__(self, model, cache_size=2**16):
        self._model = model
        self._cache = lru_cache.LRUCache(cache_size)

    def cache_info(self):
        return self._cache.info()

    def clear_cache(self):
        self._cache.clear()

    def _inference(self, state):
        # Make a singleton batch
        obs = np.expand_dims(state.observation_tensor(), 0)
        mask = np.expand_dims(state.legal_actions_mask(), 0)

        # ndarray isn't hashable
        cache_key = obs.tobytes() + mask.tobytes()
        value, policy = self._cache.make(cache_key, lambda: self._model.inference(obs, mask))
        return value[0, 0], policy[0]  # Unpack batch

    # Returns a value for the given state
    def evaluate(self, state):
        value, _ = self._inference(state)
        return np.array([value, -value])

    # Returns the probabilities for all actions.
    def prior(self, state):
        _, policy = self._inference(state)
        return [(action, policy[action]) for action in state.legal_actions()]

In [9]:
# Time to wait for processes to join.
JOIN_WAIT_DELAY = 0.001


# A particular point along a trajectory
class TrajectoryState(object):
    def __init__(self, observation, current_player, legals_mask, action, policy, value):
        self.observation = observation
        self.current_player = current_player
        self.legals_mask = legals_mask
        self.action = action
        self.policy = policy
        self.value = value


# A sequence of (observations, actions and policies), and the outcomes
class Trajectory(object):
    def __init__(self):
        self.states = []
        self.returns = None


# A fixed size buffer that keeps the newest values
class Buffer(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.data = []
        self.total_seen = 0  # The number of items that have passed through.

    def __len__(self):
        return len(self.data)

    def append(self, val):
        return self.extend([val])

    def extend(self, batch):
        batch = list(batch)
        self.total_seen += len(batch)
        self.data.extend(batch)
        self.data[:-self.max_size] = []

    def sample(self, count):
        return random.sample(self.data, count)


# A config for the model/experiment
class Config(collections.namedtuple(
    "Config", [
        "game",
        "path",
        "learning_rate",
        "weight_decay",
        "train_batch_size",
        "replay_buffer_size",
        "replay_buffer_reuse",
        "max_steps",
        "checkpoint_freq",
        "actors",
        "evaluators",
        "evaluation_window",
        "eval_levels",

        "uct_c",
        "max_simulations",
        "policy_alpha",
        "policy_epsilon",
        "temperature",
        "temperature_drop",

        "nn_model",
        "nn_width",
        "nn_depth",
        "observation_shape",
        "output_size",

        "quiet",
    ])):
  pass


def _init_model_from_config(config):
    return Model.build_model(
        config.nn_model,
        config.observation_shape,
        config.output_size,
        config.nn_width,
        config.nn_depth,
        config.weight_decay,
        config.learning_rate,
        config.path)


# A decorator to fn/processes that gives a logger and logs exceptions
def watcher(fn):
    # Wrap the decorated function
    @functools.wraps(fn)
    def _watcher(*, config, num=None, **kwargs):
        name = fn.__name__
        if num is not None:
            name += "-" + str(num)
        with file_logger.FileLogger(config.path, name, config.quiet) as logger:
            print(f'{name} started')
            logger.print(f"{name} started")
            try:
                return fn(config=config, logger=logger, **kwargs)
            except Exception as e:
                logger.print(f"\n{' Exception caught '.center(60, '=')}{traceback.format_exc()}{'=' * 60}")
                print(f"Exception caught in {name}: {e}")
                raise
            finally:
                logger.print(f"{name} exiting")
                print(f"{name} exiting", end='\n\n')
    return _watcher


# Initializes a bot
def _init_bot(config, game, evaluator_, evaluation):
    noise = None if evaluation else (config.policy_epsilon, config.policy_alpha)
    return MCTSBot(
        config.uct_c,
        config.max_simulations,
        evaluator_,
        solve=False,
        dirichlet_noise=noise,
        child_selection_fn=SearchNode.puct_value,
        verbose=False)

In [10]:
# Play one game, return the trajectory
def _play_game(logger, game_num, game, bots, temperature, temperature_drop):
    trajectory = Trajectory()
    actions = []
    state = game.new_initial_state()
    random_state = np.random.RandomState()
    #logger.opt_print(" Starting game {} ".format(game_num).center(60, "-"))
    #logger.opt_print("Initial state:\n{}".format(state))
    while not state.is_terminal():
        root = bots[state.current_player()].mcts_search(state)
        policy = np.zeros(game.num_distinct_actions())
        for c in root.children:
            policy[c.action] = c.explore_count
        policy = policy**(1 / temperature)
        policy /= policy.sum()
        if len(actions) >= temperature_drop:
            action = root.best_child().action
        else:
            action = np.random.choice(len(policy), p=policy)
        trajectory.states.append(
            TrajectoryState(state.observation_tensor(), state.current_player(),
                            state.legal_actions_mask(), action, policy,
                            root.total_reward / root.explore_count))
        action_str = state.action_to_string(state.current_player(), action)
        actions.append(action_str)
        #logger.opt_print("Player {} sampled action: {}".format(state.current_player(), action_str))
        state.apply_action(action)
    #logger.opt_print("Next state:\n{}".format(state))

    trajectory.returns = state.returns()
    #logger.print("Game {}: Returns: {}; Actions: {}".format(game_num, " ".join(map(str, trajectory.returns)), " ".join(actions)))
    return trajectory


# Read the queue for a checkpoint to load, or an exit signal
def update_checkpoint(logger, queue, model, az_evaluator):
    path = None
    while True:  # Get the last message, ignore intermediate ones.
        try:
            path = queue.get_nowait()
        except spawn.Empty:
            break
    if path:
        logger.print("Inference cache:", az_evaluator.cache_info())
        logger.print("Loading checkpoint", path)
        model.load_checkpoint(path)
        az_evaluator.clear_cache()
    elif path is not None:  # Empty string means stop this process.
        return False
    return True


# An actor process runner that generates games and returns trajectories
@watcher
def actor(*, config, game, logger, queue):
    logger.print("Initializing model")
    model = _init_model_from_config(config)

    logger.print("Initializing bots")
    az_evaluator = AlphaZeroEvaluator(model)
    bots = [_init_bot(config, game, az_evaluator, False),
            _init_bot(config, game, az_evaluator, False)]

    for game_num in itertools.count():
        if not update_checkpoint(logger, queue, model, az_evaluator):
            return
        queue.put(_play_game(logger, game_num, game, bots, config.temperature, config.temperature_drop))


# A process that plays the latest checkpoint vs standard MCTS
@watcher
def evaluator(*, game, config, logger, queue):
    results = Buffer(config.evaluation_window)

    logger.print("Initializing model")
    model = _init_model_from_config(config)

    logger.print("Initializing bots")
    az_evaluator = AlphaZeroEvaluator(model)
    random_evaluator = RandomRolloutEvaluator()

    for game_num in itertools.count():
        if not update_checkpoint(logger, queue, model, az_evaluator):
            return

        az_player = game_num % 2
        difficulty = (game_num // 2) % config.eval_levels
        max_simulations = int(config.max_simulations * (10 ** (difficulty / 2)))
        bots = [
            _init_bot(config, game, az_evaluator, True),
            MCTSBot(
                config.uct_c,
                max_simulations,
                random_evaluator,
                solve=True,
                verbose=False)
        ]
        if az_player == 1:
            bots = list(reversed(bots))

        trajectory = _play_game(logger, game_num, game, bots, temperature=1, temperature_drop=0)
        results.append(trajectory.returns[az_player])

        logger.print(f"AZ: {trajectory.returns[az_player]}, MCTS: {trajectory.returns[1 - az_player]}, AZ avg/{len(results)}: {np.mean(results.data):.3f}")

In [11]:
# A learner that consumes the replay buffer and trains the network
@watcher
def learner(*, game, config, actors, evaluators, broadcast_fn, logger):
    logger.also_to_stdout = True
    replay_buffer = Buffer(config.replay_buffer_size)
    learn_rate = config.replay_buffer_size // config.replay_buffer_reuse

    logger.print("Initializing model")
    model = _init_model_from_config(config)
    logger.print("Model type: %s(%s, %s)" % (config.nn_model, config.nn_width, config.nn_depth))
    logger.print("Model size:", model.num_trainable_variables, "variables")

    save_path = model.save_checkpoint(0)
    logger.print("Initial checkpoint:", save_path)
    broadcast_fn(save_path)

    game_lengths = stats.BasicStats()
    game_lengths_hist = stats.HistogramNumbered(game.max_game_length() + 1)
    outcomes = stats.HistogramNamed(["Player1", "Player2", "Draw"])
    total_trajectories = 0

    # Merge all the actor queues into a single generator
    def trajectory_generator():
        while True:
            found = 0
            for actor_process in actors:
                try:
                    yield actor_process.queue.get_nowait()
                except spawn.Empty:
                    pass
                else:
                    found += 1
            if found == 0:
                time.sleep(0.01)  # 10ms

    # Collects the trajectories from actors into the replay buffer
    def collect_trajectories():
        num_trajectories = 0
        num_states = 0
        for trajectory in trajectory_generator():
            num_trajectories += 1
            num_states += len(trajectory.states)
            game_lengths.add(len(trajectory.states))
            game_lengths_hist.add(len(trajectory.states))

            p1_outcome = trajectory.returns[0]
            outcomes.add(0*(p1_outcome > 0) + 1*(p1_outcome < 0) + 2*(p1_outcome == 0))

            replay_buffer.extend(
                TrainInput(s.observation, s.legals_mask, s.policy, p1_outcome)
                for s in trajectory.states)

            if num_states >= learn_rate:
                break
        return num_trajectories, num_states

    # Sample from the replay buffer, update weights and save a checkpoint
    def learn(step):
        losses = []
        for _ in range(len(replay_buffer) // config.train_batch_size):
            data = replay_buffer.sample(config.train_batch_size)
            policy_loss, value_loss, l2_reg_loss = model.update(data)
            losses.append([policy_loss, value_loss, l2_reg_loss])

        # Always save a checkpoint, either for keeping or for loading the weights to
        # the actors. It only allows numbers, so use -1 as "latest".
        save_path = model.save_checkpoint(
            step if step % config.checkpoint_freq == 0 else -1)

        policy_loss, value_loss, l2_reg_loss = np.mean(losses, axis=0)
        total_loss = policy_loss + value_loss + l2_reg_loss
        logger.print(f"Losses(total: {total_loss:.3f}, policy: {policy_loss:.3f}, value: {value_loss:.3f}, l2: {l2_reg_loss:.3f})")
        logger.print("Checkpoint saved:", save_path)
        return save_path, (policy_loss, value_loss, l2_reg_loss)

    last_time = time.time() - 60
    for step in itertools.count(1):
        game_lengths.reset()
        game_lengths_hist.reset()
        outcomes.reset()

        num_trajectories, num_states = collect_trajectories()
        total_trajectories += num_trajectories
        now = time.time()
        seconds = now - last_time
        last_time = now

        save_path, losses = learn(step)

        logger.print(f"Step: {step}")
        logger.print(("Collected {:5} states from {:3} games, {:.1f} states/s. "
                      "{:.1f} states/(s*actor), game length: {:.1f}").format(
                num_states, num_trajectories, num_states / seconds,
                num_states / (config.actors * seconds),
                num_states / num_trajectories))
        logger.print("game_length_hist:", game_lengths_hist.data)
        logger.print("outcomes", outcomes.data)
        logger.print(f"Buffer size: {len(replay_buffer)}. States seen: {replay_buffer.total_seen}")
        logger.print()

        if config.max_steps > 0 and step >= config.max_steps:
            break

        broadcast_fn(save_path)

In [12]:
# Start all the worker processes for a full alphazero setup
def alpha_zero(config: Config):
    # --- Загрузка игры ---
    game = pyspiel.load_game(config.game)
    config = config._replace(
        observation_shape=game.observation_tensor_shape(),
        output_size=game.num_distinct_actions())
    print("Starting game", config.game)

    # --- Создание папки для сохранения logs, checkpoints, config ---
    if not os.path.exists(config.path):
        os.makedirs(config.path)
    print("Writing logs and checkpoints to:", config.path)

    with open(os.path.join(config.path, "config.json"), "w") as fp:
        fp.write(json.dumps(config._asdict(), indent=2, sort_keys=True) + "\n")

    print("Model type: %s(%s, %s)" % (config.nn_model, config.nn_width, config.nn_depth))

    actors = [spawn.Process(actor, kwargs={"game": game, "config": config, "num": i})
              for i in range(config.actors)]
    evaluators = [spawn.Process(evaluator, kwargs={"game": game, "config": config, "num": i})
                  for i in range(config.evaluators)]

    def broadcast(msg):
        for proc in actors + evaluators:
            proc.queue.put(msg)

    try:
        learner(game=game, config=config, actors=actors, evaluators=evaluators, broadcast_fn=broadcast)
    except (KeyboardInterrupt, EOFError):
        print("Caught a KeyboardInterrupt, stopping early.")
    finally:
        broadcast("")
        # for actor processes to join we have to make sure that their q_in is empty, including backed up items
        for proc in actors:
            while proc.exitcode is None:
                while not proc.queue.empty():
                    proc.queue.get_nowait()
                proc.join(JOIN_WAIT_DELAY)
        for proc in evaluators:
            proc.join()

In [13]:
config = Config(
    game="tic_tac_toe",
    path="az-{}".format(datetime.datetime.now().strftime("%d.%m.%y-%H:%M")),
    learning_rate=0.01,
    weight_decay=1e-4,
    train_batch_size=128,
    replay_buffer_size=2**14,
    replay_buffer_reuse=4,
    max_steps=25,
    checkpoint_freq=25,

    actors=4,
    evaluators=4,
    uct_c=1,
    max_simulations=20,
    policy_alpha=0.25,
    policy_epsilon=1,
    temperature=1,
    temperature_drop=4,
    evaluation_window=50,
    eval_levels=7,

    nn_model="resnet",
    nn_width=128,
    nn_depth=2,
    observation_shape=None,
    output_size=None,

    quiet=True,
)

In [None]:
# НА CPU (50 минут)

alpha_zero(config)

Starting game tic_tac_toe
Writing logs and checkpoints to: az-03.07.24-21:37
Model type: resnet(128, 2)


  self.pid = os.fork()


actor-0 started
actor-1 started
actor-2 started
actor-3 started
evaluator-0 started
evaluator-1 started
evaluator-2 startedlearner started
[2024-07-03 21:37:28.914] Initializing model

evaluator-3 started


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.


[2024-07-03 21:38:00.227] Model type: resnet(128, 2)
[2024-07-03 21:38:00.240] Model size: 597173 variables
[2024-07-03 21:38:05.198] Initial checkpoint: az-03.07.24-21:37/checkpoint-0
[2024-07-03 21:39:34.747] Losses(total: 2.252, policy: 1.614, value: 0.519, l2: 0.119)
[2024-07-03 21:39:34.762] Checkpoint saved: az-03.07.24-21:37/checkpoint--1
[2024-07-03 21:39:34.773] Step: 1
[2024-07-03 21:39:34.778] Collected  4103 states from 557 games, 39.2 states/s. 9.8 states/(s*actor), game length: 7.4
[2024-07-03 21:39:34.790] game_length_hist: [0, 0, 0, 0, 0, 118, 43, 127, 55, 214]
[2024-07-03 21:39:34.815] outcomes {'counts': [273, 98, 186], 'names': ['Player1', 'Player2', 'Draw']}
[2024-07-03 21:39:34.833] Buffer size: 4103. States seen: 4103
[2024-07-03 21:39:34.847] Loss. policy: 1.6136566027998924. value: 0.519357968121767. l2reg: 0.11893886514008045. sum: 2.25195343606174
[2024-07-03 21:39:34.858]
[2024-07-03 21:40:53.072] Losses(total: 2.216, policy: 1.590, value: 0.488, l2: 0.139)
[

In [None]:
# НА GPU (30 минут)

alpha_zero(config)

Starting game tic_tac_toe
Writing logs and checkpoints to: az-06.07.24-07:01
Model type: resnet(128, 2)


  self.pid = os.fork()


actor-0 started
actor-1 started
actor-2 started
actor-3 started
evaluator-0 startedevaluator-1 started

evaluator-2 startedlearner started
[2024-07-06 07:01:23.809] Initializing model

evaluator-3 started


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Colocations handled automatically by placer.


[2024-07-06 07:01:48.499] Model type: resnet(128, 2)
[2024-07-06 07:01:48.504] Model size: 597173 variables
[2024-07-06 07:01:52.839] Initial checkpoint: az-06.07.24-07:01/checkpoint-0
[2024-07-06 07:04:10.161] Losses(total: 2.274, policy: 1.626, value: 0.530, l2: 0.118)
[2024-07-06 07:04:10.182] Checkpoint saved: az-06.07.24-07:01/checkpoint--1
[2024-07-06 07:04:10.186] Step: 1
[2024-07-06 07:04:10.204] Collected  4102 states from 561 games, 26.3 states/s. 6.6 states/(s*actor), game length: 7.3
[2024-07-06 07:04:10.213] game_length_hist: [0, 0, 0, 0, 0, 136, 33, 118, 68, 206]
[2024-07-06 07:04:10.218] outcomes {'counts': [278, 101, 182], 'names': ['Player1', 'Player2', 'Draw']}
[2024-07-06 07:04:10.226] Buffer size: 4102. States seen: 4102
[2024-07-06 07:04:10.245]
[2024-07-06 07:04:25.080] Losses(total: 2.204, policy: 1.600, value: 0.470, l2: 0.134)
[2024-07-06 07:04:25.100] Checkpoint saved: az-06.07.24-07:01/checkpoint--1
[2024-07-06 07:04:25.106] Step: 2
[2024-07-06 07:04:25.113] 

##Игра с обученной моделью (CPU)

In [None]:
game = pyspiel.load_game("tic_tac_toe")
config = config._replace(
    observation_shape=game.observation_tensor_shape(),
    output_size=game.num_distinct_actions())

In [None]:
config

Config(game='tic_tac_toe', path='az-04.07.24-11:46', learning_rate=0.01, weight_decay=0.0001, train_batch_size=128, replay_buffer_size=16384, replay_buffer_reuse=4, max_steps=25, checkpoint_freq=25, actors=4, evaluators=4, evaluation_window=50, eval_levels=7, uct_c=1, max_simulations=20, policy_alpha=0.25, policy_epsilon=1, temperature=1, temperature_drop=4, nn_model='resnet', nn_width=128, nn_depth=2, observation_shape=[3, 3, 3], output_size=9, quiet=True)

In [None]:
model = _init_model_from_config(config)
model.load_checkpoint("/content/drive/MyDrive/az-03.07.24-21:37/checkpoint-25")

az_evaluator = AlphaZeroEvaluator(model)
az = _init_bot(config, game, az_evaluator, True)

In [None]:
state = game.new_initial_state()

In [None]:
def print_stats(root):
    print("action \t  reward       outcome\t explore_count")
    for child in root.children:
        mean_reward = child.total_reward / child.explore_count if child.explore_count > 0 else 0
        print(f"{child.action}\t {round(mean_reward, 5):7}\t {child.outcome}\t {child.explore_count}")

In [None]:
# Ход соперника
root = az.mcts_search(state)
print_stats(root)
state.apply_action(root.best_child().action)
state

action 	  reward       outcome	 explore_count
6	       0	 None	 0
4	  0.4174	 None	 11
5	       0	 None	 0
3	       0	 None	 0
2	 0.47469	 None	 7
1	       0	 None	 0
0	 0.39978	 None	 1
8	       0	 None	 0
7	       0	 None	 0


...
.x.
...

In [None]:
# Мой ход
print_stats(az.mcts_search(state))
state.apply_action(0)
state

action 	  reward       outcome	 explore_count
0	 -0.14019	 None	 5
7	 -0.6125	 None	 1
3	 -0.79627	 None	 1
8	 -0.23917	 None	 2
5	 -0.71531	 None	 1
2	 -0.21865	 None	 5
6	 -0.10097	 None	 3
1	 -0.4682	 None	 1


o..
.x.
...

In [None]:
# Ход соперника
root = az.mcts_search(state)
print_stats(root)
state.apply_action(root.best_child().action)
state

action 	  reward       outcome	 explore_count
5	       0	 None	 0
2	 -0.02314	 None	 1
3	 0.19252	 None	 2
6	 0.06331	 None	 2
7	  0.3599	 None	 12
8	       0	 None	 0
1	 0.12811	 None	 2


o..
.x.
.x.

In [None]:
# Мой ход
print_stats(az.mcts_search(state))
state.apply_action(1)
state

action 	  reward       outcome	 explore_count
1	  0.0342	 None	 14
6	 -0.74291	 None	 1
2	 -0.88804	 None	 1
5	 -0.84577	 None	 1
8	 -0.98804	 None	 1
3	 -0.93622	 None	 1


oo.
.x.
.x.

In [None]:
# Ход соперника
root = az.mcts_search(state)
print_stats(root)
state.apply_action(root.best_child().action)
state

action 	  reward       outcome	 explore_count
8	 -0.51336	 None	 1
6	 -0.58918	 None	 1
5	 -0.48038	 None	 2
2	 0.03918	 None	 14
3	 -0.01541	 None	 1


oox
.x.
.x.

In [None]:
# Мой ход
print_stats(az.mcts_search(state))
state.apply_action(6)
state

action 	  reward       outcome	 explore_count
8	 -0.73398	 None	 1
3	 -0.87934	 None	 1
5	 -0.22724	 None	 1
6	 0.03032	 None	 16


oox
.x.
ox.

In [None]:
# Ход соперника
root = az.mcts_search(state)
print_stats(root)
state.apply_action(root.best_child().action)
state

action 	  reward       outcome	 explore_count
5	 -0.51312	 None	 2
8	 -0.17318	 None	 1
3	 0.04609	 None	 16


oox
xx.
ox.

In [None]:
# Мой ход
print_stats(az.mcts_search(state))
state.apply_action(5)
state

action 	  reward       outcome	 explore_count
8	 -0.78368	 None	 1
5	 0.00129	 None	 18


oox
xxo
ox.

In [None]:
# Ход соперника
root = az.mcts_search(state)
print_stats(root)
state.apply_action(root.best_child().action)
state

action 	  reward       outcome	 explore_count
8	     0.0	 [0.0, 0.0]	 19


oox
xxo
oxx

In [None]:
state = game.new_initial_state()

In [None]:
root = SearchNode(None, state.current_player(), 1)

In [None]:
visit_path, working_state = az._apply_tree_policy(root, state)

In [None]:
visit_path

[<__main__.SearchNode at 0x7c5e8a2141c0>]

In [None]:
az.evaluator.evaluate(working_state)

array([ 0.40663803, -0.40663803], dtype=float32)

# AlphaZero (OLD)

In [None]:
"""
This implements the AlphaZero training algorithm. It spawns N actors which feed
trajectories into a replay buffer which are consumed by a learner. The learner
generates new weights, saves a checkpoint, and tells the actors to update. There
are also M evaluators running games continuously against a standard MCTS+Solver,
though each at a different difficulty (ie number of simulations for MCTS).

Due to the multi-process nature of this algorithm the logs are written to files,
one per process. The learner logs are also output to stdout. The checkpoints are
also written to the same directory.

Links to relevant articles/papers:
  https://deepmind.com/blog/article/alphago-zero-starting-scratch has an open
    access link to the AlphaGo Zero nature paper.
  https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go
    has an open access link to the AlphaZero science paper.
"""

'\nThis implements the AlphaZero training algorithm. It spawns N actors which feed\ntrajectories into a replay buffer which are consumed by a learner. The learner\ngenerates new weights, saves a checkpoint, and tells the actors to update. There\nare also M evaluators running games continuously against a standard MCTS+Solver,\nthough each at a different difficulty (ie number of simulations for MCTS).\n\nDue to the multi-process nature of this algorithm the logs are written to files,\none per process. The learner logs are also output to stdout. The checkpoints are\nalso written to the same directory.\n\nLinks to relevant articles/papers:\n  https://deepmind.com/blog/article/alphago-zero-starting-scratch has an open\n    access link to the AlphaGo Zero nature paper.\n  https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go\n    has an open access link to the AlphaZero science paper.\n'

In [None]:
import collections
import datetime
import functools
import itertools
import json
import os
import random
import sys
import tempfile
import time
import traceback

import pyspiel
from open_spiel.python.utils import stats
from open_spiel.python.utils import spawn
from open_spiel.python.algorithms.alpha_zero import model as model_lib



In [None]:
"""An MCTS Evaluator for an AlphaZero model."""

from open_spiel.python.utils import lru_cache

class AlphaZeroEvaluator(object):
    def __init__(self, model, cache_size=2**16):
        self._model = model
        self._cache = lru_cache.LRUCache(cache_size)

    def cache_info(self):
        return self._cache.info()

    def clear_cache(self):
        self._cache.clear()

    def _inference(self, state):
        # Make a singleton batch
        obs = np.expand_dims(state.observation_tensor(), 0)
        mask = np.expand_dims(state.legal_actions_mask(), 0)

        # ndarray isn't hashable
        cache_key = obs.tobytes() + mask.tobytes()

        value, policy = self._cache.make(cache_key, lambda: self._model.inference(obs, mask))

        return value[0, 0], policy[0]  # Unpack batch

    # Returns a value for the given state
    def evaluate(self, state):
        value, _ = self._inference(state)
        return np.array([value, -value])

    def prior(self, state):
        # Returns the probabilities for all actions.
        _, policy = self._inference(state)
        return [(action, policy[action]) for action in state.legal_actions()]

In [None]:
# A particular point along a trajectory
class TrajectoryState(object):
    def __init__(self, observation, current_player, legals_mask, action, policy, value):
        self.observation = observation
        self.current_player = current_player
        self.legals_mask = legals_mask
        self.action = action
        self.policy = policy
        self.value = value

In [None]:
# A sequence of observations, actions and policies, and the outcomes
class Trajectory(object):
    def __init__(self):
        self.states = []
        self.returns = None

    def add(self, information_state, action, policy):
        self.states.append((information_state, action, policy))

In [None]:
# A fixed size buffer that keeps the newest values
class Buffer(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.data = []
        self.total_seen = 0  # The number of items that have passed through.

    def __len__(self):
        return len(self.data)

    def __bool__(self):
        return bool(self.data)

    def append(self, val):
        return self.extend([val])

    def extend(self, batch):
        batch = list(batch)
        self.total_seen += len(batch)
        self.data.extend(batch)
        self.data[:-self.max_size] = []

    def sample(self, count):
        return random.sample(self.data, count)

In [None]:
# A config for the model/experiment
class Config(collections.namedtuple(
    "Config", [
        "game",
        "path",
        "learning_rate",
        "weight_decay",
        "train_batch_size",
        "replay_buffer_size",
        "replay_buffer_reuse",
        "max_steps",
        "checkpoint_freq",
        "actors",
        "evaluators",
        "evaluation_window",
        "eval_levels",

        "uct_c",
        "max_simulations",
        "policy_alpha",
        "policy_epsilon",
        "temperature",
        "temperature_drop",

        "nn_model",
        "nn_width",
        "nn_depth",
        "observation_shape",
        "output_size",

        "quiet",
    ])):
    pass

In [None]:
def _init_model_from_config(config):
  return model_lib.Model.build_model(
      config.nn_model,
      config.observation_shape,
      config.output_size,
      config.nn_width,
      config.nn_depth,
      config.weight_decay,
      config.learning_rate,
      config.path)

In [None]:
import datetime
import os

# A logger to print stuff to a file
class FileLogger(object):

  def __init__(self, path, name, quiet=False, also_to_stdout=False):
    self._fd = open(os.path.join(path, "log-{}.txt".format(name)), "w")
    self._quiet = quiet
    self.also_to_stdout = also_to_stdout

  def print(self, *args):
    # Date/time with millisecond precision.
    date_prefix = "[{}]".format(datetime.datetime.now().isoformat(" ")[:-3])
    print(date_prefix, *args, file=self._fd, flush=True)
    if self.also_to_stdout:
      print(date_prefix, *args, flush=True)

  def opt_print(self, *args):
    if not self._quiet:
      self.print(*args)

  def __enter__(self):
    return self

  def __exit__(self, unused_exception_type, unused_exc_value, unused_traceback):
    self.close()

  def close(self):
    if self._fd:
      self._fd.close()
      self._fd = None

  def __del__(self):
    self.close()

In [None]:
"""Log data to a jsonl file."""

import datetime
import json
import os
import time
from typing import Any, Dict, Text

from open_spiel.python.utils import gfile


class DataLoggerJsonLines:
  """Log data to a jsonl file."""

  def __init__(self, path: str, name: str, flush=True):
    self._fd = gfile.Open(os.path.join(path, name + ".jsonl"), "w")
    self._flush = flush
    self._start_time = time.time()

  def __del__(self):
    self.close()

  def close(self):
    if hasattr(self, "_fd") and self._fd is not None:
      self._fd.flush()
      self._fd.close()
      self._fd = None

  def flush(self):
    self._fd.flush()

  def write(self, data: Dict[Text, Any]):
    now = time.time()
    data["time_abs"] = now
    data["time_rel"] = now - self._start_time
    dt_now = datetime.datetime.utcfromtimestamp(now)
    data["time_str"] = dt_now.strftime("%Y-%m-%d %H:%M:%S.%f +0000")
    self._fd.write(json.dumps(data))
    self._fd.write("\n")
    if self._flush:
      self.flush()

In [None]:
# Декоратор для fn/processes that gives a logger and logs exceptions
def watcher(fn):
    # Wrap the decorated function
    @functools.wraps(fn)
    def _watcher(*, config, num=None, **kwargs):
        name = fn.__name__
        if num is not None:
            name += "-" + str(num)
        with FileLogger(config.path, name, config.quiet) as logger:
            print("{} started".format(name))
            logger.print("{} started".format(name))
            try:
                return fn(config=config, logger=logger, **kwargs)
            except Exception as e:
                logger.print("\n".join([
                    "",
                    " Exception caught ".center(60, "="),
                    traceback.format_exc(),
                    "=" * 60,
                ]))
                print("Exception caught in {}: {}".format(name, e))
                raise
            finally:
                logger.print("{} exiting".format(name))
                print("{} exiting".format(name))
    return _watcher

In [None]:
# Инициализация бота
def _init_bot(config, game, evaluator_, evaluation):
    noise = None if evaluation else (config.policy_epsilon, config.policy_alpha)
    return MCTSBot(
        game,
        config.uct_c,
        config.max_simulations,
        evaluator_,
        solve=False,
        dirichlet_noise=noise,
        child_selection_fn=SearchNode.puct_value,
        verbose=False,
        dont_return_chance_node=True)

In [None]:
# Play one game, return the trajectory
def _play_game(logger, game_num, game, bots, temperature, temperature_drop):
    trajectory = Trajectory()
    actions = []
    state = game.new_initial_state()
    random_state = np.random.RandomState()
    logger.opt_print(" Starting game {} ".format(game_num).center(60, "-"))
    logger.opt_print("Initial state:\n{}".format(state))
    while not state.is_terminal():
        root = bots[state.current_player()].mcts_search(state)
        policy = np.zeros(game.num_distinct_actions())
        for c in root.children:
            policy[c.action] = c.explore_count
        policy = policy**(1 / temperature)
        policy /= policy.sum()
        if len(actions) >= temperature_drop:
            action = root.best_child().action
        else:
            action = np.random.choice(len(policy), p=policy)
        trajectory.states.append(
            TrajectoryState(state.observation_tensor(), state.current_player(),
                            state.legal_actions_mask(), action, policy,
                            root.total_reward / root.explore_count)
        )
        action_str = state.action_to_string(state.current_player(), action)
        actions.append(action_str)
        logger.opt_print("Player {} sampled action: {}".format(state.current_player(), action_str))
        state.apply_action(action)
    logger.opt_print("Next state:\n{}".format(state))

    trajectory.returns = state.returns()
    logger.print("Game {}: Returns: {}; Actions: {}".format(
        game_num, " ".join(map(str, trajectory.returns)), " ".join(actions)))
    return trajectory

In [None]:
# Read the queue for a checkpoint to load, or an exit signal
def update_checkpoint(logger, queue, model, az_evaluator):
    path = None
    while True:  # Get the last message, ignore intermediate ones.
        try:
            path = queue.get_nowait()
        except spawn.Empty:
            break
    if path:
        logger.print("Inference cache:", az_evaluator.cache_info())
        logger.print("Loading checkpoint", path)
        model.load_checkpoint(path)
        az_evaluator.clear_cache()
    elif path is not None:  # Empty string means stop this process.
        return False
    return True

In [None]:
# An actor process runner that generates games and returns trajectories
@watcher
def actor(*, config, game, logger, queue):
    logger.print("Initializing model")
    model = _init_model_from_config(config)

    logger.print("Initializing bots")
    az_evaluator = AlphaZeroEvaluator(game, model)

    bots = [_init_bot(config, game, az_evaluator, False),
            _init_bot(config, game, az_evaluator, False)]
    for game_num in itertools.count():
        if not update_checkpoint(logger, queue, model, az_evaluator):
            return
        queue.put(_play_game(logger, game_num, game, bots, config.temperature, config.temperature_drop))

In [None]:
# A process that plays the latest checkpoint vs standard MCTS
@watcher
def evaluator(*, game, config, logger, queue):
    results = Buffer(config.evaluation_window)
    logger.print("Initializing model")

    model = _init_model_from_config(config)
    logger.print("Initializing bots")

    az_evaluator = AlphaZeroEvaluator(game, model)
    random_evaluator = RandomRolloutEvaluator()

    for game_num in itertools.count():
        if not update_checkpoint(logger, queue, model, az_evaluator):
            return

        az_player = game_num % 2
        difficulty = (game_num // 2) % config.eval_levels
        max_simulations = int(config.max_simulations * (10 ** (difficulty / 2)))
        bots = [
            _init_bot(config, game, az_evaluator, True),
            MCTSBot(
                game,
                config.uct_c,
                max_simulations,
                random_evaluator,
                solve=True,
                verbose=False,
                dont_return_chance_node=True)
        ]
        if az_player == 1:
            bots = list(reversed(bots))

        trajectory = _play_game(logger, game_num, game, bots, temperature=1, temperature_drop=0)
        results.append(trajectory.returns[az_player])
        queue.put((difficulty, trajectory.returns[az_player]))

        logger.print("AZ: {}, MCTS: {}, AZ avg/{}: {:.3f}".format(
            trajectory.returns[az_player],
            trajectory.returns[1 - az_player],
            len(results), np.mean(results.data))
        )

In [None]:
# Learner, выполняющий обучение нейронной сети на опыте из replay buffer'а
@watcher
def learner(*, game, config, actors, evaluators, broadcast_fn, logger):
    logger.also_to_stdout = True
    replay_buffer = Buffer(config.replay_buffer_size)
    learn_rate = config.replay_buffer_size // config.replay_buffer_reuse
    logger.print("Initializing model")
    model = _init_model_from_config(config)
    logger.print("Model type: %s(%s, %s)" % (config.nn_model, config.nn_width, config.nn_depth))
    logger.print("Model size:", model.num_trainable_variables, "variables")
    save_path = model.save_checkpoint(0)
    logger.print("Initial checkpoint:", save_path)
    broadcast_fn(save_path)

    data_log = DataLoggerJsonLines(config.path, "learner", True)

    stage_count = 7
    value_accuracies = [stats.BasicStats() for _ in range(stage_count)]
    value_predictions = [stats.BasicStats() for _ in range(stage_count)]
    game_lengths = stats.BasicStats()
    game_lengths_hist = stats.HistogramNumbered(game.max_game_length() + 1)
    outcomes = stats.HistogramNamed(["Player1", "Player2", "Draw"])
    evals = [Buffer(config.evaluation_window) for _ in range(config.eval_levels)]
    total_trajectories = 0


    # Merge all the actor queues into a single generator
    def trajectory_generator():
        while True:
          found = 0
          for actor_process in actors:
              try:
                  yield actor_process.queue.get_nowait()
              except spawn.Empty:
                  pass
              else:
                  found += 1
          if found == 0:
              time.sleep(0.01)  # 10ms

    # Collects the trajectories from actors into the replay buffer
    def collect_trajectories():
      num_trajectories = 0
      num_states = 0
      for trajectory in trajectory_generator():
          num_trajectories += 1
          num_states += len(trajectory.states)
          game_lengths.add(len(trajectory.states))
          game_lengths_hist.add(len(trajectory.states))

          p1_outcome = trajectory.returns[0]
          if p1_outcome > 0:
              outcomes.add(0)
          elif p1_outcome < 0:
              outcomes.add(1)
          else:
              outcomes.add(2)

          replay_buffer.extend(
              model_lib.TrainInput(
                  s.observation, s.legals_mask, s.policy, p1_outcome)
              for s in trajectory.states)

          for stage in range(stage_count):
              # Scale for the length of the game
              index = (len(trajectory.states) - 1) * stage // (stage_count - 1)
              n = trajectory.states[index]
              accurate = (n.value >= 0) == (trajectory.returns[n.current_player] >= 0)
              value_accuracies[stage].add(1 if accurate else 0)
              value_predictions[stage].add(abs(n.value))

          if num_states >= learn_rate:
              break
      return num_trajectories, num_states


    # Sample from the replay buffer, update weights and save a checkpoint
    def learn(step):
        losses = []
        for _ in range(len(replay_buffer) // config.train_batch_size):
            data = replay_buffer.sample(config.train_batch_size)
            losses.append(model.update(data))

        # Always save a checkpoint, either for keeping or for loading the weights to
        # the actors. It only allows numbers, so use -1 as "latest".
        save_path = model.save_checkpoint(
            step if step % config.checkpoint_freq == 0 else -1)
        losses = sum(losses, model_lib.Losses(0, 0, 0)) / len(losses)
        logger.print(losses)
        logger.print("Checkpoint saved:", save_path)
        return save_path, losses

    last_time = time.time() - 60
    for step in itertools.count(1):
        for value_accuracy in value_accuracies:
            value_accuracy.reset()
        for value_prediction in value_predictions:
            value_prediction.reset()
        game_lengths.reset()
        game_lengths_hist.reset()
        outcomes.reset()

        num_trajectories, num_states = collect_trajectories()
        total_trajectories += num_trajectories
        now = time.time()
        seconds = now - last_time
        last_time = now

        logger.print("Step:", step)
        logger.print(
            ("Collected {:5} states from {:3} games, {:.1f} states/s. "
            "{:.1f} states/(s*actor), game length: {:.1f}").format(
                num_states, num_trajectories, num_states / seconds,
                num_states / (config.actors * seconds),
                num_states / num_trajectories))
        logger.print("Buffer size: {}. States seen: {}".format(
            len(replay_buffer), replay_buffer.total_seen))

        save_path, losses = learn(step)

        for eval_process in evaluators:
            while True:
                try:
                    difficulty, outcome = eval_process.queue.get_nowait()
                    evals[difficulty].append(outcome)
                except spawn.Empty:
                    break

        batch_size_stats = stats.BasicStats()  # Only makes sense in C++.
        batch_size_stats.add(1)
        data_log.write({
            "step": step,
            "total_states": replay_buffer.total_seen,
            "states_per_s": num_states / seconds,
            "states_per_s_actor": num_states / (config.actors * seconds),
            "total_trajectories": total_trajectories,
            "trajectories_per_s": num_trajectories / seconds,
            "queue_size": 0,  # Only available in C++.
            "game_length": game_lengths.as_dict,
            "game_length_hist": game_lengths_hist.data,
            "outcomes": outcomes.data,
            "value_accuracy": [v.as_dict for v in value_accuracies],
            "value_prediction": [v.as_dict for v in value_predictions],
            "eval": {
                "count": evals[0].total_seen,
                "results": [sum(e.data) / len(e) if e else 0 for e in evals],
            },
            "batch_size": batch_size_stats.as_dict,
            "batch_size_hist": [0, 1],
            "loss": {
                "policy": losses.policy,
                "value": losses.value,
                "l2reg": losses.l2,
                "sum": losses.total,
            },
            "cache": {  # Null stats because it's hard to report between processes.
                "size": 0,
                "max_size": 0,
                "usage": 0,
                "requests": 0,
                "requests_per_s": 0,
                "hits": 0,
                "misses": 0,
                "misses_per_s": 0,
                "hit_rate": 0,
            },
        })
        logger.print()

        if config.max_steps > 0 and step >= config.max_steps:
            break

        broadcast_fn(save_path)

In [None]:
JOIN_WAIT_DELAY = 0.001

# Start all the worker processes for a full alphazero setup
def alpha_zero(config: Config):
    game = pyspiel.load_game(config.game)
    config = config._replace(
        observation_shape=game.observation_tensor_shape(),
        output_size=game.num_distinct_actions()
    )

    path = config.path
    if not path:
        path = tempfile.mkdtemp(prefix="az-{}-{}-".format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M"), config.game))
        config = config._replace(path=path)

    if not os.path.exists(path):
        os.makedirs(path)
    if not os.path.isdir(path):
        sys.exit("{} isn't a directory".format(path))
    print("Writing logs and checkpoints to:", path)
    print("Model type: %s(%s, %s)" % (config.nn_model, config.nn_width, config.nn_depth))

    with open(os.path.join(config.path, "config.json"), "w") as fp:
        fp.write(json.dumps(config._asdict(), indent=2, sort_keys=True) + "\n")

    actors = [spawn.Process(actor, kwargs={"game": game, "config": config, "num": i})
              for i in range(config.actors)]
    evaluators = [spawn.Process(evaluator, kwargs={"game": game, "config": config, "num": i})
                  for i in range(config.evaluators)]

    def broadcast(msg):
        for proc in actors + evaluators:
          proc.queue.put(msg)

    try:
        learner(game=game, config=config, actors=actors,  # pylint: disable=missing-kwoa
                evaluators=evaluators, broadcast_fn=broadcast)
    except (KeyboardInterrupt, EOFError):
        print("Caught a KeyboardInterrupt, stopping early.")
    finally:
        broadcast("")
        # for actor processes to join we have to make sure that their q_in is empty, including backed up items
        for proc in actors:
            while proc.exitcode is None:
                while not proc.queue.empty():
                    proc.queue.get_nowait()
                proc.join(JOIN_WAIT_DELAY)
        for proc in evaluators:
            proc.join()