In [1]:
from typing import Mapping, Optional, Set, Callable
import numpy as np
from utils.generic_typevars import S, A
from utils.mp_funcs import get_rv_gen_func_single
from utils.mp_funcs import get_expected_action_value

### Prove the Epsilon-Greedy Policy Improvement Theorem (we sketched the proof in Class)

$$ \begin{split} Q_{\pi}(s, \pi'(s)) &= \sum_{a \in A} \pi'(a|s)Q_{\pi}(s, a) \\
&= \frac{\epsilon}{m}\sum_{a \in A} Q_{\pi}(s,a) + (1-\epsilon)\max_{a \in A}Q_{\pi} (s,a)\\
&\geq \frac{\epsilon}{m}\sum_{a \in A} Q_{\pi}(s,a) + (1-\epsilon)\sum_{a \in A} \frac{\pi (a|s)-\frac{\epsilon}{m}}{1-\epsilon} Q_{\pi}(s,a) \\
&\geq \sum_{a \in A} \pi(a|s)Q_{\pi}(s,a) \\
&= V_{\pi}(s)
\end{split} $$

### Provide (with clear mathematical notation) the defintion of GLIE (Greedy in the Limit with Infinite Exploration)

All state-action pairs are explored infinitely many times,
$$\lim_{k \rightarrow \infty} N_k(s,a) = \infty $$
The policy converges on a greedy policy,
$$ \lim_{k \rightarrow \infty } \pi_k (a|s) = 1(a=\arg\max_{a' \in A} Q_k(s, a')) $$
GLIE Monte-Carlo control converges to the optimal action-value function,
$$ Q(s,a) \rightarrow q^*(s, a)$$

### Implement the tabular SARSA and tabular SARSA(Lambda) algorithms

In [2]:
class MDPforRLTab:
    '''
        First define the MDP class
    '''
    def __init__(self, policy, actions: Mapping[S, Set[A]], terminal_states: Set[S], state_reward_gen_dict, gamma: float):
        self.policy = policy
        self.actions = actions
        self.terminal_states = terminal_states
        self.state_reward_gen_dict = state_reward_gen_dict # a dictionary of functions that generate the next state and reward
        self.gamma = gamma
    
    def get_actions(self, s):
        return self.actions[s]
    
    def get_terminal_states(self, s):
        return s in self.terminal_states
    
    def get_state_reward_gen_func(self, s, a):
        return self.state_reward_gen_dict[s][a]()
    
    def init_state_gen(self):
        dic = {}
        for s in self.actions.keys():
            dic[s] = 1. / len(self.actions)
        return get_rv_gen_func_single(dic)
    
    def init_state_action_gen(self):
        dic = {}
        for s, v1 in self.actions.items():
            for a in v1:
                dic[(s, a)] = 1. / sum(len(v) for v in self.actions.values())
                
                
class RLTabInterface:
    '''
    A model-free RL interface that does not need the state-transition probability model or the reward model
    '''
    
    def __init__(self, mdp: MDPforRLTab, exploring_start: bool, softmax: bool, epsilon: float, 
                 epsilon_half_life: float, num_episodes: int, max_steps: int):

        self.mdp = mdp

    # get a state-action dictionary
    def get_actions(self) -> Mapping[S, Set[A]]:
        return self.mdp.actions
    
    # check whether a state is a terminal state
    def get_terminal_states(self, s) -> bool:
        return self.mdp.get_terminal_states(s)
    
    # get a sampling of the (next state, reward) pair
    def get_next_pair(self, s, a):
        next_state, reward = self.mdp.get_state_reward_gen_func(s, a)
        

In [5]:
class Sarsa:
    def __init__(self, mdp: MDPforRLTab, epsilon: float, alpha: float, lamb: float, num_episodes: int, max_steps: int, learning_rate_decay: float):
        self.mdp = mdp,
        self.lamb = lamb
        self.epsilon = epsilon,
        self.num_episodes = num_episodes,
        self.max_steps = max_steps
        self.alpha = alpha      
        self.learning_rate_decay = learning_rate_decay

    def get_q_values(self, pol):
        policy = pol if pol is not None else self.get_init_policy()
        q_values = {s: {a: 0.0 for a in v} for s, v in self.mdp.state_action_dict.items()}
        episodes = 0
        updates = 0

        while episodes < self.num_episodes:
            state = self.mdp.init_state_gen()
            action = get_rv_gen_func_single(policy.get_state_probabilities(state))()
            steps = 0

            while True:
                next_state, reward = self.mdp.state_reward_gen_dict[state][action]()
                next_action = get_rv_gen_func_single(policy.get_state_probabilities(next_state))()
                next_q = get_expected_action_value(q_values[next_state], False, self.epsilon)
                q_values[state][action] += self.alpha*(updates / self.learning_rate_decay + 1) ** -0.5 *\
                    (reward + self.mdp.gamma * next_q - q_values[state][action])
                updates += 1
                steps += 1
                if steps >= self.max_steps or state in self.mdp.terminal_states:
                    break
                state = next_state
                action = next_action

            episodes += 1

        return q_values


### Implement the tabular Q-Learning algorithm

In [None]:
class QLearning:
    def __init__(self, mdp: MDPforRLTab, epsilon: float, alpha: float, lamb: float, num_episodes: int, max_steps: int, learning_rate_decay: float):
        self.mdp = mdp,
        self.lamb = lamb
        self.epsilon = epsilon,
        self.num_episodes = num_episodes,
        self.max_steps = max_steps
        self.alpha = alpha      
        self.learning_rate_decay = learning_rate_decay

    def get_q_values(self, pol):
        policy = pol if pol is not None else self.get_init_policy()
        q_values = {s: {a: 0.0 for a in v} for s, v in self.mdp.state_action_dict.items()}
        episodes = 0
        updates = 0

        while episodes < self.num_episodes:
            state = self.mdp.init_state_gen()
            action = get_rv_gen_func_single(policy.get_state_probabilities(state))()
            steps = 0

            while True:
                next_state, reward = self.mdp.state_reward_gen_dict[state][action]()
                next_action = get_rv_gen_func_single(policy.get_state_probabilities(next_state))()
                next_q = max(q_values[next_state][a] for a in q_values[next_state])
                q_values[state][action] += self.alpha*(updates / self.learning_rate_decay + 1) ** -0.5 *\
                    (reward + self.mdp.gamma * next_q - q_values[state][action])
                updates += 1
                steps += 1
                if steps >= self.max_steps or state in self.mdp.terminal_states:
                    break
                state = next_state
                action = next_action

            episodes += 1

        return q_values
