<a href="https://colab.research.google.com/github/fezilemahlangu/Reinforcement-Learning-Project/blob/master/MCTS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installing Dependencies

In [None]:
#https://towardsdatascience.com/deep-reinforcement-learning-and-monte-carlo-tree-search-with-connect-4-ba22a4713e7a


In [None]:
!apt update
!apt install -y cmake
!apt-get install -y build-essential autoconf libtool pkg-config
!apt-get install flex bison libbz2-dev
!pip install nle
!pip install minihack
!python -m minihack.scripts.env_list
!pip install gym[atari,accept-rom-license]

# Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gym
import nle
import minihack
from gym import spaces
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch
import random
from collections import defaultdict
import warnings
warnings.filterwarnings("ignore")

# MCTS

# Node

In [None]:
class MCTS():
  '''
  Class for MCTS node 
  '''
  def __init__(self,state, parent,  parent_action):

    self.state = state
    self.parent = parent # parent node in MCTS
    self.parent_action = parent_action # action that the parent took
    self.children = [] #children of parent node 
    self.visit_count = 1 # keeps count of how many times node has been visited
    self.rewards = defaultdict(int)
  def expand(self,env):

    '''
    expand node and take unexplored action 
    '''
    action = None #should be unexplored action 

    obs, reward, done, _ = env.step(action)

    child_node = MCTS(obs,self,action) #obs should be state 

    self.children.append(child_node)

    return child_node, env

  def back_propagate(self,reward):
    '''
    Back propagate reward on all nodes from leaf to root and update visit count 
    '''
    
    self.visit_count += 1
    self.rewards[reward] += 1

    if self.parent:
      self.parent.back_propagate(reward)

  def unexplored_actions(self):
    '''
    returns unexplored actions 
    '''

  def rollout_policy(self):
    '''
    rollout policy 
    '''

  def tree_policy(self,env): #->fix 
    '''
    keeps expanding tree until terminal node is reached 
    '''
    curr = self
    while not curr.state.done:
      if len(self.children) < env.action_space.n:
        self.expand(env)
      else:
        curr = curr.best_child(c_p = 0.1)

    return curr, env



  def best_action(self):
    '''
    find next best action 
    '''


  
  def best_child(self,c_p):
    '''
    finds best child using UCT
    '''

    ns = self.visit_count #visit count of parent

    ni = [c.visit_count for c in self.children] #visit count of chilren 

    q = 0
    for r in self.rewards:
      q += r*self.rewards[r]

    first_term = q / ni

    second_term = c_p * np.sqrt((2*np.log(ns))/ni) 

    UCB1 = first_term + second_term 

    best_child = np.argmax(UCB1)

    return self.children[best_child]
    



# Main

In [None]:
env = gym.make("MiniHack-Quest-Hard-v0")

init_state = env.reset()