In [11]:
import itertools
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import poisson
from functools import lru_cache

In [None]:
@lru_cache(maxsize=None)
def p_pmf(x, mu):
    return poisson.pmf(x, mu)

def alpha_level(mu, epsilon):
    left, right = 0, 0
    iterator = itertools.count()
    for x in iterator:
        y = p_pmf(x, mu)
        if y >= epsilon:
            left = x
            break
    for x in iterator:
        y = p_pmf(x, mu)
        if y < epsilon:
            right = x
            break
    return left, right


def segment(start, stop, step=1):
    return range(start, stop + 1, step)


@dataclass
class PickUpPoint:
    issuance_mu: int
    returns_mu: int
    max_cars: int


def apply_action(state, action):
    first, second = state
    return first - action, second + action


def clip(v, from_v, to_v):
    return max(from_v, min(to_v, v))

@dataclass
class JackCarRental:
    first_pickup: PickUpPoint
    second_pickup: PickUpPoint
    point_to_point_movement_cost: float
    point_to_point_movement_max_cars: int
    first_to_second_free_movement = 0
    second_to_first_free_movement = 0
    rent_prize: float
    theta = 65
    theta_degrade = 0.5
    epsilon = 0.01
    gamma = 0.9

    def iterate_states(self):
        for i in segment(0, self.first_pickup.max_cars):
            for j in segment(0, self.second_pickup.max_cars):
                yield i, j

    def iterate_actions(self, state):
        d = self.point_to_point_movement_max_cars
        for i in segment(-d, d):
            first, second = apply_action(state, i)
            if 0 <= first <= self.first_pickup.max_cars and \
                    0 <= second <= self.second_pickup.max_cars:
                yield i

    def action_cost(self, action):
        free_first_to_second = self.first_to_second_free_movement
        free_second_to_first = self.second_to_first_free_movement
        free_ride = free_first_to_second if action >= 0 else free_second_to_first

        return max(abs(action) - free_ride, 0) * self.point_to_point_movement_cost

    def action_evaluation(self, state, action, values):
        balance = 0

        balance += self.action_cost(action)

        first, second = apply_action(state, action)

        for first_cars_issued in segment(*alpha_level(self.first_pickup.issuance_mu, self.epsilon)):
            for first_cars_returned in segment(*alpha_level(self.first_pickup.returns_mu, self.epsilon)):
                for second_cars_issued in segment(*alpha_level(self.second_pickup.issuance_mu, self.epsilon)):
                    for second_cars_returned in segment(*alpha_level(self.second_pickup.returns_mu, self.epsilon)):
                        first_actually_issued, second_actually_issued = \
                            min(first, first_cars_issued), min(second, second_cars_issued)

                        new_state = \
                            min(first - first_actually_issued + first_cars_returned, self.first_pickup.max_cars), \
                                min(second - second_actually_issued + second_cars_returned, self.second_pickup.max_cars)

                        income = (first_actually_issued + second_actually_issued) * self.rent_prize
                        probability = \
                            p_pmf(first_cars_issued, self.first_pickup.issuance_mu) * \
                            p_pmf(first_cars_returned, self.first_pickup.returns_mu) * \
                            p_pmf(second_cars_issued, self.second_pickup.issuance_mu) * \
                            p_pmf(second_cars_returned, self.second_pickup.returns_mu)
                        balance += probability * (income + self.gamma * values[new_state])
        return balance

    def strategy_evaluation(self, values, strategy, curr_theta, i=0):
        for j in itertools.count():
            new_values = np.zeros_like(values)
            delta = 0
            for s in self.iterate_states():
                new_values[s] = self.action_evaluation(s, strategy[s], values)
                delta = max(delta, abs(values[s] - new_values[s]))
                print(f'eval{i}:{j}', s, new_values[s], values[s], delta)
            values = new_values
            if delta < curr_theta:
                break
        return new_values

    def strategy_improvement(self, values, old_strategy, i=0):
        policy_stable = True
        new_strategy = np.zeros_like(old_strategy)
        for s in self.iterate_states():
            old_action = old_strategy[s]

            new_action = old_action
            new_action_balance = 0
            for a in self.iterate_actions(s):
                b = self.action_evaluation(s, a, values)
                # if a < 0:
                #     print(new_action, new_action_balance, b)
                if b > new_action_balance:
                    new_action_balance = b
                    new_action = a

            new_strategy[s] = new_action

            if old_action != new_action:
                policy_stable = False

            print(f'improvement{i}', s, old_strategy[s], new_strategy[s], policy_stable)
        return new_strategy, policy_stable

    def get_optimal_jack_strategy(self):
        shape = self.first_pickup.max_cars + 1, self.second_pickup.max_cars + 1
        values = np.zeros(shape)
        strategy = np.zeros(shape, dtype=int)
        strategy_history = [strategy]
        curr_theta = self.theta
        i = 0
        while True:
            values = self.strategy_evaluation(values, strategy, curr_theta, i)
            strategy, policy_stable = self.strategy_improvement(values, strategy, i)
            strategy_history.append(strategy)
            if policy_stable:
                break
            curr_theta *= self.theta_degrade
            i += 1
        return values, strategy, strategy_history


rental = JackCarRental(first_pickup=PickUpPoint(3, 3, 20), second_pickup=PickUpPoint(4, 2, 20),
                       point_to_point_movement_cost=-2, point_to_point_movement_max_cars=5, rent_prize=10)
# rental.gamma = 0.1
v, pi, history = rental.get_optimal_jack_strategy()

eval0:0 (0, 0) 0.0 0.0 0
eval0:0 (0, 1) 9.70348307521959 0.0 9.70348307521959
eval0:0 (0, 2) 18.680699722129045 0.0 18.680699722129045
eval0:0 (0, 3) 26.205383512418074 0.0 26.205383512418074
eval0:0 (0, 4) 31.7933568272135 0.0 31.7933568272135
eval0:0 (0, 5) 35.4446196665148 0.0 35.4446196665148
eval0:0 (0, 6) 37.54651412542098 0.0 37.54651412542098
eval0:0 (0, 7) 38.61549633073094 0.0 38.61549633073094
eval0:0 (0, 8) 39.09424296255669 0.0 39.09424296255669
eval0:0 (0, 9) 39.277871807640395 0.0 39.277871807640395
eval0:0 (0, 10) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 11) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 12) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 13) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 14) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 15) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 16) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 17) 39.33033719195003 0.0 39.33033719195003
eval0:0 (0, 18) 39.3303371919500

In [None]:
v

In [None]:
pi

In [None]:
for h in history:
    plt.imshow(h, cmap='viridis', vmin=-5, vmax=5, origin='lower')
    plt.colorbar()
    plt.show()