In [None]:
!pip install swig
!pip install gym[box2d]

In [None]:
import math
import numpy as np
import random
from copy import deepcopy
C = 1.0

class Node:

  def __init__(self, parent, action, game_state):
    self.parent = parent
    self.children = None
    self.T = 0
    self.N = 0
    self.action = action
    self.game_state = game_state
    self.done = False

  def __str__(self, level=0):
    ret = "   "*level+repr(self)+"\n"
    if self.children:
      for child in self.children:
          ret += child.__str__(level+1)
    return ret

  def __repr__(self):
    return f"N: {self.N}, T: {self.T}, action: {self.action}"

  def UCB_score(self):
    if self.N == 0:
      return float('inf')

    top_node = self
    if top_node.parent:
      top_node = top_node.parent

    return (self.T/self.N * C * (math.sqrt(math.log(top_node.N)/self.N)))

  def add_children(self):
    self.children = []
    for action in range(GAME_ACTIONS):
      game_state = deepcopy(self.game_state)
      _, _, done, _ = game_state.step(action)
      child = Node(self, action, game_state)
      self.children.append(child)

  def explore(self):
    node = self
    while node.children:
      ucb_score_list = [x.UCB_score() for x in node.children]
      node_to_explore = node.children[np.argmax(ucb_score_list)]
      node = node_to_explore
    return node

  def rollout(self):
    if self.done:
      return 0
    tot_reward = 0
    done = False
    new_game = deepcopy(self.game_state)
    while not done:
        action = random.randint(0, GAME_ACTIONS-1)
        observation, reward, done, _ = new_game.step(action)
        tot_reward += reward
        if done:
            new_game.reset()
            new_game.close()
            break
    return tot_reward

  def backpropagate(self, reward):
    node = self
    while node.parent:
      node.T += reward
      node.N += 1
      node = node.parent
    node.T += reward
    node.N += 1

In [27]:
import time

def execute(root, step):
  for i in range(step):
    node = root.explore()
    if node.N != 0:
      node.add_children()
      node = node.children[random.randint(0, len(node.children)-1)]
    reward = node.rollout()
    node.backpropagate(reward)
  return root

def find_action(node):
  children = node.children
  sort_by_n = sorted(children, key=lambda x: x.N, reverse=True)
  max_n = [x for x in sort_by_n if x.N==sort_by_n[0].N]
  sort_by_t = sorted(max_n, key=lambda x: x.T, reverse=True)

  return sort_by_t[0], sort_by_t[0].action

In [None]:
import gym
import matplotlib.pyplot as plt

GAME_NAME = 'CartPole-v0'

env = gym.make(GAME_NAME)
GAME_ACTIONS = env.action_space.n
GAME_OBS = env.observation_space.shape[0]

print(f'In the {GAME_NAME} environment there are: {str(GAME_ACTIONS)} possible actions.')
print('In the ' + GAME_NAME + ' environment the observation is composed of: ' + str(GAME_OBS) + ' values.')

env.reset()
env.close()

In [None]:
episodes = 10
rewards = []
moving_average = []

'''
Here we are experimenting with our implementation:
- we play a certain number of episodes of the game
- for deciding each move to play at each step, we will apply our MCTS algorithm
- we will collect and plot the rewards to check if the MCTS is actually working.
- For CartPole-v0, in particular, 200 is the maximum possible reward.
'''

for e in range(episodes):

    reward_e = 0
    game = gym.make(GAME_NAME)
    observation = game.reset()
    done = False

    new_game = deepcopy(game)
    mytree = Node(None, 0, new_game)

    print('\n episode #' + str(e+1))

    while not done:

        print('.', end='')
        mytree = execute(mytree, 100)
        mytree, action = find_action(mytree)
        del mytree.parent
        mytree.parent = None

        observation, reward, done, _ = game.step(action)

        reward_e = reward_e + reward

        if done:
            print('\n reward_e ' + str(reward_e))
            game.close()
            break

    rewards.append(reward_e)

plt.plot(rewards)
plt.plot(moving_average)
plt.show()
print('moving average: ' + str(np.mean(rewards[-20:])))

env.reset()
env.close()