## Monte Carlo Prediction and Tabular

In [1]:
import os
os.chdir('/Users/Alex/Desktop/Documents_4A/Winter_quarter_1/MS&E_346/RL_book/')

In [2]:
from rl.chapter2.simple_inventory_mrp import SimpleInventoryMRPFinite
user_capacity = 2
user_poisson_lambda = 1.0
user_holding_cost = 1.0
user_stockout_cost = 10.0
user_gamma = 0.9
si_mrp = SimpleInventoryMRPFinite(
    capacity=user_capacity,
    poisson_lambda=user_poisson_lambda,
    holding_cost=user_holding_cost,
    stockout_cost=user_stockout_cost
)
si_mrp.display_value_function(gamma=user_gamma)

{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.511,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.932,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.345,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.932,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.345,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.345}


In [3]:
from rl.chapter2.simple_inventory_mrp import InventoryState
from rl.function_approx import Tabular
from rl.approximate_dynamic_programming import ValueFunctionApprox
from rl.distribution import Choose
from rl.iterate import last
from rl.monte_carlo import mc_prediction
from itertools import islice
from pprint import pprint
from typing import Iterable, Iterator, TypeVar, Callable

traces = si_mrp.reward_traces(Choose(si_mrp.non_terminal_states))

it: Iterator[ValueFunctionApprox[InventoryState]] = mc_prediction(
    traces=traces,
    approx_0=Tabular(),
    gamma=user_gamma,
    episode_length_tolerance=1e-6
)
print("test")
num_traces = 1000
last_func: ValueFunctionApprox[InventoryState] = last(islice(it, num_traces))
    
pprint({s: round(last_func.evaluate([s])[0], 3) for s in si_mrp.non_terminal_states})

test
{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.498,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.884,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.32,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.915,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.323,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.29}


In [17]:
#Tabular from scratch

from typing import Iterable, Tuple, Mapping, TypeVar, Iterator, Sequence
from operator import itemgetter
import numpy as np
from itertools import *
from numpy.random import randint
from collections import defaultdict
import itertools

S = TypeVar('S')

Traces = Iterable[Iterable[Tuple[S, float]]]

RewardFunc = Mapping[S, float]
ValueFunc = Mapping[S, float]

def full_group_by(l, key=lambda x: x[0]):
    d = defaultdict(list)
    for item in l:
        d[key(item)].append(item[1])
    return d.items()

#Function that creates a sequence of pairs (state,return) given a trace and gamma and tolerance

def get_state_return_samples(trace: Iterable[Tuple[S, float]], gamma, tolerance) -> Sequence[Tuple[S, float]]:
    l = []
    trace = iter(trace)
    max_steps = int(np.log(tolerance) / np.log(gamma)) if gamma < 1 else None
    if max_steps is not None:
        trace = list(itertools.islice(trace, int(max_steps * 2) ))
    for i, T in enumerate(trace): #T is transition step type 
        if i < max_steps :
            ret = sum(A.reward*(gamma**q) for q,A in enumerate(trace[i:]))
            l.append((T.state,ret))
    return l


#Create a sequence of pairs (state,return) given numerous traces and gamma/tolerance, num_traces
def multiple_state_return_samples(traces: Iterable[Iterable[Tuple[S, float]]], gamma, tolerance,num_traces) -> Sequence[Tuple[S, float]]:
    L : Sequence[Tuple[S, float]] = []
    for trace in itertools.islice(traces, num_traces):
        L = L + get_state_return_samples(trace, gamma, tolerance)
    return L


def get_mc_value_function(state_return_samples: Sequence[Tuple[S, float]]) -> ValueFunc:
    return {s: np.mean([r for r in l]) for s, l in full_group_by(state_return_samples)}




In [8]:
list(full_group_by([("a",1),("b",2),("a",3),("a",4)]))

[('a', [1, 3, 4]), ('b', [2])]

In [9]:
traces = si_mrp.reward_traces(Choose(si_mrp.non_terminal_states))
N = 1000

