In [74]:
from __future__ import annotations

import numpy as np
from rl.function_approx import FunctionApprox 
from dataclasses import dataclass
from typing import TypeVar, Iterable, Tuple, Optional
from collections import defaultdict

X = TypeVar('X')

import numpy as np
import sys
sys.path.append("../")
sys.path.append("/Users/abhinavrangarajan/opt/anaconda3/lib/python3.7/site-packages/")
from rl.chapter3.simple_inventory_mdp_cap import *

from pprint import pprint


# Problem 2 : Implement Tabular SARSA, Experiment with SimpleInventoryMDPCap

In [7]:

user_capacity = 2
user_poisson_lambda = 1.0
user_holding_cost = 1.0
user_stockout_cost = 10.0

user_gamma = 0.9

si_mdp: FiniteMarkovDecisionProcess[InventoryState, int] =\
    SimpleInventoryMDPCap(
        capacity=user_capacity,
        poisson_lambda=user_poisson_lambda,
        holding_cost=user_holding_cost,
        stockout_cost=user_stockout_cost
    )

In [68]:
non_term_size = len(si_mdp.non_terminal_states)
eps = 0.5
tab = TabularApprox()
gamma = 0.9

In [69]:
for _ in range(1000):
    s = si_mdp.non_terminal_states[np.random.choice(non_term_size)]
    available_actions = list(si_mdp.actions(s))
    if np.random.rand() < eps:
        action_idx = np.random.choice(len(available_actions))
    else:
        action_idx = np.argmax([tab.value((s, a)) for a in available_actions])
    a = available_actions[action_idx]
    s_, r = si_mdp.step(s, a).sample()
    
    available_actions_ = list(si_mdp.actions(s_))
    action_idx_ = np.argmax([tab.value((s_, a_)) for a_ in available_actions_])
    a_ = available_actions_[action_idx_]
    
    tab.update([( (s, a), r + gamma*tab.value((s_, a_)) )])


In [72]:
for s, val in tab.value_dict.items():
    print(f"State {s} => Value {val}")

State (InventoryState(on_hand=1, on_order=0), 0) => Value -13.427191179632626
State (InventoryState(on_hand=1, on_order=0), 1) => Value -11.590707772394008
State (InventoryState(on_hand=0, on_order=2), 0) => Value -10.1707106900751
State (InventoryState(on_hand=0, on_order=1), 0) => Value -14.381099237678958
State (InventoryState(on_hand=0, on_order=1), 1) => Value -10.943091978166041
State (InventoryState(on_hand=2, on_order=0), 0) => Value -12.302556168350002
State (InventoryState(on_hand=1, on_order=1), 0) => Value -11.065536401503138
State (InventoryState(on_hand=0, on_order=0), 0) => Value -23.62499752310123
State (InventoryState(on_hand=0, on_order=0), 1) => Value -18.07335805305571
State (InventoryState(on_hand=0, on_order=0), 2) => Value -17.487599511192123


In [49]:
class TabularApprox(FunctionApprox):
    def __init__(self):
        self.count_dict = defaultdict(int)
        self.value_dict = defaultdict(int)
    
    def value(self, s):
        return self.value_dict[s]
    
    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        return np.array([self.value_dict[x] for x in x_values_seq])
    
    def representational_gradient(self, x_value: X) -> TabularApprox[X]:
        pass
    
    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> TabularApprox[X]:
        tmp = TabularApprox()
        tmp.update(xy_vals_seq=xy_vals_seq)
        return tmp
    
    def update(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> TabularApprox:
        for (x,y) in xy_vals_seq:
            self.count_dict[x] += 1
            self.value_dict[x] += (1 / self.count_dict[x]) * (y - self.value_dict[x])
        return self
    
    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        if isinstance(other, TabularApprox):
            return np.all(
                (k in other) and (other.value_dict[k] == self.value_dict[k])
                for k in self.value_dict.keys()
            )

        return False