In [1]:
import numpy as np

In [2]:
class GridWorld:
    def __init__(self, size):
        self.size = size
        self.grid = np.ones((size, size))
        self.start = (0, 0)
        self.goal = (size - 1, size - 1)
        self.grid[*self.goal] = 2
        self.num_walls = size // 2
        wall_count = 0
        while wall_count != self.num_walls:
            wall_x = np.random.randint(size)
            wall_y = np.random.randint(size)
            if (wall_x, wall_y) != self.start and (wall_x, wall_y) != self.goal:
                self.grid[*(wall_x, wall_y)] = 0
                wall_count += 1
    
    def step(self, state, action):
        if self.grid[*(state)] == 0:
            raise Exception("Not a valid State")
        if state == self.goal:
            return state, 0, True
        next_state = state
        if action == 'u':
            next_state = (max(0, next_state[0] -1 ), next_state[1])
        elif action == 'l':
            next_state = (next_state[0], max(0, next_state[1] - 1))
        elif action == 'd':
            next_state = (min(self.size - 1, next_state[0] + 1), next_state[1])
        elif action == 'r':
            next_state = (next_state[0], min(self.size - 1, next_state[1] + 1))

        if self.grid[*(next_state)] == 0:
            return state, -1, False
        elif self.grid[*(next_state)] == 2:
            return next_state, -1, True
        else:
            return next_state, -1, False

    def get_possibilities(self, state):
        possible_actions = ['l', 'r', 'u', 'd']
        return possible_actions, [self.step(state, action)[0] for action in possible_actions]

    def get_roll_out(self, state):
        random_policy = lambda state: np.random.choice(['l', 'r', 'u', 'd'])
        done = False
        current_state = state
        score = 0
        while not done:
            action = random_policy(state)
            current_state, reward, done = self.step(current_state, action)
            score += reward
        return score
            
            
        
            
env = GridWorld(3)
env.grid

array([[1., 1., 1.],
       [1., 1., 0.],
       [1., 1., 2.]])

In [3]:
class Node:
    def __init__(self, state, parent=None):
        self.Q = 0
        self.N = 0
        self.state = state
        self.parent = parent 
        self.children = {}

    def is_fully_expanded(self, env):
        return len(self.children) == len(env.get_possibilities(self.state)[0])

    def get_best_child(self, c):
        action = max(self.children, key=lambda action : self.children[action].get_uct(c))
        return self.children[action]

    def get_uct(self, c):
        if self.parent is None:
            return float('inf')
        return (self.Q / (self.N + 1e-6)) + c*np.sqrt(np.log(self.parent.N)/(self.N+1e-6)) 
        

In [4]:
class MCTS:
    def __init__(self, env, num_iterations, c):
        self.env = env
        self.num_iterations = num_iterations
        self.c = c
    
    def is_terminal(self, node):
        return node.state == self.env.goal

    def search(self, root_state):
        root = Node(root_state)
        for _ in range(self.num_iterations):
            node = self.selection(root)
            if not self.is_terminal(node):
                node = self.expansion(node)
            reward = self.env.get_roll_out(node.state)
            self.back_propagate(node, reward)
        action = max(root.children, key=lambda action : root.children[action].get_uct(self.c))
        return action, root.children[action].state

    def back_propagate(self, node, reward):
        while node is not None:
            node.N += 1
            node.Q += reward
            node = node.parent
        
        
    def selection(self, node):
        while not self.is_terminal(node) and node.is_fully_expanded(self.env):
            node = node.get_best_child(self.c)
        return node

    def expansion(self, node):
        possible_actions, next_states = self.env.get_possibilities(node.state)
        node.children = {action : Node(state, parent=node) for action, state in zip(possible_actions, next_states)}
        return node

In [5]:
env = GridWorld(5)

In [12]:
state = (0, 0)
path = env.grid.copy()
mcts = MCTS(env, 100, 2)
while state != env.goal:
    path[state] = -1
    action, next_state = mcts.search(state)
    print(state, action)
    state = next_state
print(path)

(0, 0) l
(0, 0) r
(0, 1) u
(0, 1) l
(0, 0) u
(0, 0) r
(0, 1) r
(0, 2) u
(0, 2) r
(0, 3) l
(0, 2) l
(0, 1) d
(1, 1) d
(2, 1) d
(3, 1) r
(3, 2) r
(3, 3) d
(4, 3) r
[[-1. -1. -1. -1.  1.]
 [ 1. -1.  1.  1.  1.]
 [ 1. -1.  1.  1.  0.]
 [ 1. -1. -1. -1.  1.]
 [ 0.  1.  1. -1.  2.]]


In [13]:
print(path)

[[-1. -1. -1. -1.  1.]
 [ 1. -1.  1.  1.  1.]
 [ 1. -1.  1.  1.  0.]
 [ 1. -1. -1. -1.  1.]
 [ 0.  1.  1. -1.  2.]]
