In [1]:
# Imports
import os
import math
import random
import time
import torch
import random
import logging
import coloredlogs

import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm.notebook import tqdm

log = logging.getLogger(__name__)
coloredlogs.install(level='INFO')  # Change this to DEBUG to see more info.

import sys
sys.path.append('..')
sys.path.append('../nma_rl_games/alpha-zero')

In [2]:
from ws3_helper import set_seed
from ws3_helper import seed_worker
from ws3_helper import set_device

SEED = 2023
set_seed(seed=SEED)
DEVICE = set_device()

Random seed 2023 has been set.


In [3]:
import Arena

from utils import *
from Game import Game
from MCTS import MCTS
from NeuralNet import NeuralNet

# from othello.OthelloPlayers import *
from othello.OthelloLogic import Board
# from othello.OthelloGame import OthelloGame
from othello.pytorch.NNet import NNetWrapper as NNet

In [13]:
from ws3_helper import loadTrainExamples
from ws3_helper import save_model_checkpoint
from ws3_helper import load_model_checkpoint
from ws3_helper import OthelloGame
from ws3_helper import RandomPlayer
from ws3_helper import OthelloNNet
from ws3_helper import ValueNetwork
from ws3_helper import ValueBasedPlayer
from ws3_helper import PolicyNetwork
from ws3_helper import PolicyBasedPlayer
from ws3_helper import MonteCarlo
from ws3_helper import MonteCarloBasedPlayer

# Playing Go with Neural Networks and Monte Carlo Tree Search

- How did AI master the ancient game of Go?
- Discover the power of self-play and deep learning.
- See how Monte Carlo Tree Search (MCTS) guides decision-making in real time.
- Learn how AlphaZero taught itself to dominate Go without any human knowledge.

<img src="../images/go_game.JPG" alt="" width="600"/>

<img src="../images/alphago_paper.png" alt="" width="600"/>

The AlphaGo Algorithm

<img src="../images/alphazero_paper.png" alt="" width="600"/>

The AlphaZero Algorithm

<img src="../images/alphazero0.png" alt="Monte Carlo tree search in AlphaGo" width="600"/>

In Monte-Carlo Tree Search (MCTS), nodes represent game states, and edges represent actions.

- **Leaf Node**: A game state that has not been visited yet.
  
- **Terminal Node**: A node where the game ends (win, loss, or draw).


#### Key Terms for Understanding MCTS in AlphaZero

- **$Q(s, a)$ (Action Value)**: The estimated value of taking action $a$ in state $s$, updated as simulations progress. Tracks the average value of actions.

- **$P(s, a)$ (Prior Probability)**: Probability of selecting action $a$ in state $s$, predicted by the policy network. Guides the tree search based on learned strategies.

- **$N(s, a)$ (Visit Count)**: Number of times action $a$ has been taken from state $s$ during MCTS simulations, balancing exploration and exploitation.


<img src="../images/alphago_fig3.png" alt="Monte Carlo tree search in AlphaGo" width="600"/>


MCTS in AlphaZero

- **(a)**: **Selection**
  - Select the edge with the highest score.
  - The edge's action value $Q(s, a)$ is combined with a bonus term $u(P)$ which is based on the stored prior probability $P(s, a)$ for that edge.
  - The goal is to balance exploration (using $u(P)$) and exploitation (using $Q(s, a)$) when choosing the next action:
    $$
    u(P) = c_{puct} \cdot P(s, a) \cdot \frac{\sqrt{\sum_b N(s, b)}}{1 + N(s, a)}
    $$
    where $c_{puct}$ is a constant controlling exploration, and $N(s, a)$ is the number of times the edge $(s, a)$ has been visited.

- **(b)**: **Leaf Node Expansion**
  - When a leaf node is reached, it may be expanded by adding a new node to the tree.
  - The policy network $p_\sigma$ processes the new node, generating a set of action probabilities.
  - These probabilities are stored as prior probabilities $P(s, a)$ for each possible action from that node, guiding future simulations.

- **(c)**: **Leaf Node Evaluation**
    - The value network $v_\theta$ evaluates the current state, producing a value estimate for the player at the node.

- **(d)**: **Action Value Update**
  - Update $Q(s, a)$ to reflect the mean value of all evaluations $v_\theta$ in the subtree below that action.