sr_samp = multiple_state_return_samples(traces, gamma = user_gamma, tolerance = 1e-6, num_traces = N)
#Cannot print too many sequences

mc_val = get_mc_value_function(sr_samp)

print("------------- MONTE CARLO VALUE FUNCTION --------------")
pprint(mc_val)

------------- MONTE CARLO VALUE FUNCTION --------------
{NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -28.930084305825783,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.28994123637695,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.33402768144681,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.30694891908742,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -27.86804449655829,
 NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.46549741700056}


## TD Prediction and Tabular

In [10]:
import rl.iterate as iterate
import rl.td as td
import itertools
from pprint import pprint
from rl.chapter10.prediction_utils import fmrp_episodes_stream
from rl.chapter10.prediction_utils import unit_experiences_from_episodes
from rl.function_approx import learning_rate_schedule
from rl.markov_process import TransitionStep

episode_length: int = 100
initial_learning_rate: float = 0.03
half_life: float = 1000.0
exponent: float = 0.5
user_gamma: float = 0.9


episodes: Iterable[Iterable[TransitionStep[S]]] = \
    fmrp_episodes_stream(si_mrp)
td_experiences: Iterable[TransitionStep[S]] = \
    unit_experiences_from_episodes(
        episodes,
        episode_length)

learning_rate_func: Callable[[int], float] = learning_rate_schedule(
    initial_learning_rate=initial_learning_rate,
    half_life=half_life,
    exponent=exponent
)
    
td_vfs: Iterator[ValueFunctionApprox[S]] = td.td_prediction(
    transitions=td_experiences,
    approx_0=Tabular(count_to_weight_func=learning_rate_func),
    gamma=user_gamma
)
    
    
num_episodes = 2000

final_td_vf: ValueFunctionApprox[S] = \
    iterate.last(itertools.islice(td_vfs, episode_length * num_episodes))

pprint({s: round(final_td_vf(s), 3) for s in si_mrp.non_terminal_states})

{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.688,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -28.096,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.43,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -29.137,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.818,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.486}


In [11]:
## Tabular from scratch : no terminal state
episodes: Iterable[Iterable[TransitionStep[S]]] = fmrp_episodes_stream(si_mrp)
    
learning_rate_func: Callable[[int], float] = learning_rate_schedule(
    initial_learning_rate=initial_learning_rate,
    half_life=half_life,
    exponent=exponent
)

def get_td_value_function(episodes, num_episodes, episode_length, learning_rate_func, gamma, mp):
    vc : Mapping[S,float] = {s : 0. for s in mp.non_terminal_states}
    count_lr=0
    count_episode = 0
    for episode in episodes : 
        if count_episode > num_episodes : 
            break
        list_exp = itertools.islice(episode,episode_length)
        for T in list_exp : 
            vc[T.state] = vc[T.state] + learning_rate_func(count_lr)*(T.reward + gamma*vc[T.next_state] - vc[T.state])
            count_lr += 1
        count_episode +=1
    return vc

In [12]:
td_val = get_td_value_function(episodes=episodes, num_episodes = 2000, episode_length=200, lr=learning_rate_func, gamma=user_gamma, mp=si_mrp)

print("------------- Temporal Difference VALUE FUNCTION --------------")
pprint(td_val)

------------- Temporal Difference VALUE FUNCTION --------------
{NonTerminal(state=InventoryState(on_hand=0, on_order=0)): -35.58993025913883,
 NonTerminal(state=InventoryState(on_hand=0, on_order=1)): -28.066649907079743,
 NonTerminal(state=InventoryState(on_hand=0, on_order=2)): -28.376179039859835,
 NonTerminal(state=InventoryState(on_hand=1, on_order=0)): -29.025821492496252,
 NonTerminal(state=InventoryState(on_hand=1, on_order=1)): -29.41800601312299,
 NonTerminal(state=InventoryState(on_hand=2, on_order=0)): -30.474710274239822}
