In [1]:
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy.stats import poisson
import sys

In [2]:
# class Possion:
#     def __init__(self, n, lam):
#         self.store = np.zeros([n, lam])
#         for i in range(0, n):
#             for j in range(0, lam):
#                 self.store[i][j] = self.calc(i, j)

#     def get(self, n, lam):
#         return self.store[n][lam]

#     def calc(self, n, lam):
#         if n == 0: return np.exp(-lam)
#         factorial = math.gamma(n + 1)
#         return (lam ** n) * np.exp(-lam) / factorial


In [3]:
def get_prob(transition, params):
    return (poisson.pmf(transition[0], params["ret_A"]) *
            poisson.pmf(transition[1], params["rent_A"]) *
            poisson.pmf(transition[2], params["ret_B"]) * 
            poisson.pmf(transition[3], params["rent_B"]))

def transitions(n, params):
    trans = np.zeros([int(n ** 4), 4], dtype='int')
    probabs = np.zeros(int(n ** 4), dtype='f')
    counter = 0
    indices = []
    thresh = 1e-7

    for i in range(n):
        for j in range(n):
            for k in range(n):
                for l in range(n):
                    trans[counter][0] = i
                    trans[counter][1] = j
                    trans[counter][2] = k
                    trans[counter][3] = l
                    probabs[counter] = get_prob([i, j, k, l], params)
                    if probabs[counter] < thresh:
                        indices.append(counter)
                    counter += 1

    probabs = np.delete(probabs, indices, axis=0)
    trans = np.delete(trans, indices, axis=0)
    return trans, probabs                       

In [4]:
def update_value(v_pi, params, A, B, transitions):
    value = 0

    for t in range(len(transitions[0])):
        returned_A, rented_A, returned_B, rented_B = transitions[0][t]

        true_rented_A = min(A, rented_A)
        true_rented_B = min(B, rented_B)

        A_new = max(0, min(params["cars_max"] - 1, A - true_rented_A + returned_A))
        B_new = max(0, min(params["cars_max"] - 1, B - true_rented_B + returned_B))

        value += transitions[1][t] * (params["r_rent"] * (true_rented_A + true_rented_B) + params["gamma"] * v_pi[A_new][B_new])

    return value


def policy_eval(v_pi, pi_s, params, trans):
    eps = 0.01
    # iterations = 0

    while True:
        dif = 0.0
        # iterations += 1
        # print("iterations =", iterations, v_pi[-1][-1])
        # sys.stdout.flush()

        for A in range(0, params["cars_max"]):
            for B in range(0, params["cars_max"]):
                old_val = v_pi[A][B]

                s_dash_A = A - pi_s[A][B]
                s_dash_B = B + pi_s[A][B]

                v_pi[A][B] = update_value(v_pi, params, s_dash_A, s_dash_B, trans)
                v_pi[A][B] += params["r_move"] * abs(pi_s[A][B])

                dif = max(dif, abs(old_val - v_pi[A][B]))

        if dif < eps:
            break

    return v_pi

def get_action(A, B, v_pi, params):
    maxValue = 0
    action = -10

    for transfer in range(-5, 6):
        A_new = A - transfer
        B_new = B + transfer
        if A_new < 0 or A_new > 20 or B_new < 0 or B_new > 20:
            continue
            
        if v_pi[A_new][B_new] + params["r_move"] * abs(transfer) >= maxValue:
            maxValue = v_pi[A_new][B_new] + params["r_move"] * abs(transfer)
            action = transfer

    return action

In [5]:
params = {
    "rent_A" : 3,
    "ret_A" : 3,
    "rent_B" : 4,
    "ret_B" : 2,
    "gamma" : 0.9,
    "cars_max" : 21,
    "move_max" : 5,
    "r_rent" : 10.0,
    "r_move" : -2.0
}

trans = transitions(20, params)

In [6]:
## zero initial values and zero initial policy
v_pi = np.zeros([params["cars_max"], params["cars_max"]])

## [-5, -4, ... , 4, 5] Positive numbers indicate moving cars from A to B
pi_s = np.zeros([params["cars_max"], params["cars_max"]], dtype='int32')

In [7]:
# policy_stable = False

# while not policy_stable:
for i in range(2):
    print("\n\npass,", i)
    
    ## policy evaluation
    v_pi = policy_eval(v_pi, pi_s, params, trans)

    # policy improvement
    # policy_stable = True
    for A in range(0, params["cars_max"]):
        for B in range(0, params["cars_max"]):
            new_action = get_action(A, B, v_pi, params)
            pi_s[A][B] = new_action
            # if new_action != pi_s[A][B]:
                # policy_stable = False
                
    print("\nValue:")
    print(v_pi.astype(int))
    print("\nPolicy:")
    print(pi_s)



pass, 0

Value:
[[405 415 425 435 444 453 461 469 477 484 491 498 504 510 516 521 526 531
  536 540 543]
 [415 425 435 445 454 463 471 479 487 494 501 508 514 520 526 531 536 541
  545 549 553]
 [424 434 444 454 463 472 480 489 496 503 510 517 523 529 535 540 545 550
  555 559 562]
 [433 443 453 462 471 480 489 497 504 512 519 525 531 537 543 548 554 558
  563 567 570]
 [440 450 460 469 478 487 496 504 511 519 526 532 539 544 550 556 561 565
  570 574 578]
 [446 456 465 475 484 493 502 510 517 525 532 538 544 550 556 561 567 571
  576 580 583]
 [451 461 470 480 489 498 507 515 522 530 536 543 549 555 561 566 571 576
  581 585 588]
 [455 465 474 484 493 502 511 519 526 534 540 547 553 559 565 570 575 580
  585 589 592]
 [458 468 478 487 497 506 514 522 530 537 544 550 557 563 568 574 579 584
  588 592 596]
 [461 471 481 490 499 508 517 525 532 540 547 553 559 565 571 576 582 586
  591 595 598]
 [463 473 483 492 502 511 519 527 535 542 549 555 562 568 573 579 584 589
  593 597 601]
 [4