In [1]:
import gym
from gym.spaces.utils import flatdim
import numpy as np
import scipy
import matplotlib.pyplot as plt
# from tqdm import trange
from tqdm.auto import tqdm  # notebook compatible
from minigrid.envs import FourRoomsEnv
import matplotlib.pyplot as plt
import graphviz
from copy import deepcopy
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Node:
    """
    Transition node. Represents the child of a node as a result of an action being taken. 
    """
    def __init__(self, state, action=None, prior=0):
        self.state = state
        self.action = action
        self.reward = 0
        self.done = False

        self.parent = None
        self.children = {}

        self.value = 0.0
        self.visits = 0
        self.prior = prior
    
    def __repr__(self):
        return f"Node: (s:{self.state}, a:{self.action}, r:{self.reward}, d:{self.done}, p:{self.prior})"
    
    def is_expanded(self):
        return len(self.children) > 0

    def update_stats(self, val):
        self.visits += 1
        self.value += (val - self.value) / float(self.visits)

class MCTS_AlphaZero:
    """
    MCTS using a similar approach to AlphaZero: https://arxiv.org/pdf/1712.01815.pdf
    """
    def __init__(self, cur_state, model, network, hparams):
        self.cur_state = cur_state
        self.model = model
        self.network = network
        self.iters: int = hparams["iters"]
        self.discount = hparams.get("discount", 0.99) 
        self.sim_depth = hparams.get("sim_depth", 1000)
        self.c_puct = 1
    
    def search(self):
        root = Node(self.cur_state)
        self._expand(root)
        #root.prior += dirichlet_noise()
        for _ in range(self.iters):
            # Tree Policy
            # Run through the tree and recursively select the best nodes with respect to their `PUCT` values
            next_node = root
            while next_node.is_expanded():
                next_node = self._PUCT(next_node)

            # Expand the leaf node by evaluating network for policy probs and value at state, sample most likely action
            value = self._expand(next_node)

            # Backup the value of this node or the reward if its a terminal state (TODO: Seems suspicious for non terminal rewards)
            self._backup(next_node, value)
        
        return self._best_action(root), root

    # "Most Robust Child" selection: http://www.incompleteideas.net/609%20dropbox/other%20readings%20and%20resources/MCTS-survey.pdf
    def _best_action(self, root):
        return max(root.children.values(), key = lambda child: child.visits).action

    def _expand(self, node: Node) -> Node:

        # Expand and add children with predicted prior
        prior, value = self.network(node.state)
        prior /= np.sum(prior)

        for action in self.model.actions(node.state):
            next_obs, r, done, _ = self.model.step(node.state, action)

            # Update tree with transition
            next_node = Node(next_obs, action=action, prior=prior[action])
            next_node.parent = node
            next_node.reward = r
            next_node.done = done
            node.children[action] = next_node

        return value


    # Detailed here: https://web.stanford.edu/~surag/posts/alphazero.html
    def _PUCT(self, node: Node) -> Node:

        # Get children and compute state-action values
        children: list[Node] = list(node.children.values())
        q_vals = np.array([child.value for child in children])

        # PUCT takes into account model probs + visitation counts
        puct_vals = q_vals +  self.c_puct * np.array([node.children[a].prior * np.sqrt(node.visits / (1 + node.children[a].visits)) for a in node.children.keys()])
        return np.random.choice(children, p=puct_vals/sum(puct_vals)) if np.any(puct_vals) else np.random.choice(children)



    def _backup(self, node: Node, value: float) -> None:
        node.update_stats(value)

        while node.parent:
            node = node.parent
            node.update_stats(self.discount * value)

def episode(env, model, network, config):
    obs = env.reset()
    done = False
    rewards = []
    states = [obs]
    logits = []
    while not done:
        action, root = MCTS_AlphaZero(obs, model, network, config).search()
        obs, r, done, _ = env.step(action)

        rewards.append(r)
        obs.append(obs)
        logits.append([child.prior for child in root.children.values()])

    returns = compute_returns(rewards, config["discount"])

    return states, logits, returns



# Compute discounts
# https://stackoverflow.com/questions/47970683/vectorize-a-numpy-discount-calculation
def compute_returns(rewards, discount):
    """
    C[i] = R[i] + discount * C[i+1]
    signal.lfilter(b, a, x, axis=-1, zi=None)
    a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
                          - a[1]*y[n-1] - ... - a[N]*y[n-N]
    """
    r = rewards[::-1]
    a = [1, -discount]
    b = [1]
    y = scipy.signal.lfilter(b, a, x=r)
    return y[::-1]

# For viz purposes
def unpack(root: Node, graph, ignore, action_map=None):
    if (not root.parent):
        graph.node(str(root.state))

    if len(root.children.values()) == 0:
        return graph
    
    for child in root.children.values():
        if (not (str(root.state), str(child.state), child.action) in ignore):
            graph.edge(str(root.state), str(child.state), label=str(child.action if not action_map else action_map[child.action]))
            ignore.add((str(root.state), str(child.state), child.action))

        graph = unpack(child, graph, ignore, action_map=action_map)
    
    return graph

# Visualize visitation frequency
def unpack_visits(root: Node, visit_freqs=defaultdict(lambda: 0)):
    if len(root.children.values()) == 0:
        visit_freqs[root.state] += root.visits
    
    for child in root.children.values():
        visit_freqs = unpack_visits(child, visit_freqs)
    
    return visit_freqs

