In [442]:
import warnings ; warnings.filterwarnings('ignore')

import itertools
import gym, gym_walk, gym_aima
import numpy as np
from tabulate import tabulate
from pprint import pprint
from tqdm import tqdm_notebook as tqdm

from itertools import cycle, count

import random
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
SEEDS = (12, 34, 56, 78, 90)

%matplotlib inline

In [443]:
plt.style.use('fivethirtyeight')
params = {
    'figure.figsize': (15, 8),
    'font.size': 24,
    'legend.fontsize': 20,
    'axes.titlesize': 28,
    'axes.labelsize': 24,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20
}
pylab.rcParams.update(params)
np.set_printoptions(suppress=True)

In [444]:
def value_iteration(P, gamma = 1.0, theta = 1e-10):
    """
    Value Iteration algorithm for Markov Decision Processes (MDPs).
    
    Parameters:
    P : dict
        Transition probabilities and rewards.
    gamma : float
        Discount factor.
    theta : float
        Threshold for convergence.
    
    Returns:
    V : dict
        Optimal value function.
    policy : dict
        Optimal policy.
    """
    V = np.zeros(len(P), dtype=np.float64)

    while True: 
        Q = np.zeros((len(P), len(P[0])), dtype=np.float64)

        for s in range(len(P)):
            for a in range(len(P[s])):
                Q[s, a] = sum(p * (r + gamma * V[s_]) for p, s_, r in P[s][a])

        if np.max(np.abs(V - np.max(Q, axis = 1))) < theta:
            break 
    pi = lambda s: np.argmax(Q[s])
    return V, pi, Q


In [445]:
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^'), n_cols=4, title='Policy:'):
    print(title)
    arrs = {k:v for k,v in enumerate(action_symbols)}
    for s in range(len(P)):
        a = pi(s)
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), arrs[a].rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [446]:
def print_state_value_function(V, P, n_cols=4, prec=3, title='State-value function:'):
    print(title)
    for s in range(len(P)):
        v = V[s]
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), '{}'.format(np.round(v, prec)).rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [447]:
def print_action_value_function(Q, 
                                optimal_Q=None, 
                                action_symbols=('<', '>'), 
                                prec=3, 
                                title='Action-value function:'):
    vf_types=('',) if optimal_Q is None else ('', '*', 'err')
    headers = ['s',] + [' '.join(i) for i in list(itertools.product(vf_types, action_symbols))]
    print(title)
    states = np.arange(len(Q))[..., np.newaxis]
    arr = np.hstack((states, np.round(Q, prec)))
    if not (optimal_Q is None):
        arr = np.hstack((arr, np.round(optimal_Q, prec), np.round(optimal_Q-Q, prec)))
    print(tabulate(arr, headers, tablefmt="fancy_grid"))

In [None]:
def get_metrics_from_tracks(env, gamma, goal_state, optimal_Q, pi_track, coverage = 0.1):
    """
    Calculate metrics from the tracks of the agent's performance.

    Parameters:
    env : gym.Env
        The environment.
    gamma : float
        Discount factor.
    goal_state : int
        The goal state index.
    optimal_Q : np.ndarray
        Optimal action-value function.
    pi_track : list
        List of policies tracked during the episodes.
    coverage : float
        Coverage threshold for the policy.

    Returns:
    metrics : dict
        Dictionary containing various performance metrics.
    """
    