## Coding Exercise 1: MCTS planner

In building the MCTS planner, we will focus on the action selection part, particularly the objective function used. MCTS will use a combination of the current action-value function $Q$ and the policy prior as follows:

\begin{equation}
\underset{a}{\operatorname{argmax}} (Q(s_t, a)+u(s_t, a))
\end{equation}

with $u(s_t, a)=c_{puct} \cdot P(s,a) \cdot \frac{\sqrt{\sum_b N(s,b)}}{1+N(s,a)}$. This effectively implements an Upper Confidence bound applied to Trees (UCT). UCT balances exploration and exploitation by taking the values stored from the MCTS into account. The trade-off is parametrized by $c_{puct}$.

**Note**: Polynomial Upper Confidence Trees (PUCT) is the technical term for the alorithm below in which we sequentially run MCTS and store/use information from previous runs to explore and find optimal actions).

<br>

**Exercise**:
* Finish the MCTS planner by using UCT to select actions to build the tree.
* Deploy the MCTS planner to build a tree search for a given board position, producing value estimates and action counts for that position.

### Part 1: Initialization and Attribute Definitions

The `MCTS` class is initialized with three main arguments: `game`, `nnet`, and `args`. These initialize the game environment, the neural network for evaluating game states, and various parameters for running the Monte Carlo Tree Search (MCTS) algorithm. Additionally, several dictionaries (`Qsa`, `Nsa`, `Ns`, `Ps`, `Es`, `Vs`) are defined to store information about the search tree.

```python
class MCTS():
    def __init__(self, game, nnet, args):
        """
        Args:
          game: OthelloGame instance
          nnet: OthelloNet instance
        """
        self.game = game
        self.nnet = nnet
        self.args = args
        self.Qsa = {}  # Stores Q values for s,a
        self.Nsa = {}  # Stores #times edge s,a was visited
        self.Ns = {}  # Stores #times board s was visited
        self.Ps = {}  # Stores initial policy (returned by neural net)
        self.Es = {}  # Stores game.getGameEnded ended for board s
        self.Vs = {}  # Stores game.getValidMoves for board s
```


### Part 2: Search Method and Terminal State Check

`search` function: Performs one iteration of MCTS.

- Check if the current state `s` is terminal (`self.Es`).
  - If terminal, return the game outcome (negated value).
  - If not terminal, recursively call `search` on the next state.
  - Recursion continues until a leaf node or terminal state is reached.

```python
    def search(self, canonicalBoard):
        """
        Perform one iteration of MCTS.
        Args:
          canonicalBoard: Canonical Board of size n x n.
        Returns:
          float: The negative value of the current canonical board state.
        """
        s = self.game.stringRepresentation(canonicalBoard)

        if s not in self.Es:
            self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
        if self.Es[s] != 0:
            # Terminal node
            return -self.Es[s]
```

### Part 3: Leaf Node Detection and Neural Network Prediction

If the state `s` is not a terminal state, the next step is to check whether the state is a leaf node. If it is, the neural network is called to predict the policy (`Ps`) and the value (`v`) of the state. The valid moves are then masked in the policy to eliminate invalid actions.

```python
        if s not in self.Ps:
            # Leaf node
            self.Ps[s], v = self.nnet.predict(canonicalBoard)
            valids = self.game.getValidMoves(canonicalBoard, 1)
            self.Ps[s] = self.Ps[s] * valids  # Masking invalid moves
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s  # Renormalize
            else:
                # If all valid moves were masked make all valid moves equally probable
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0

            return -v
```


### Part 4: Action Selection Using Upper Confidence Bound

After handling the leaf nodes, the algorithm selects the next action to explore based on the Upper Confidence Bound (UCB). This value balances exploration (selecting actions with fewer visits) and exploitation (selecting actions with higher Q-values). The action with the highest UCB is selected.

```python
        valids = self.Vs[s]
        cur_best = -float("inf")
        best_act = -1

        # Pick the action with the highest upper confidence bound
        for a in range(self.game.getActionSize()):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
                        1 + self.Nsa[(s, a)]
                    )
                else:
                    u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8)

                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act
```

### Part 5: Recursion and Value Propagation

