In [1]:
%matplotlib ipympl

import functools
import numpy as np
import scipy.special as sps
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook, tnrange

In [5]:
MAX_NUM_CARS = 20
MAX_NUM_TRANSFERS = 5


def memoize(obj): # from https://wiki.python.org/moin/PythonDecoratorLibrary
    cache = obj.cache = {}
    @functools.wraps(obj)
    def memoizer(*args, **kwargs):
        key = str(args) + str(kwargs)
        if key not in cache:
            cache[key] = obj(*args, **kwargs)
        return cache[key]
    return memoizer


@memoize
def poisson_prob(n, lam):
    return np.power(lam, n) / sps.factorial(n, exact=True) * np.exp(-lam)


"""
Poisson probability for all ns >= n
"""
@memoize
def poisson_remaining_prob(n, lam):
    if n == 0:
        return 1.
    ns = np.arange(n)
    return 1. - np.sum(np.power(lam, ns) / sps.factorial(ns, exact=True) * np.exp(-lam))


def policy_iteration(transition_prob_and_reward_fn, gamma=0.9, theta=1e-3, MAX_NUM_CARS=MAX_NUM_CARS, MAX_NUM_TRANSFERS=MAX_NUM_TRANSFERS):
    nS = MAX_NUM_CARS+1

    # Following p. 65 - Policy iteration (using iterative policy evaluation)
    # 1. Initialization
    V = np.random.rand(nS, nS)*200 + 420 # Re-scaled based on Figure 4.2
    pi = np.random.randint(low=-MAX_NUM_TRANSFERS, high=MAX_NUM_TRANSFERS+1, size=(nS, nS))
    n_iters = []

    while True:
        # 2. Policy Evaluation
        n_evals = 0
        while True:
            delta = 0
            for s1 in tnrange(nS, desc='eval %d s1' % n_evals):
                for s2 in range(nS):
                    v = V[s1, s2]
                    v_new = 0
                    for s1p in range(nS):
                        for s2p in range(nS):
                            p, r = transition_prob_and_reward_fn(s=(s1, s2), a=pi[s1, s2], sp=(s1p, s2p))
                            v_new += p*(r + gamma*V[s1p, s2p])
                    V[s1, s2] = v_new
                    delta = max(delta, abs(v-v_new))
            n_evals += 1
            print('delta:', delta)
            if delta < theta:
                break

        # 3. Policy Improvement
        n_iters.append(n_evals)
        num_a_diff = 0
        for s1 in tnrange(nS, desc='impr %d s1' % len(n_iters)):
            for s2 in range(nS):
                old_action = pi[s1, s2]
                backups = np.zeros(shape=(11,), dtype='float32')
                for ai, a in enumerate(range(-MAX_NUM_TRANSFERS, MAX_NUM_TRANSFERS+1)):
                    for s1p in range(nS):
                        for s2p in range(nS):
                            p, r = transition_prob_and_reward_fn(s=[s1, s2], a=a, sp=[s1p, s2p])
                            backups[ai] += p*(r + gamma*V[s1p, s2p])
                ai_max = np.argmax(backups)
                a_max = int(ai_max - MAX_NUM_TRANSFERS)
                pi[s1, s2] = a_max
                if old_action != a_max:
                    num_a_diff += 1
        
        print('num_a_diff',num_a_diff)
        
        if num_a_diff == 0:
            break
    
    return V, pi, n_iters


In [3]:
# Example 4.2

