[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/real-itu/modern-ai-course/blob/master/lecture-03/lab.ipynb)

# Lab 3 - Monte Carlo Tree Search

In this exercise we will use the same game as in the previous exercise, namely, Connect 4.
([Connect 4](https://en.wikipedia.org/wiki/Connect_Four)). You should implement the MCTS algorithm to play the game.

As before, the game is implemented below. It will play a game where both players take random (legal) actions. The MAX player is represented with a X and the MIN player with an O. The MAX player starts. Execute the code.

In [8]:
import random
from copy import deepcopy
from typing import Sequence

NONE = '.'
MAX = 'X'
MIN = 'O'
COLS = 7
ROWS = 6
N_WIN = 4


class ArrayState:
    def __init__(self, board, heights, n_moves):
        self.board = board
        self.heights = heights
        self.n_moves = n_moves

    @staticmethod
    def init():
        board = [[NONE] * ROWS for _ in range(COLS)]
        return ArrayState(board, [0] * COLS, 0)


def result(state: ArrayState, action: int) -> ArrayState:
    """Insert in the given column."""
    assert 0 <= action < COLS, "action must be a column number"

    if state.heights[action] >= ROWS:
        raise Exception('Column is full')

    player = MAX if state.n_moves % 2 == 0 else MIN

    board = deepcopy(state.board)
    board[action][ROWS - state.heights[action] - 1] = player

    heights = deepcopy(state.heights)
    heights[action] += 1

    return ArrayState(board, heights, state.n_moves + 1)


def actions(state: ArrayState) -> Sequence[int]:
    return [i for i in range(COLS) if state.heights[i] < ROWS]


def branch_states(state: ArrayState) -> Sequence[ArrayState]:
    """get all reachable states from the current state:
        useful for MCTS
    """
    return [result(state, a) for a in actions(state)]
    

def utility(state: ArrayState) -> float:
    """Get the winner on the current board."""

    board = state.board

    def diagonalsPos():
        """Get positive diagonals, going from bottom-left to top-right."""
        for di in ([(j, i - j) for j in range(COLS)] for i in range(COLS + ROWS - 1)):
            yield [board[i][j] for i, j in di if i >= 0 and j >= 0 and i < COLS and j < ROWS]

    def diagonalsNeg():
        """Get negative diagonals, going from top-left to bottom-right."""
        for di in ([(j, i - COLS + j + 1) for j in range(COLS)] for i in range(COLS + ROWS - 1)):
            yield [board[i][j] for i, j in di if i >= 0 and j >= 0 and i < COLS and j < ROWS]

    lines = board + \
            list(zip(*board)) + \
            list(diagonalsNeg()) + \
            list(diagonalsPos())

    max_win = MAX * N_WIN
    min_win = MIN * N_WIN
    for line in lines:
        str_line = "".join(line)
        if max_win in str_line:
            return 1
        elif min_win in str_line:
            return -1
    return 0


def terminal_test(state: ArrayState) -> bool:
    return state.n_moves >= COLS * ROWS or utility(state) != 0


def printBoard(state: ArrayState):
    board = state.board
    """Print the board."""
    print('  '.join(map(str, range(COLS))))
    for y in range(ROWS):
        print('  '.join(str(board[x][y]) for x in range(COLS)))
    print()



s = ArrayState.init()
while not terminal_test(s):
    a = random.choice(actions(s))
    s = result(s, a)
    printBoard(s)
print(utility(s))


0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  X  .  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  O  X  .  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
X  O  X  .  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  .  .  .  .
X  O  X  .  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  X  O  .  .  .  .
X  O  X  .  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  .  .  .  .
.  X  O  .  .  .  .
X  O  X  .  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  .  .  .  .
.  X  O  .  .  .  .
X  O  X  .  .  .  X

0  1  2  3  4

The last number 0, -1 or 1 is the utility or score of the game. 0 means it was a draw, 1 means MAX player won and -1 means MIN player won.

### Exercise 1 (Transfer code from the previous exercise)

Modify the code so that you can play manually as the MIN player against the random AI.

In [6]:
if __name__ == '__main__':
    s = ArrayState.init()
    player = random.choice(['AI','Programmer'])
    while not terminal_test(s):
        print(utility(s))
        if player == 'AI':
            a = random.choice(actions(s))
            s = result(s, a)
            printBoard(s)
            player = 'Programmer'
        else:
            a = input()
            a = int(a)
            while not (a >= 0 and a <= 6):
                print('Invalid value, insert new value')
                a = input()
                a = int(a)
            s = result(s, a)
            printBoard(s)
            player = 'AI'
    print(utility(s))


0
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  X  .

0
3
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  O  .  X  .

0
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  X  .
.  .  .  O  .  X  .

0
h


ValueError: invalid literal for int() with base 10: 'h'

### Exercise 2
Implement the standard MCTS algorithm.

In [40]:
float(inf)

NameError: name 'inf' is not defined

In [46]:
import time
import numpy as np

class MCTS:
    def __init__(self, exploration_weight = 0.1, thinking_time = 30, is_maxplayer = True):
        self.children = {}
        self.N = {}
        self.Q = {}
        self.parents = []
        self.thinking_time = thinking_time
        self.e = exploration_weight
        self.is_maxplayer = is_maxplayer
        
        
    def search(self, node):
        start_time = time.time()
        
        # add node to children
        if str(node.board) not in self.children:
            self.children[str(node.board)] = {}
        if str(node.board) not in self.N:
            self.N[str(node.board)] = 1
            self.Q[str(node.board)] = 0
        
        while time.time() - start_time < self.thinking_time:
            v1 = self.tree_policy(node)
            delta = self.defaultpolicy(v1)
            self.backup(v1, delta)
            
        return self.best_child(node)
            
        
    def tree_policy(self, node):
        while not terminal_test(node):
            # not all possible actions have been tried
            if len(self.children[str(node.board)]) < len(actions(node)):
                return self.expand(node)
                
            else:
                node = self.best_child(node)
                self.children[str(node.board)] = {}
        
        return node
        
        
    def expand(self, node):
        for action in actions(node):
            child = result(node, action)
            if str(child.board) not in self.children[str(node.board)]:
                # add child
                self.children[str(node.board)][str(child.board)] = child
                
                return child
      
    
    def best_child(self, node):
        ln_N = 2 * np.log(self.N[str(node.board)])
        
        def UCB(child):
            v = str(child.board)
            if self.N[v] == 0:
                return - float('inf')
            
            return self.Q[v]/self.N[v] + self.e * np.sqrt(ln_N / self.N[v])
            

            
        return max(self.children[str(node.board)].values(), key=UCB)
    
    
    def defaultpolicy(self, node):
        self.parents.append(str(node.board))
        while not terminal_test(node):
            move = random.choice(actions(node))
            node = result(node, move)
            self.parents.append(str(node.board))
            
        return utility(node)
    
    
    def backup(self, node, delta):
        self.parents.reverse()
        
        if delta < 0:
            delta = 0
        
        if not self.is_maxplayer:
            #make negative positive
            delta *= -1
        for parent in self.parents:
            if parent not in self.N:
                self.N[parent] = 1
                self.Q[parent] = delta
            else:
                self.N[parent] += 1
                self.Q[parent] += delta
            # change value of delta
            delta = 1 - delta
            
        # Clean parents
        self.parents = []
        
    
    
    

In [47]:
if __name__ == '__main__':
    s = ArrayState.init()
    thinking_time = 10
    player = random.choice(['AI','Programmer'])
    if player == 'AI':
        MCTS_player = MCTS(is_maxplayer = True, thinking_time = thinking_time)
    else:
        MCTS_player = MCTS(is_maxplayer = False, thinking_time = thinking_time)
        
    while not terminal_test(s):
        if player == 'AI':
            s = MCTS_player.search(s)
            printBoard(s)
            player = 'Programmer'
        else:
            a = input()
            a = int(a)
            while not (a >= 0 and a <= 6):
                print('Invalid value, insert new value')
                a = input()
                a = int(a)
            s = result(s, a)
            printBoard(s)
            player = 'AI'
    print(utility(s))


3
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  X  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  X  .  .  .

5
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  X  .  X  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  X  O  X  .

4
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  X  .  .
.  .  O  X  O  X  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  X  .  .
.  .  O  X  O  X  O

2
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  X  .  X  .  .
.  .  O  X  O  X  O

0  1 

In [70]:
[3,2,4,2,4,9,9,9][:1]

[3]

## Return move

In [112]:
import time
import numpy as np

class MCTS:
    def __init__(self, exploration_weight = 0.1, thinking_time = 30, is_maxplayer = True):
        self.children = {}
        self.N = {}
        self.Q = {}
        self.parents = []
        self.thinking_time = thinking_time
        self.e = exploration_weight
        self.is_maxplayer = is_maxplayer
        
        
    def search(self, node):
        start_time = time.time()
        
        # add node to children
        if str(node.board) not in self.children:
            self.children[str(node.board)] = []
        if str(node.board) not in self.N:
            self.N[str(node.board)] = 1
        if str(node.board) not in self.Q:
            self.Q[str(node.board)] = 0
            
                    
        
        while time.time() - start_time < self.thinking_time:
            move = self.tree_policy(node)
            delta = self.defaultpolicy(node, move)
            self.backup(delta)
            
        return self.best_child(node)
            
        
    def tree_policy(self, node):
        while not terminal_test(node):
            # not all possible actions have been tried
            if len(self.children[str(node.board)]) < len(actions(node)):
                return self.expand(node)
                
            else:
                move = self.best_child(node)
                node = result(node, move)
                self.children[str(node.board)] = []
        
        return node
        
        
    def expand(self, node):
        for action in actions(node):
            if action not in self.children[str(node.board)]:
                # add child 
                self.children[str(node.board)].append(action)
                
                return action
      
    
    def best_child(self, node):
        ln_N = 2 * np.log(self.N[str(node.board)])
        
        def UCB(move):
            child = result(node, move)
            v = str(child.board)
            if not self.N[v]: # if zero or none
                return - float('inf')
            
            return self.Q[v]/self.N[v] + self.e * np.sqrt(ln_N / self.N[v])
            
            
        return max(self.children[str(node.board)], key=UCB)
    
    
    def defaultpolicy(self, node, move):
        node = result(node, move)
        self.parents.append(str(node.board))
        while not terminal_test(node):
            move = random.choice(actions(node))
            node = result(node, move)
            self.parents.append(str(node.board))
            
        return utility(node)
    
    
    def backup(self, delta):
    
        if not self.is_maxplayer:
            #make negative positive
            delta *= -1
        if delta < 0:
            delta = 0

        while self.parents:
            parent = self.parents.pop()
            
            if parent not in self.N:
                self.N[parent] = 1
                self.Q[parent] = delta
            else:
                self.N[parent] += 1
                self.Q[parent] += delta
            # change value of delta
            delta = 1 - delta


In [108]:
if __name__ == '__main__':
    s = ArrayState.init()
    thinking_time = 20
    player = random.choice(['AI','Programmer'])
    if player == 'AI':
        MCTS_player = MCTS(is_maxplayer = True, thinking_time = thinking_time)
    else:
        MCTS_player = MCTS(is_maxplayer = False, thinking_time = thinking_time)
        
    while not terminal_test(s):
        if player == 'AI':
            a = MCTS_player.search(s)
            print(a)
            s = result(s, a)
            printBoard(s)
            player = 'Programmer'
        else:
            a = input()
            a = int(a)
            while not (a >= 0 and a <= 6):
                print('Invalid value, insert new value')
                a = input()
                a = int(a)
            s = result(s, a)
            printBoard(s)
            player = 'AI'
    print(utility(s))


3
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  X  .  .  .

Last
["[['.', '.', '.', 'O', 'X', 'O'], ['.', 'X', 'O', 'X', 'O', 'X'], ['.', '.', 'X', 'X', 'X', 'O'], ['.', 'O', 'O', 'X', 'O', 'X'], ['.', 'O', 'O', 'O', 'O', 'X'], ['.', '.', '.', '.', '.', '.'], ['.', '.', 'X', 'O', 'X', 'X']]"]
-1
1
4
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  X  O  .  .

4
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  X  .  .
.  .  .  X  O  .  .

Last
["[['.', '.', '.', '.', 'O', 'O'], ['.', '.', 'X', 'X', 'O', 'O'], ['.', '.', '.', '.', '.', '.'], ['X', 'O', 'X', 'X', 'O', 'X'], ['.', 'O', 'O', 'X', 'X', 'O'], ['.', '.', '.', '.', 'X', 'O'], ['.', '.', 'X', 'X', 'O', 'X']]"]
1
0
4
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  

KeyboardInterrupt: Interrupted by user

### Exercise 3
Implement the loop where you play against your MCTS agent. Either train the agent at each step while you play against it, or pretrain it with more rollouts and play agaist it after training.

In [113]:
def train_model(num_rollouts : int, thinking_time : int):
    MCTS_player1 = MCTS(is_maxplayer = True, thinking_time = thinking_time)
    MCTS_player2 = MCTS(is_maxplayer = False, thinking_time = thinking_time)
    
    for i in range(num_rollouts):
        s = ArrayState.init()
        while not terminal_test(s):
            a = MCTS_player1.search(s)
            s = result(s, a)

            a = MCTS_player2.search(s)
            s = result(s, a)
    
    return MCTS_player1, MCTS_player2
    
    

In [114]:
if __name__ == '__main__':
    s = ArrayState.init()
    
    player = random.choice(['AI','Programmer'])
    if player == 'AI':
        MCTS_player, _ = train_model(3, 10)
    else:
        _, MCTS_player = train_model(3, 10)
        
    while not terminal_test(s):
        if player == 'AI':
            s = MCTS_player.search(s)
            printBoard(s)
            player = 'Programmer'
        else:
            a = input()
            a = int(a)
            while not (a >= 0 and a <= 6):
                print('Invalid value, insert new value')
                a = input()
                a = int(a)
            s = result(s, a)
            printBoard(s)
            player = 'AI'
    print(utility(s))


TypeError: '<=' not supported between instances of 'int' and 'ArrayState'


### Exercise 4

Add move ordering. The middle columns are often "better" since there's more winning positions that contain them. Increase the probability to choose middle columns when randomly executing rollouts: [3,2,4,1,5,0,6]. See if your connect4 AI can beat you.


In [48]:
def actions(state: ArrayState) -> Sequence[int]:
    act = [i for i in range(COLS) if state.heights[i] < ROWS]
    sort_act = [3, 4, 2, 5, 1, 6, 0]
    return [a for a in sort_act if a in act]

In [49]:
if __name__ == '__main__':
    s = ArrayState.init()
    thinking_time = 10
    player = random.choice(['AI','Programmer'])
    if player == 'AI':
        MCTS_player = MCTS(is_maxplayer = True, thinking_time = thinking_time)
    else:
        MCTS_player = MCTS(is_maxplayer = False, thinking_time = thinking_time)
        
    while not terminal_test(s):
        if player == 'AI':
            s = MCTS_player.search(s)
            printBoard(s)
            player = 'Programmer'
        else:
            a = input()
            a = int(a)
            while not (a >= 0 and a <= 6):
                print('Invalid value, insert new value')
                a = input()
                a = int(a)
            s = result(s, a)
            printBoard(s)
            player = 'AI'
    print(utility(s))


0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  X  .  .  .  .

3
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  X  O  .  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  X  O  X  .  .

2
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  .  .  .  .
.  .  X  O  X  .  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  .  .  .  .
.  .  X  O  X  X  .

4
0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  O  .  O  .  .
.  .  X  O  X  X  .

0  1  2  3  4  5  6
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  .  .  .
.  .  .  .  X  .  .
.  .  O  .  O  .  .
.  .  X  O  X  X  .

3
0  1 

### Exercise 5 - Optional

Pit your MCTS agent against the one from the previous exercise.
* Which one wins more often?
* Which one takes more time to run per step once it is at a point that it can beat you?