Once the best action is chosen, the algorithm moves to the next state by calling the game environment. The search continues recursively from this next state. Once a value `v` is returned from the deeper searches, the Q-values and visit counts are updated along the path.

```python
        next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
        next_s = self.game.getCanonicalForm(next_s, next_player)

        v = self.search(next_s)

        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1
        return -v
```

In [5]:
# to_remove solution
class MCTS:

    def __init__(self, game, nnet, args):
        """
        Args:
          game: OthelloGame instance
            Instance of the OthelloGame class above;
          nnet: OthelloNet instance
            Instance of the OthelloNNet class above;
          args: dictionary
            Instantiates number of iterations and episodes, controls temperature threshold, queue length,
            arena, checkpointing, and neural network parameters:
            learning-rate: 0.001, dropout: 0.3, epochs: 10, batch_size: 64,
            num_channels: 512
        """
        self.game = game
        self.nnet = nnet
        self.args = args
        self.Qsa = {}  # Stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # Stores #times edge s,a was visited
        self.Ns = {}  # Stores #times board s was visited
        self.Ps = {}  # Stores initial policy (returned by neural net)
        self.Es = {}  # Stores game.getGameEnded ended for board s
        self.Vs = {}  # Stores game.getValidMoves for board s

    def search(self, canonicalBoard):
        """
        Perform one iteration of MCTS.

        It is recursively called till a leaf node is found. The action chosen at
        each node is one that has the maximum upper confidence bound.
        Once a leaf node is found, the neural network is called to return an
        initial policy P and a value v for the state. This value is propagated
        up the search path. In case the leaf node is a terminal state, the
        outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
        updated.
        NOTE: the return values are the negative of the value of the current
        state. This is done since v is in [-1,1] and if v is the value of a
        state for the current player, then its value is -v for the other player.

        Args:
          canonicalBoard: np.ndarray
            Canonical Board of size n x n [6x6 in this case]

        Returns:
            v: Float
              The negative of the value of the current canonicalBoard
        """
        s = self.game.stringRepresentation(canonicalBoard)

        if s not in self.Es:
            self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
        if self.Es[s] != 0:
            # Terminal node
            return -self.Es[s]

        if s not in self.Ps:
            # Leaf node
            self.Ps[s], v = self.nnet.predict(canonicalBoard)
            valids = self.game.getValidMoves(canonicalBoard, 1)
            self.Ps[s] = self.Ps[s] * valids  # Masking invalid moves
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s  # Renormalize
            else:
                # If all valid moves were masked make all valid moves equally probable
                # NB! All valid moves may be masked if either your NNet architecture is
                # insufficient or you've get overfitting or something else.
                # If you have got dozens or hundreds of these messages you should
                # pay attention to your NNet and/or training process.
                log = logging.getLogger(__name__)
                log.error("All valid moves were masked, doing a workaround.")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0

            return -v

        valids = self.Vs[s]
        cur_best = -float("inf")
        best_act = -1

        # Pick the action with the highest upper confidence bound
        for a in range(self.game.getActionSize()):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (
                        1 + self.Nsa[(s, a)]
                    )
                else:
                    u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8)

                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act
        next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
        next_s = self.game.getCanonicalForm(next_s, next_player)

        v = self.search(next_s)

        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1

        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1
        return -v

    def getNsa(self):
        return self.Nsa

---
# Section 2: Use MCTS to play games

*Time estimate: ~10 mins*


**Goal:** Learn how to use the results of MCTS to play games.

**Exercise:**
* Plug the MCTS planner into an agent.
* Play games against other agents.
* Explore the contributions of prior network, value function, number of simulations/time to play and explore/exploit parameters.

In [6]:
# Load MCTS model from the repository
mcts_model_save_name = 'MCTS.pth.tar'
path = "../nma_rl_games/alpha-zero/pretrained_models/models/"

