In [1]:
import gym
from gym.spaces.utils import flatdim
import numpy as np
import scipy
import matplotlib.pyplot as plt
# from tqdm import trange
from tqdm.auto import tqdm  # notebook compatible
from minigrid.envs import FourRoomsEnv
import matplotlib.pyplot as plt
import graphviz
from copy import deepcopy
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Node:
    """
    Transition node. Represents the child of a node as a result of an action being taken. 
    """
    def __init__(self, state, action=None, prior=0):
        self.state = state
        self.action = action
        self.reward = 0
        self.done = False

        self.parent = None
        self.children = {}

        self.value = 0.0
        self.visits = 0
        self.prior = prior
    
    def __repr__(self):
        return f"Node: (s:{self.state}, a:{self.action}, r:{self.reward}, d:{self.done}, p:{self.prior})"
    
    def is_expanded(self):
        return len(self.children) > 0

    def update_stats(self, val):
        self.visits += 1
        self.value += (val - self.value) / float(self.visits)

class MCTS_AlphaZero:
    """
    MCTS using a similar approach to AlphaZero: https://arxiv.org/pdf/1712.01815.pdf
    """
    def __init__(self, cur_state, model, network, hparams):
        self.cur_state = cur_state
        self.model = model
        self.network = network
        self.iters: int = hparams["search_iters"]
        self.discount = hparams.get("discount", 0.99) 
        self.c_puct = 1
    
    def search(self):
        root = Node(self.cur_state)
        self._expand(root)
        #root.prior += dirichlet_noise()
        for _ in range(self.iters):
            # Tree Policy
            # Run through the tree and recursively select the best nodes with respect to their `PUCT` values
            next_node = root
            while next_node.is_expanded():
                next_node = self._PUCT(next_node)

            # Expand the leaf node by evaluating network for policy probs and value at state, sample most likely action
            value = self._expand(next_node)

            # Backup the value of this node or the reward if its a terminal state (TODO: Seems suspicious for non terminal rewards)
            self._backup(next_node, value)
        
        return self._best_action(root), root

    # "Most Robust Child" selection: http://www.incompleteideas.net/609%20dropbox/other%20readings%20and%20resources/MCTS-survey.pdf
    def _best_action(self, root):
        return max(root.children.values(), key = lambda child: child.visits).action

    def _expand(self, node: Node) -> Node:

        # Expand and add children with predicted prior
        prior, value = self.network(torch.Tensor([node.state]))
        prior, value = prior.detach().numpy(), value.detach().numpy()[0]

        prior = self._normalize(prior)
        for action in self.model.actions(node.state):
            next_obs, r, done, _ = self.model.step(node.state, action)

            # Update tree with transition
            next_node = Node(next_obs, action=action, prior=prior[action])
            next_node.parent = node
            next_node.reward = r
            next_node.done = done
            node.children[action] = next_node

        return value


    # Detailed here: https://web.stanford.edu/~surag/posts/alphazero.html
    def _PUCT(self, node: Node) -> Node:

        # Get children and compute state-action values
        children: list[Node] = list(node.children.values())
        q_vals = np.array([child.value for child in children])

        # PUCT takes into account model probs + visitation counts
        puct_vals = q_vals +  self.c_puct * np.array([node.children[a].prior * np.sqrt(node.visits / (1 + node.children[a].visits)) for a in node.children.keys()])
        
        if (np.any(puct_vals)):
            puct_vals = self._normalize(puct_vals)
            return np.random.choice(children, p=puct_vals)
        
        return np.random.choice(children)



    def _backup(self, node: Node, value: float) -> None:
        node.update_stats(value)

        while node.parent:
            node = node.parent
            node.update_stats(self.discount * value)
    
    def _normalize(self, arr):
        shifted = (arr - np.min(arr))
        return shifted /np.sum(shifted)


def episode(env, model, network, memory, config):
    # Run network in evaluation mode
    network.eval()

    obs = env.reset()
    done = False

    rewards = []
    logits = []
    ep_obs = []
    while not done:
        action, root = MCTS_AlphaZero(obs, model, network, config).search()
        obs, r, done, _ = env.step(action)

        rewards.append(r)
        ep_obs.append(obs)
        logits.append([child.prior for child in root.children.values()])

    returns = compute_returns(rewards, config["discount"])

    for i in range(len(returns)):
        memory.store_transition(ep_obs[i], logits[i], returns[i])

# Compute discounts
# https://stackoverflow.com/questions/47970683/vectorize-a-numpy-discount-calculation
def compute_returns(rewards, discount):
    """
    C[i] = R[i] + discount * C[i+1]
    signal.lfilter(b, a, x, axis=-1, zi=None)
    a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
                          - a[1]*y[n-1] - ... - a[N]*y[n-N]
    """
    r = rewards[::-1]
    a = [1, -discount]
    b = [1]
    y = scipy.signal.lfilter(b, a, x=r)
    return y[::-1]

