In [None]:
import math
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from torch.distributions.categorical import Categorical
from typing import Any, Dict, List, Optional, Set, Tuple, Union

In [None]:
def mlp(sizes: List[int], activation=nn.ReLU, output_activation=nn.Identity):
    """Returns a multi-layer perceptron"""
    mlp = nn.Sequential()
    for i in range(len(sizes) - 1):
        mlp.append(nn.Linear(sizes[i], sizes[i + 1]))
        if i < len(sizes) - 2:
            mlp.append(activation())
        else:
            mlp.append(output_activation())
    return mlp


class CategorialPolicy:
    def __init__(self, sizes: List[int], actions: List):
        assert sizes[-1] == len(actions)
        torch.manual_seed(1337)
        self.net = mlp(sizes=sizes)
        self.actions = actions
        self._actions_tensor = torch.as_tensor(actions, dtype=torch.float32).view(
            len(actions), -1
        )

    def _get_distribution(self, state: torch.Tensor) -> Categorical:
        """Calls the model and returns a categorial distribution over the actions."""
        logits = self.net(state)
        return Categorical(logits=logits)

    def get_action(
        self, state: torch.Tensor, deterministic: bool = False
    ) -> torch.Tensor:
        """Returns an action sample for the given state"""
        policy = self._get_distribution(state)
        if deterministic:
            return self.actions[policy.mode.item()]
        return self.actions[policy.sample().item()]

    def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Returns the log-probability for taking the action, when being the given state"""
        return self._get_distribution(states).log_prob(
            self._get_action_id_from_action(actions)
        )

    def _get_action_id_from_action(self, actions: torch.Tensor) -> torch.Tensor:
        """Returns the indices of the passed actions in self.actions"""
        reshaped_actions = actions.unsqueeze(1).expand(
            -1, self._actions_tensor.size(0), -1
        )
        reshaped_actions_tensor = self._actions_tensor.unsqueeze(0).expand(
            actions.size(0), -1, -1
        )
        return torch.where(
            torch.all(reshaped_actions == reshaped_actions_tensor, dim=-1)
        )[1]

In [None]:
SIMPLE_MDP_DICT = {
    "states": [1, 2],
    "actions": ["A", "B"],
    "initial_state": 1,
    "terminal_states": [2],
    "transition_probabilities": {
        (1, "A"): [(0.2, 1), (0.8, 2)],
        (1, "B"): [(0.5, 1), (0.5, 2)],
        (2, "A"): [(1.0, 1)],
        (2, "B"): [(0.3, 1), (0.7, 2)],
    },
    "reward": {1: -0.1, 2: -0.5},
}

GRID_MDP_DICT = {
    "grid": [
        [-0.04, -0.04, -0.04, +1],
        [-0.04, None, -0.04, -1],
        [-0.04, -0.04, -0.04, -0.04],
    ],
    "initial_state": (1, 0),
    "terminal_states": {(3, 2), (3, 1)},
    "transition_probabilities_per_action": {
        (0, 1): [(0.8, (0, 1)), (0.1, (1, 0)), (0.1, (-1, 0))],
        (0, -1): [(0.8, (0, -1)), (0.1, (1, 0)), (0.1, (-1, 0))],
        (1, 0): [(0.8, (1, 0)), (0.1, (0, 1)), (0.1, (0, -1))],
        (-1, 0): [(0.8, (-1, 0)), (0.1, (0, 1)), (0.1, (0, -1))],
    },
}

LC_LEFT_ACTION, STAY_IN_LANE_ACTION, LC_RIGHT_ACTION = (1, 1), (1, 0), (1, -1)

HIGHWAY_MDP_DICT = {
    "grid": [
        [0, -1, -1, -1, -1, -1, -1, -1, -1, -50],
        [0, -2, -2, -2, -2, -2, -2, -2, -2, -50],
        [0, -3, -3, -3, -3, -3, -3, -3, -3, -50],
        [None, None, None, None, None, None, -2, -2, -2, 0],
    ],
    "initial_state": (0, 2),
    "terminal_states": {(9, 3), (9, 1), (9, 2), (9, 0)},
    "transition_probabilities_per_action": {
        STAY_IN_LANE_ACTION: [(1.0, STAY_IN_LANE_ACTION)],
        LC_LEFT_ACTION: [(0.5, LC_LEFT_ACTION), (0.5, STAY_IN_LANE_ACTION)],
        LC_RIGHT_ACTION: [(0.75, LC_RIGHT_ACTION), (0.25, STAY_IN_LANE_ACTION)],
    },
    "restrict_actions_to_available_states": False,
}


class MDP:
    def __init__(
        self,
        states: Set[Any],
        actions: Set[Any],
        initial_state: Any,
        terminal_states: Set[Any],
        transition_probabilities: Dict[Tuple[Any, Any], List[Tuple[float, Any]]],
        reward: Dict[Any, float],
    ) -> None:
        """A Markov decision process.

        Args:
            states: Set of states.
            actions: Set of actions.
            initial_state: Initial state.
            terminal_states: Set of terminal states.
            transition_probabilities: Dictionary of transition
                probabilities, mapping from tuple (state, action) to
                list of tuples (probability, next state).
            reward: Dictionary of rewards per state, mapping from state
                to reward.
        """
        self.states = states

        self.actions = actions

        assert initial_state in self.states
        self.initial_state = initial_state

        for terminal_state in terminal_states:
            assert (
                terminal_state in self.states
            ), f"The terminal state {terminal_state} is not in states {states}"
        self.terminal_states = terminal_states

        for state in self.states:
            for action in self.actions:
                if (state, action) not in transition_probabilities:
                    continue
                total_prob = 0
                for prob, next_state in transition_probabilities[(state, action)]:
                    assert (
                        next_state in self.states
                    ), f"next_state={next_state} is not in states={states}"
                    total_prob += prob
                assert math.isclose(total_prob, 1), "Probabilities must add to one"
        self.transition_probabilities = transition_probabilities

        assert set(reward.keys()) == set(
            self.states
        ), "Rewards must be defined for every state in the set of states"
        for state in self.states:
            assert reward[state] is not None
        self.reward = reward

    def get_states(self) -> Set[Any]:
        """Get the set of states."""
        return self.states

    def get_actions(self, state) -> Set[Any]:
        """Get the set of actions available in a certain state, returns [None] for terminal states."""
        if self.is_terminal(state):
            return {None}
        return set(
            [a for a in self.actions if (state, a) in self.transition_probabilities]
        )

    def get_reward(self, state) -> float:
        """Get the reward for a specific state."""
        return self.reward[state]

    def is_terminal(self, state) -> bool:
        """Return whether a state is a terminal state."""
        return state in self.terminal_states

    def get_transitions_with_probabilities(
        self, state, action
    ) -> List[Tuple[float, Any]]:
        """Get the list of transitions with their probability, returns [(0.0, state)] for terminal states."""
        if action is None or self.is_terminal(state):
            return [(0.0, state)]
        return self.transition_probabilities[(state, action)]

    def sample_next_state(self, state, action) -> Any:
        """Randomly sample the next state given the current state and taken action."""
        if self.is_terminal(state):
            return ValueError("No next state for terminal states.")
        if action is None:
            return ValueError("Action must not be None.")
        prob_per_transition = self.get_transitions_with_probabilities(state, action)
        num_actions = len(prob_per_transition)
        choice = np.random.choice(
            num_actions, p=[ppa[0] for ppa in prob_per_transition]
        )
        return prob_per_transition[choice][1]

    def execute_action(self, state, action) -> Tuple[Any, float, bool]:
        """Executes the action in the current state and returns the new state, obtained reward and terminal flag."""
        new_state = self.sample_next_state(state=state, action=action)
        reward = self.get_reward(state=new_state)
        terminal = self.is_terminal(state=new_state)
        return new_state, reward, terminal


class GridMDP(MDP):
    def __init__(
        self,
        grid: List[List[Union[float, None]]],
        initial_state: Tuple[int, int],
        terminal_states: Set[Tuple[int, int]],
        transition_probabilities_per_action: Dict[
            Tuple[int, int], List[Tuple[float, Tuple[int, int]]]
        ],
        restrict_actions_to_available_states: Optional[bool] = False,
    ) -> None:
        """A Markov decision process on a grid.

        Args:
            grid: List of lists, containing the rewards of the grid
                states or None.
            initial_state: Initial state in the grid.
            terminal_states: Set of terminal states in the grid.
            transition_probabilities_per_action: Dictionary of
                transition probabilities per action, mapping from action
                to list of tuples (probability, next state).
            restrict_actions_to_available_states: Whether to restrict
                actions to those that result in valid next states.
        """
        states = set()
        reward = {}
        grid = grid.copy()
        grid.reverse()  # y-axis pointing upwards
        rows = len(grid)
        cols = len(grid[0])
        self.grid = grid
        for x in range(cols):
            for y in range(rows):
                if grid[y][x] is not None:
                    states.add((x, y))
                    reward[(x, y)] = grid[y][x]

        transition_probabilities = {}
        for state in states:
            for action in transition_probabilities_per_action.keys():
                transition_probability_list = self._generate_transition_probability_list(
                    state=state,
                    action=action,
                    restrict_actions_to_available_states=restrict_actions_to_available_states,
                    states=states,
                    transition_probabilities_per_action=transition_probabilities_per_action,
                    next_state_fn=self._next_state_deterministic,
                )
                if transition_probability_list:
                    transition_probabilities[
                        (state, action)
                    ] = transition_probability_list

        super().__init__(
            states=states,
            actions=set(transition_probabilities_per_action.keys()),
            initial_state=initial_state,
            terminal_states=terminal_states,
            transition_probabilities=transition_probabilities,
            reward=reward,
        )

    @staticmethod
    def _generate_transition_probability_list(
        state,
        action,
        restrict_actions_to_available_states,
        states,
        transition_probabilities_per_action,
        next_state_fn,
    ):
        """Generate the transition probability list of the grid."""
        transition_probability_list = []
        none_in_next_states = False
        for (
            probability,
            deterministic_action,
        ) in transition_probabilities_per_action[action]:
            next_state = next_state_fn(
                state,
                deterministic_action,
                states,
                output_none_if_non_existing_state=restrict_actions_to_available_states,
            )
            if next_state is None:
                none_in_next_states = True
                break
            transition_probability_list.append((probability, next_state))

        if not none_in_next_states:
            return transition_probability_list

        return []

    @staticmethod
    def _next_state_deterministic(
        state, action, states, output_none_if_non_existing_state=False
    ):
        """Output the next state given the action in a deterministic setting.
        Output None if next state not existing in case output_none_if_non_existing_state is True.
        """
        next_state_candidate = tuple(np.array(state) + np.array(action))
        if next_state_candidate in states:
            return next_state_candidate
        if output_none_if_non_existing_state:
            return None
        return state


def policy_gradient(
    *,
    mdp: MDP,
    pol: CategorialPolicy,
    lr: float = 1e-2,
    iterations: int = 50,
    batch_size: int = 5000,
    return_history: bool = False,
    use_random_init_state: bool = False,
    verbose: bool = True,
) -> Union[List[CategorialPolicy], CategorialPolicy]:
    """Train a paramterized policy using vanilla policy gradient.

    Adapted from: https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py

    The MIT License (MIT)

    Copyright (c) 2018 OpenAI (http://openai.com)

    Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

    Args:
        mdp: The underlying MDP.
        pol: The stochastic policy to be trained.
        lr: Learning rate.
        iterations: Number of iterations.
        batch_size: Number of samples generated for each policy update.
        return_history: Whether to return the whole history of value estimates
            instead of just the final estimate.
        use_random_init_state: bool, if the agent should be initialized randomly.
        verbose: bool, if traing progress should be printed.

    Returns:
        The final policy, if return_history is false. The
        history of policies as list, if return_history is true.
    """
    np.random.seed(1337)
    torch.manual_seed(1337)

    # add untrained model to model_checkpoints
    model_checkpoints = [deepcopy(pol)]

    # make optimizer
    optimizer = torch.optim.Adam(pol.net.parameters(), lr=lr)

    # get non-terminal states
    non_terminal_states = [state for state in mdp.states if not mdp.is_terminal(state)]

    # training loop
    for i in range(1, iterations + 1):

        # make some empty lists for logging.
        buffer = {
            "states": [],
            "actions": [],
            "weights": [],
            "ep_rets": [],
            "ep_lens": [],
        }

        # reset episode-specific variables
        if use_random_init_state:
            state = non_terminal_states[np.random.choice(len(non_terminal_states))]
        else:
            state = mdp.initial_state
        episode_rewards = []

        # collect experience by acting in the mdp
        while True:
            # save visited state
            buffer["states"].append(deepcopy(state))

            # call model to get next action
            action = pol.get_action(state=torch.as_tensor(state, dtype=torch.float32))

            # execute action in the environment
            state, reward, done = mdp.execute_action(state=state, action=action)

            # save action, reward
            buffer["actions"].append(action)
            episode_rewards.append(reward)

            if done:
                # if episode is over, record info about episode
                episode_return = sum(episode_rewards)
                episode_length = len(episode_rewards)
                buffer["ep_rets"].append(episode_return)
                buffer["ep_lens"].append(episode_length)
                # todo 1: calculate the gradient weights for the current episode and append them to buffer["weights"]
                # must be the same length as the number of samples collected in that episode
                buffer["weights"] += [episode_return] * episode_length

                # reset episode-specific variables
                if use_random_init_state:
                    state = non_terminal_states[
                        np.random.choice(len(non_terminal_states))
                    ]
                else:
                    state = mdp.initial_state
                episode_rewards = []

                # end experience loop if we have enough of it
                if len(buffer["states"]) > batch_size:
                    break

        # compute the loss
        # todo 2: calculate the objective to minimize (max J = min -J)
        # you can convert lists to tensors using torch.as_tensor(list, dtype=torch.float32)
        logp = pol.get_log_prob(
            states=torch.as_tensor(buffer["states"], dtype=torch.float32),
            actions=torch.as_tensor(buffer["actions"], dtype=torch.float32),
        )
        batch_loss = -(
            logp * torch.as_tensor(buffer["weights"], dtype=torch.float32)
        ).mean()

        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        # logging
        if verbose:
            print(
                "iteration: %3d;  return: %.3f;  episode_length: %.3f"
                % (i, np.mean(buffer["ep_rets"]), np.mean(buffer["ep_lens"]))
            )
        if return_history:
            model_checkpoints.append(deepcopy(pol))
    if return_history:
        return model_checkpoints
    return pol


def derive_deterministic_policy(mdp: MDP, pol: CategorialPolicy) -> Dict[Any, Any]:
    """Compute the best policy for an MDP given the stochastic policy.

    Args:
        mdp: The underlying MDP.
        pol: The stochastic policy.

    Returns:
        Policy, i.e. mapping from state to action.
    """
    pi = {}
    for state in mdp.get_states():
        if mdp.is_terminal(state):
            continue
        pi[state] = pol.get_action(
            state=torch.as_tensor(state, dtype=torch.float32), deterministic=True
        )
    return pi

In [None]:
def make_plot_policy_step_function(columns, rows, policy_over_time, show=True):
    """Create a function that allows plotting a policy over time."""

    def plot_grid_step(iteration):
        data = policy_over_time[iteration]
        for row in range(rows):
            for col in range(columns):
                if not (col, row) in data:
                    continue
                x = col + 0.5
                y = row + 0.5
                if data[(col, row)] is None:
                    plt.scatter([x], [y], color="black")
                    continue
                dx = data[(col, row)][0]
                dy = data[(col, row)][1]
                scaling = np.sqrt(dx**2.0 + dy**2.0) * 2.5
                dx /= scaling
                dy /= scaling
                plt.arrow(
                    x,
                    y,
                    dx,
                    dy,
                    shape="full",
                    lw=1.0,
                    length_includes_head=True,
                    head_width=0.15,
                )
        plt.axis("equal")
        plt.xlim([0, columns])
        plt.ylim([0, rows])
        if show:
            plt.show()

    return plot_grid_step

## TOY EXAMPLE

In [None]:
grid_mdp = GridMDP(**GRID_MDP_DICT)

In [None]:
pol = CategorialPolicy(
    sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
    actions=list(grid_mdp.actions),
)

In [None]:
model_checkpoints = policy_gradient(
    mdp=grid_mdp,
    pol=pol,
    iterations=100,
    return_history=True,
)

In [None]:
policy_array = [
    derive_deterministic_policy(mdp=grid_mdp, pol=model) for model in model_checkpoints
]

In [None]:
plot_policy_step_grid_map = make_plot_policy_step_function(
    columns=4, rows=3, policy_over_time=policy_array
)

In [None]:
mkdocs_flag = False
if mkdocs_flag:
    import ipywidgets
    from IPython.display import display

    iteration_slider = ipywidgets.IntSlider(
        min=0, max=len(model_checkpoints) - 1, step=1, value=0
    )
    w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
    display(w)

In [None]:
plot_policy_step_grid_map(100)

## HIGHWAY EXAMPLE

In [None]:
if False:
    # we will change this to true later on, to see the effect
    HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
        (0.4, LC_RIGHT_ACTION),
        (0.6, STAY_IN_LANE_ACTION),
    ]

In [None]:
HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)

In [None]:
pol = CategorialPolicy(
    sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
    actions=list(highway_mdp.actions),
)

In [None]:
model_checkpoints = policy_gradient(
    mdp=highway_mdp,
    pol=pol,
    iterations=200,
    return_history=True,
)

In [None]:
policy_array = [
    derive_deterministic_policy(mdp=highway_mdp, pol=model)
    for model in model_checkpoints
]

In [None]:
plot_policy_step_grid_map = make_plot_policy_step_function(
    columns=10, rows=4, policy_over_time=policy_array
)

In [None]:
if mkdocs_flag:
    import ipywidgets
    from IPython.display import display

    iteration_slider = ipywidgets.IntSlider(
        min=0, max=len(model_checkpoints) - 1, step=1, value=0
    )
    w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
    display(w)

In [None]:
plot_policy_step_grid_map(200)