In [7]:
# to_remove solution
class MonteCarloTreeSearchBasedPlayer():

  def __init__(self, game, nnet, args):
    """
    Args:
      game: OthelloGame instance
        Instance of the OthelloGame class above;
      nnet: OthelloNet instance
        Instance of the OthelloNNet class above;
      args: dictionary
        Instantiates number of iterations and episodes, controls temperature threshold, queue length,
        arena, checkpointing, and neural network parameters:
        learning-rate: 0.001, dropout: 0.3, epochs: 10, batch_size: 64,
        num_channels: 512
    """
    self.game = game
    self.nnet = nnet
    self.args = args
    self.mcts = MCTS(game, nnet, args)

  def play(self, canonicalBoard, temp=1):
    """
    Args:
      canonicalBoard: np.ndarray
        Canonical Board of size n x n [6x6 in this case]
      temp: Integer
        Signifies if game is in terminal state

    Returns:
      List of probabilities for all actions if temp is 0
      Best action based on max probability otherwise
    """
    for i in range(self.args.numMCTSSims):
      self.mcts.search(canonicalBoard)

    s = self.game.stringRepresentation(canonicalBoard)
    self.Nsa = self.mcts.getNsa()
    self.counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]

    if temp == 0:
      bestAs = np.array(np.argwhere(self.counts == np.max(self.counts))).flatten()
      bestA = np.random.choice(bestAs)
      probs = [0] * len(self.counts)
      probs[bestA] = 1
      return probs

    self.counts = [x ** (1. / temp) for x in self.counts]
    self.counts_sum = float(sum(self.counts))
    probs = [x / self.counts_sum for x in self.counts]
    return np.argmax(probs)

  def getActionProb(self, canonicalBoard, temp=1):
    """
    Args:
      canonicalBoard: np.ndarray
        Canonical Board of size n x n [6x6 in this case]
      temp: Integer
        Signifies if game is in terminal state

    Returns:
      action_probs: List
        Probability associated with corresponding action
    """
    action_probs = np.zeros((self.game.getActionSize()))
    best_action = self.play(canonicalBoard)
    action_probs[best_action] = 1

    return action_probs


set_seed(seed=SEED)
game = OthelloGame(6)
rp = RandomPlayer(game).play  # All players
num_games = 20  # Games
n1 = NNet(game)  # nnet players
n1.load_checkpoint(folder=path, filename=mcts_model_save_name)
args1 = dotdict({'numMCTSSims': 50, 'cpuct':1.0})

## Uncomment below to check your agent!
print('\n******MCTS player versus random player******')
mcts1 = MonteCarloTreeSearchBasedPlayer(game, n1, args1)
n1p = lambda x: np.argmax(mcts1.getActionProb(x, temp=0))
arena = Arena.Arena(n1p, rp, game, display=OthelloGame.display)
MCTS_result = arena.playGames(num_games, verbose=False)
print(f"\nNumber of games won by player1 = {MCTS_result[0]}, "
      f"number of games won by player2 = {MCTS_result[1]}, out of {num_games} games")
win_rate_player1 = MCTS_result[0]/num_games
print(f"\nWin rate for player1 over {num_games} games: {round(win_rate_player1*100, 1)}%")


Random seed 2023 has been set.


  checkpoint = torch.load(filepath, map_location=map_location)



******MCTS player versus random player******


Arena.playGames (1): 100%|██████████| 10/10 [00:17<00:00,  1.79s/it]
Arena.playGames (2): 100%|██████████| 10/10 [00:16<00:00,  1.60s/it]


Number of games won by player1 = 19, number of games won by player2 = 1, out of 20 games

Win rate for player1 over 20 games: 95.0%





In [8]:
# @title Load in trained value and policy networks
model_save_name = 'ValueNetwork.pth.tar'
path = "../nma_rl_games/alpha-zero/pretrained_models/models/"
set_seed(seed=SEED)
game = OthelloGame(6)
vnet = ValueNetwork(game)
vnet.load_checkpoint(folder=path, filename=model_save_name)

model_save_name = 'PolicyNetwork.pth.tar'
path = "../nma_rl_games/alpha-zero/pretrained_models/models/"
set_seed(seed=SEED)
game = OthelloGame(6)
pnet = PolicyNetwork(game)
pnet.load_checkpoint(folder=path, filename=model_save_name)