In [3]:
# Env Models
class FrozenLakeModel:
    def __init__(self, transitions):
        self.model = transitions

    def step(self, obs, action):
        _, next_obs, r, done =  self.model[obs][action][0]
        return next_obs, r, done, _
    
    def actions(self, obs):
        return list(self.model[obs].keys())

class FourRoomsModel:
    def __init__(self, goal_pos=(16,16), seed=42):
        self.goal_pos = goal_pos
        self.seed = seed

    def step(self, agent_state, action) -> tuple[_, float]:
        agent_pos, agent_dir = agent_state
        env = FourRoomsEnvPos(agent_pos=agent_pos, goal_pos=self.goal_pos)
        env.reset(seed=self.seed)
        env.agent_dir = agent_dir
        return env.step(action)

    def actions(self, obs) -> list[int]:
        return list(range(env.action_space.n - 4))

class FourRoomsEnvPos(FourRoomsEnv):
    def __init__(self, agent_pos=None, goal_pos=None, **kwargs):
        super().__init__(agent_pos=agent_pos, goal_pos=goal_pos, **kwargs)
        self.max_steps = 1000
    
    def step(self, action):
        _, r, terminated, truncated, _ = super().step(action)
        return (self.agent_pos, self.agent_dir), int(self.agent_pos == self._goal_default_pos), terminated or truncated, _

    def reset(self, seed=None):
        obs, _ = super().reset(seed=seed)
        return self.agent_pos, self.agent_dir

lake_actions = {
    0: "LEFT",
    1: "DOWN", 
    2: "RIGHT",
    3: "UP"
}

room_actions = {
    0: "left",
    1: "right",
    2: "forward",
    3: "pickup",
    4: "drop",
    5: "toggle",
    6: "done" 
}

In [4]:
import torch
from torch import nn

# Tools to train model
class DenseMCTSNet(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(1, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(84, env.action_space.n) 
        self.value_head = nn.Linear(84, 1)
    
    def forward(self, x):
        x = self.network(x)
        return self.policy_head(x), self.value_head(x)

# Replay memory to store transitions, in a cyclical buffer so olded transitions are removed first (Could be interesting to try a prioritized buffer instead)
class ReplayMemory():
    def __init__(self, size, env):
        self.size = size
        self.counter = 0

        obs_shape = env.observation_space.shape

        self.obs = np.zeros((size, *obs_shape), dtype=np.float32)
        self.action_probs = np.zeros((size, env.action_space.n), env.action_space.dtype)
        self.returns = np.zeros((size, 1), dtype=np.float32)
    
    def store_transition(self, obs, action_prob, r):
        indx = self.counter % self.size

        self.obs[indx] = np.array(obs).copy()
        self.action_probs[indx] = np.array(action_prob).copy()
        self.returns[indx] = np.array(r).copy()

        self.counter += 1
    
    def sample_batch(self, batch_size):
        batch = np.random.randint(0, self.size, size=batch_size)

        obs = torch.as_tensor(self.obs[batch])
        action_probs = torch.as_tensor(self.action_probs[batch])
        returns = torch.as_tensor(self.returns[batch])

        return obs, action_probs, returns


def train_on_batch(network, optimizer, batch):
    # Run network in training mode
    network.train()

    obs = batch[0]
    action_probs = batch[1]
    returns = batch[2]

    policy_probs, pred_vals = network(obs)

    value_loss = nn.MSELoss()(pred_vals, returns)
    policy_loss = nn.CrossEntropyLoss()(action_probs, policy_probs)

    loss = value_loss + policy_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


def policy_iteration(env, model, network, config):
    memory = ReplayMemory(config["memory_size"], env)
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-5)
    for global_step in tqdm(range(config["training_steps"])):
        episode(env, model, network, memory, config)

        # Only train once buffer large enough
        if (memory.counter > memory.size and global_step % config["train_freq"] == 0):
            batch = memory.sample_batch(config["batch_size"])
            network = train_on_batch(network, optimizer, batch) 
    return network


DEFAULT_CONFIG = {
    "training_steps": 250_000,
    "batch_size": 256,
    "learning_rate": 1e-4,
    "memory_size": 50_000,
    "discount": 0.99,
    "train_freq": 4,
    "search_iters": 800
}

In [5]:
env = gym.make("FrozenLake-v1")
model = FrozenLakeModel(env.P)
net = DenseMCTSNet(env)

policy_iteration(env, model, net, DEFAULT_CONFIG)

  0%|          | 203/250000 [08:44<179:22:58,  2.59s/it]


KeyboardInterrupt: 