"""
s, sp: 2x1 state - [num_cars_in_first_loc; num_cars_in_second_loc], each \in [0, MAX_NUM_CARS]
a: action \in [-MAX_NUM_TRANSFERS, MAX_NUM_TRANSFERS], representing number of cars moved from first_loc to second_loc

Returns tuple: probability, (expected) reward
"""
@memoize
def transition_prob_and_reward_ex42(s, a, sp, MAX_NUM_CARS=MAX_NUM_CARS, MAX_NUM_TRANSFERS=MAX_NUM_TRANSFERS, r_transfer=-2, r_rental=10):
    nS = MAX_NUM_CARS+1
    
    r = 0
    p = 0
    
    # Check for invalid actions
    if a > 0 and s[0] < a:
        return 0., 0
    elif a < 0 and s[1] < -a:
        return 0., 0
    
    # Determine number of cars after transfer, before returns and new rentals
    r = r_transfer*abs(a)
    s_after_transfer = np.copy(s)
    s_after_transfer[0] -= a
    s_after_transfer[1] += a
    if s_after_transfer[0] > MAX_NUM_CARS:
        s_after_transfer[0] = MAX_NUM_CARS
    if s_after_transfer[1] > MAX_NUM_CARS:
        s_after_transfer[1] = MAX_NUM_CARS

    # Compute sum of probabilities by enumerating all possible returns (accumulate all remaining probs on cars=MAX_NUM_CARS) and determine corresponding number of rentals
    exp_r_rentals = 0
    for num_returns_first_loc in range(MAX_NUM_CARS+1):
        cars_after_return_first_loc = s_after_transfer[0]+num_returns_first_loc
        if cars_after_return_first_loc == MAX_NUM_CARS:
            prob_returns_first_loc = poisson_remaining_prob(n=num_returns_first_loc, lam=3) # more returns still result in same car state
        else:
            prob_returns_first_loc = poisson_prob(n=num_returns_first_loc, lam=3)
        num_rentals_first_loc = cars_after_return_first_loc - sp[0]
        if num_rentals_first_loc < 0:
            continue
        if sp[0] == 0:
            prob_rentals_first_loc = poisson_remaining_prob(n=num_rentals_first_loc, lam=3) # more rentals still result in same car state
        else:
            prob_rentals_first_loc = poisson_prob(n=num_rentals_first_loc, lam=3)
        
        for num_returns_sec_loc in range(MAX_NUM_CARS+1):
            cars_after_return_sec_loc = s_after_transfer[1]+num_returns_sec_loc
            if cars_after_return_sec_loc == MAX_NUM_CARS:
                prob_returns_sec_loc = poisson_remaining_prob(n=num_returns_sec_loc, lam=2) # more returns still result in same car state
            else:
                prob_returns_sec_loc = poisson_prob(n=num_returns_sec_loc, lam=2)
            num_rentals_sec_loc = cars_after_return_sec_loc - sp[1]
            if num_rentals_sec_loc < 0:
                continue
            if sp[1] == 0:
                prob_rentals_sec_loc = poisson_remaining_prob(n=num_rentals_sec_loc, lam=4) # more rentals still result in same car state
            else:
                prob_rentals_sec_loc = poisson_prob(n=num_rentals_sec_loc, lam=4)

            prob = prob_returns_first_loc * prob_returns_sec_loc * prob_rentals_first_loc * prob_rentals_sec_loc
            p += prob
            
            exp_r_rentals += r_rental*(num_rentals_first_loc+num_rentals_sec_loc) * prob
            
            if cars_after_return_sec_loc == MAX_NUM_CARS:
                break
            
        if cars_after_return_first_loc == MAX_NUM_CARS:
            break
    
    if p > 0:
        exp_r_rentals /= p
        r += exp_r_rentals
    
    return p, r


In [6]:
V, pi, n_iters = policy_iteration(transition_prob_and_reward_ex42)


delta: 613.467497185



delta: 346.869287225



delta: 137.711879243



delta: 75.3784987072



delta: 45.3064169314



delta: 25.2765527216



delta: 12.2168433901



delta: 5.52085740208



delta: 2.44027924164



delta: 1.0848778699



delta: 0.472861974498



delta: 0.203265352644



delta: 0.0865138185595



delta: 0.0365580853023



delta: 0.0153670020984



delta: 0.00643424923015



delta: 0.00268621921845



delta: 0.00111901674131



delta: 0.000465390608724



num_a_diff 405



delta: 332.5612214



delta: 69.1041939471



delta: 57.0815801807



delta: 49.2407397454



delta: 41.5545608998



delta: 34.8247322303



delta: 29.1039577814



delta: 24.2913156191



delta: 20.2608614175



delta: 16.8929307918