In [3]:
# Env Models
class FrozenLakeModel:
    def __init__(self, transitions):
        self.model = transitions

    def step(self, obs, action):
        _, next_obs, r, done =  self.model[obs][action][0]
        return next_obs, r, done, _
    
    def actions(self, obs):
        return list(self.model[obs].keys())

class FourRoomsModel:
    def __init__(self, goal_pos=(16,16), seed=42):
        self.goal_pos = goal_pos
        self.seed = seed

    def step(self, agent_state, action) -> tuple[_, float]:
        agent_pos, agent_dir = agent_state
        env = FourRoomsEnvPos(agent_pos=agent_pos, goal_pos=self.goal_pos)
        env.reset(seed=self.seed)
        env.agent_dir = agent_dir
        return env.step(action)

    def actions(self, obs) -> list[int]:
        return list(range(env.action_space.n - 4))

class FourRoomsEnvPos(FourRoomsEnv):
    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
        super().__init__(agent_pos=agent_pos, goal_pos=goal_pos, **kwargs)
        self.max_steps = 1000
    
    def step(self, action):
        _, r, terminated, truncated, _ = super().step(action)
        return (self.agent_pos, self.agent_dir), int(self.agent_pos == self._goal_default_pos), terminated or truncated, _

    def reset(self, seed=None):
        obs, _ = super().reset(seed=seed)
        return self.agent_pos, self.agent_dir

lake_actions = {
    0: "LEFT",
    1: "DOWN", 
    2: "RIGHT",
    3: "UP"
}

room_actions = {
    0: "left",
    1: "right",
    2: "forward",
    3: "pickup",
    4: "drop",
    5: "toggle",
    6: "done" 
}

In [6]:
env = gym.make("FrozenLake-v1", is_slippery=False)
model = FrozenLakeModel(env.P)
class FakeNetwork:
    def __init__(self, nA):
        self.act_probs = np.ones(nA) / sum(np.ones(nA))
        self.value = np.random.rand()
    
    def __call__(self, obs):
        return self.act_probs, self.value

obs = env.reset()


obs = env.reset()
done = False
ep_r = 0
while not done:
    action, root = MCTS_AlphaZero(obs, model, FakeNetwork(env.action_space.n), {"iters": 1000}).search()
    print([(lake_actions[child.action], child.value) for child in root.children.values()])
    print(f"Moving {lake_actions[action]}")
    obs, r, done, _ = env.step(action)
    ep_r += r

print(f"Episode return: {ep_r}")

[('LEFT', 0.6662871439999686), ('DOWN', 0.6662886406757791), ('RIGHT', 0.6662873477884159), ('UP', 0.6662888686228624)]
Moving LEFT
[('LEFT', 0.9646975123523664), ('DOWN', 0.9646967659579148), ('RIGHT', 0.9646973607226275), ('UP', 0.964700462357331)]
Moving DOWN
[('LEFT', 0.13936913354720237), ('DOWN', 0.13936945684088048), ('RIGHT', 0.13936948138260022), ('UP', 0.13936886370336188)]
Moving UP
[('LEFT', 0.7474842951844973), ('DOWN', 0.7474850200932024), ('RIGHT', 0.7474831605048385), ('UP', 0.7474863098393514)]
Moving RIGHT
[('LEFT', 0.509061998711699), ('DOWN', 0.5090631362076741), ('RIGHT', 0.5090624218659394), ('UP', 0.5090603223338668)]
Moving UP
[('LEFT', 0.48741800733446816), ('DOWN', 0.48741889185075565), ('RIGHT', 0.48741698699358216), ('UP', 0.48741923376461466)]
Moving RIGHT
[('LEFT', 0.5445073110211908), ('DOWN', 0.5445052745593671), ('RIGHT', 0.544505445057249), ('UP', 0.5445049414631195)]
Moving UP
[('LEFT', 0.1710495733766843), ('DOWN', 0.17104994327389064), ('RIGHT', 0.1

In [3]:
class FourRoomsModel:
    def __init__(self, goal_pos=(16,16), seed=42):
        self.goal_pos = goal_pos
        self.seed = seed

    def step(self, agent_state, action) -> tuple[_, float]:
        agent_pos, agent_dir = agent_state
        env = FourRoomsEnvPos(agent_pos=agent_pos, goal_pos=self.goal_pos)
        env.reset(seed=self.seed)
        env.agent_dir = agent_dir
        return env.step(action)

    def actions(self, obs) -> list[int]:
        return list(range(env.action_space.n - 4))

class FourRoomsEnvPos(FourRoomsEnv):
    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
        super().__init__(agent_pos=agent_pos, goal_pos=goal_pos, **kwargs)
        self.max_steps = 1000
    
    def step(self, action):
        _, r, terminated, truncated, _ = super().step(action)
        return (self.agent_pos, self.agent_dir), int(self.agent_pos == self._goal_default_pos), terminated or truncated, _

    def reset(self, seed=None):
        obs, _ = super().reset(seed=seed)
        return self.agent_pos, self.agent_dir


action_map = {
    0: "left",
    1: "right",
    2: "forward",
    3: "pickup",
    4: "drop",
    5: "toggle",
    6: "done" 
}

In [4]:
class Fake:
    def __init__(self):
        self.string="hi"
    
    def __call__(self):
        return self.string

a = Fake()
print(a())

hi