# Alternative if the downloading of trained model didn't work (will train the model)
if not os.listdir('../nma_rl_games/alpha-zero/pretrained_models/models/'):
  path = "../nma_rl_games/alpha-zero/pretrained_models/data/"
  loaded_games = loadTrainExamples(folder=path, filename='checkpoint_1.pth.tar')

  set_seed(seed=SEED)
  game = OthelloGame(6)
  vnet = ValueNetwork(game)
  vnet.train(loaded_games)

  set_seed(seed=SEED)
  game = OthelloGame(6)
  pnet = PolicyNetwork(game)
  pnet.train(loaded_games)

  checkpoint = torch.load(filepath, map_location=device)


Random seed 2023 has been set.
Random seed 2023 has been set.


### MCTS player against Value-based player

In [9]:
print('\n******MCTS player versus value-based player******')
set_seed(seed=SEED)
vp = ValueBasedPlayer(game, vnet).play  # Value-based player
arena = Arena.Arena(n1p, vp, game, display=OthelloGame.display)
MC_result = arena.playGames(num_games, verbose=False)

print(f"\nNumber of games won by player1 = {MC_result[0]}, "
      f"number of games won by player2 = {MC_result[1]}, out of {num_games} games")
win_rate_player1 = MC_result[0]/num_games
print(f"\nWin rate for player1 over {num_games} games: {round(win_rate_player1*100, 1)}%")


******MCTS player versus value-based player******
Random seed 2023 has been set.


Arena.playGames (1): 100%|██████████| 10/10 [00:17<00:00,  1.79s/it]
Arena.playGames (2): 100%|██████████| 10/10 [00:17<00:00,  1.79s/it]


Number of games won by player1 = 17, number of games won by player2 = 3, out of 20 games

Win rate for player1 over 20 games: 85.0%





### MCTS player against Policy-based player

In [10]:
print('\n******MCTS player versus policy-based player******')
set_seed(seed=SEED)
pp = PolicyBasedPlayer(game, pnet).play  # Policy-based player
arena = Arena.Arena(n1p, pp, game, display=OthelloGame.display)
MC_result = arena.playGames(num_games, verbose=False)

print(f"\nNumber of games won by player1 = {MC_result[0]}, "
      f"number of games won by player2 = {MC_result[1]}, out of {num_games} games")
win_rate_player1 = MC_result[0]/num_games
print(f"\nWin rate for player1 over {num_games} games: {round(win_rate_player1*100, 1)}%")


******MCTS player versus policy-based player******
Random seed 2023 has been set.


Arena.playGames (1): 100%|██████████| 10/10 [00:19<00:00,  1.92s/it]
Arena.playGames (2): 100%|██████████| 10/10 [00:16<00:00,  1.61s/it]


Number of games won by player1 = 20, number of games won by player2 = 0, out of 20 games

Win rate for player1 over 20 games: 100.0%





### MCTS player against Monte-Carlo player

In [11]:
mc_model_save_name = 'MC.pth.tar'
path = "nma_rl_games/alpha-zero/pretrained_models/models/"

n2 = NNet(game)  # nNet players
n2.load_checkpoint(folder=path, filename=mc_model_save_name)
args2 = dotdict({'numMCsims': 10, 'maxRollouts':5, 'maxDepth':5, 'mc_topk': 3})

In [12]:
print('\n******MCTS player versus MC player******')
set_seed(seed=SEED)
mc = MonteCarloBasedPlayer(game, n2, args2)
n2p = lambda x: np.argmax(mc.getActionProb(x))
arena = Arena.Arena(n1p, n2p, game, display=OthelloGame.display)
MC_result = arena.playGames(num_games, verbose=False)

print(f"\nNumber of games won by player1 = {MC_result[0]}, "
      f"number of games won by player2 = {MC_result[1]}, out of {num_games} games")
win_rate_player1 = MC_result[0]/num_games
print(f"\nWin rate for player1 over {num_games} games: {round(win_rate_player1*100, 1)}%")


******MCTS player versus MC player******
Random seed 2023 has been set.


Arena.playGames (1): 100%|██████████| 10/10 [01:40<00:00, 10.08s/it]
Arena.playGames (2): 100%|██████████| 10/10 [01:43<00:00, 10.35s/it]


Number of games won by player1 = 16, number of games won by player2 = 4, out of 20 games

Win rate for player1 over 20 games: 80.0%





---
# Summary

In this tutorial, you have learned about players with Monte Carlo Tree Search planner and compared them to random, value-based, policy-based, and Monte-Carlo players.