delta: 14.0819056728



delta: 11.737213681



delta: 9.78221344935



delta: 8.15248878175



delta: 6.79409447905



delta: 5.66194538573



delta: 4.71840479594



delta: 3.93207522991



delta: 3.27677475113



delta: 2.73067597254



delta: 2.27558457165



delta: 1.89633601596



delta: 1.58029178073



delta: 1.31691897978



delta: 1.0974397672



delta: 0.914539017186



delta: 0.762120640545



delta: 0.635104470072



delta: 0.52925697355



delta: 0.441050167576



delta: 0.367544038058



delta: 0.306288552883



delta: 0.255242003243



delta: 0.21270295319



delta: 0.17725352996



delta: 0.14771216529



delta: 0.123094212745



delta: 0.102579128534



delta: 0.0854831219601



delta: 0.071236363964



delta: 0.0593639941108



delta: 0.0494702929764



delta: 0.041225492377



delta: 0.0343547838358



delta: 0.0286291589088



delta: 0.0238577760721



delta: 0.0198815997667



delta: 0.0165680995612



delta: 0.0138068327641



delta: 0.0115057632461



delta: 0.00958819376825



delta: 0.00799020958135



delta: 0.00665854807522



delta: 0.00554882347217



delta: 0.00462404740063



delta: 0.0038533960352



delta: 0.00321118269665



delta: 0.00267600169241



delta: 0.00223001483755



delta: 0.00185835688649



delta: 0.00154864006311



delta: 0.00129054115592



delta: 0.00107545743867



delta: 0.000896219928165



num_a_diff 346



delta: 5.70433754601



delta: 2.62744222575



delta: 1.77110625617



delta: 1.25573307616



delta: 0.89110083336



delta: 0.646185291434



delta: 0.476026810666



delta: 0.366094326071



delta: 0.308922059637



delta: 0.260043095939



delta: 0.218589965243



delta: 0.183590186631



delta: 0.154114946181



delta: 0.129330396801



delta: 0.108509713153



delta: 0.0910292180985



delta: 0.0763584900753



delta: 0.0640487909262



delta: 0.0537217024274



delta: 0.0450587404665



delta: 0.0377921970572



delta: 0.0316972231096



delta: 0.0265850622982



delta: 0.0222973088297



delta: 0.0187010550796



delta: 0.0156848022784



delta: 0.0131550201327



delta: 0.0110332555904



delta: 0.00925370478592



delta: 0.00776117491205



delta: 0.00650937387394



delta: 0.00545947530765



delta: 0.00457891476299



delta: 0.00384037990409



delta: 0.00322096349078



delta: 0.00270145293632



delta: 0.00226573442217



delta: 0.00190029311699



delta: 0.00159379399116



delta: 0.00133673024055



delta: 0.00112112841441



delta: 0.000940301102787



num_a_diff 103



delta: 0.567228506345



delta: 0.110013681141



delta: 0.0540981291246



delta: 0.0341130317599



delta: 0.0241365125723



delta: 0.0178934442774



delta: 0.0130869510983



delta: 0.00970963002806



delta: 0.0082113711307



delta: 0.00692153303021



delta: 0.00582351322316



delta: 0.00489429043142



delta: 0.00411055340794



delta: 0.00345085103118



delta: 0.00289624144244



delta: 0.00243034604671



delta: 0.00203916804372



delta: 0.0017108288888



delta: 0.00143529044425



delta: 0.00120409230038



delta: 0.00101011574014



delta: 0.000847377375578



num_a_diff 5



delta: 0.00375059107023



delta: 0.000688687451543



num_a_diff 0


In [7]:
plt.matshow(V)
plt.xlabel('Number of cars in first location')
plt.ylabel('Number of cars in second location')
plt.title('Value function')

plt.matshow(pi)
plt.xlabel('Number of cars in first location')
plt.ylabel('Number of cars in second location')
plt.title('Policy function')

Text(0.5,1.05,'Policy function')

In [None]:
"""
# TODOs:
- discrepancy with sln? pbly not poisson folding
- Q value impl
- exercise 4.5
"""