In [2]:
""" Example 4.2 Jack's Car Rental

Author : SeongJin Yoon
"""
import numpy as np
from enum import Enum
import math
import itertools
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import datetime as dt
MIN = 0
MAX = 1

class ConfigDict(dict):
    def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
    def __getattr__(self, name): return self[name]
    def __setattr__(self, name, value): self[name] = value
    def __delattr__(self, name): del self[name]
        
config = ConfigDict()
config.env = ConfigDict(max_managed_cars = 20, max_movable_cars = 5, num_branch = 2)
config.branch1 = ConfigDict(lambda_rent = 3, lambda_return = 3)
config.branch2 = ConfigDict(lambda_rent = 4, lambda_return = 2)
config.value = ConfigDict(discount = 0.9, theta = 1e-4)
config.poisson = ConfigDict(upperbound = 11)
config.plot = ConfigDict(save_dir = './sjyoon/')

class Plot():
    def __init__(self, nrows, ncols, num_ticks, coord_list):
        self.fig_idx = 0

        self.coord_list = coord_list
        self.xticks = list(range(num_ticks+1))
        self.yticks = list(reversed(range(num_ticks+1)))
        
        _, self.axes = plt.subplots(nrows=nrows, ncols=ncols,  figsize=(40, 20))
        plt.subplots_adjust(wspace=0.1, hspace=0.2)
        self.axes = self.axes.flatten()

    # plot a policy/state value matrix
    def draw_heatmap(self, data, labels):
        fig = sns.heatmap(data, cmap="YlGnBu", ax=self.axes[self.fig_idx])
        fig.set_xticks(self.xticks)
        fig.set_yticks(self.yticks)
        fig.set_ylabel(labels[0], fontsize=30)
        fig.set_xlabel(labels[1], fontsize=30)
        fig.set_title(labels[2], fontsize=30)
        self.fig_idx += 1

    def save(self, title):
        plt.savefig(title)

    def show(self):
        plt.show()

    def close(self):
        plt.close()


class Poisson():
    def __init__(self):
        self.dist = dict()
        
    def prob(self, n, mean):
        key = n * 10 + mean
        if key not in self.dist.keys():
            self.dist[key] = math.exp(-mean) * pow(mean, n) / math.factorial(n)

        return self.dist[key]
    
    def __call__(self, n, mean):
        return self.prob(n, mean)


class Reward(Enum):
    """Reward classes."""
    rent = 1
    move = 2

class Transaction():
    """Request and Return transaction classes."""
    def __init__(self, dist, mean):
        self.dist = dist
        self.mean = mean

    def prob(self, num_tx):
        return self.dist.prob(num_tx, self.mean)
    
class Environment():
    def __init__(self):
        
        self.prepare_environment()
        
        self.branch1 = Branch(config.branch1)
        self.branch2 = Branch(config.branch2)
        self.branch_list = [self.branch1, self.branch2]
        
        poisson_dist = Poisson()
        self.request1 = Transaction(poisson_dist, config.branch1.lambda_rent)
        self.request2 = Transaction(poisson_dist, config.branch2.lambda_rent)
        self.return1 = Transaction(poisson_dist, config.branch1.lambda_return)
        self.return2 = Transaction(poisson_dist, config.branch2.lambda_return)
        

    def prepare_environment(self):
        self.num_state = config.env.max_managed_cars + 1
        self.num_action = config.env.max_movable_cars # max number of cars to be moved
        self.num_request = config.poisson.upperbound
        self.num_return = config.poisson.upperbound
        
        
        self.state_list = self.make_pair_list(self.num_state)
        self.action_list = list(range(0, self.num_action+1)) + list(range(-1, -(self.num_action+1), -1))
        self.request_list = self.make_pair_list(self.num_request)
        self.return_list = self.make_pair_list(self.num_return)
        self.reward_list = {Reward.rent : 10, Reward.move : -2}
        return
    
    def make_pair_list(self, num_pair):
        return list(itertools.product(range(num_pair), repeat=2))
    
    def get_statelist(self):
        return self.state_list
    
    def get_actionlist(self):
        return self.action_list

    def get_requestlist(self):
        return self.request_list

    def get_returnlist(self):
        return self.return_list

    def get_num_state(self):
        return self.num_state

    def get_num_action(self):
        return self.num_action
    
    def get_state(self):
        state = tuple((self.branch1.get_state(), self.branch2.get_state()))
        return state

    def get_reward(self, reward_type):
        return self.reward_list[reward_type]

    def rent_prob(self, requests):
        r1, r2 = requests
        return self.request1.prob(r1) * self.request2.prob(r2)
            
    def return_prob(self, returns):
        r1, r2 = returns
        return self.return1.prob(r1) * self.return2.prob(r2)

    def validate(self, state, action):
        self.reset_state(state)
        
        from_branch = self.branch1
        to_branch = self.branch2
        requests = action
        
        if action < 0:
            from_branch = self.branch2
            to_branch = self.branch1
            requests = -action
        
        if from_branch.get_available_cars() < requests \
            or to_branch.get_acceptable_cars() < requests:
            return False
        
        return True
        
    def lookahead_do_action(self, state, action):
        """ move cars from one branch to another branch"""

        if self.validate(state, action)  is False:
            return 0

        self.reset_state(state)

        from_branch = self.branch1
        to_branch = self.branch2
        requests = action
        
        if action < 0:
            from_branch = self.branch2
            to_branch = self.branch1
            requests = -action
       
        cars_move_from = from_branch.move_from(requests)
        cars_move_info = to_branch.move_into(requests)
        
        assert cars_move_from == cars_move_info

        reward = requests*self.get_reward(Reward.move)
        return self.get_state(), reward

    def lookahead_rent_cars(self, state, requests):
        self.reset_state(state)
        
        r1, r2 = requests
        rent_cars1 = self.branch1.rent_cars(r1)
        rent_cars2 = self.branch2.rent_cars(r2)
       
        total_rent_cars = rent_cars1 + rent_cars2
        reward = total_rent_cars*self.get_reward(Reward.rent)
        
        return self.get_state(), reward

    def lookahead_return_cars(self, state, returns):
        self.reset_state(state)
            
        r1, r2 = returns
        self.branch1.return_cars(r1)
        self.branch2.return_cars(r2)
        
        return self.get_state()
    
    def reset_state(self, state):
        s1, s2 = state
        self.branch1.set_state(s1)
        self.branch2.set_state(s2)

class Branch():
    def __init__(self, branch):
        self.branch = branch
        self.state_range = [0,config.env.max_managed_cars]
        self.state = config.env.max_managed_cars

    def get_state(self):
        return self.state
    
    def set_state(self, state):
        assert self.in_range(state)
        self.state = state

    def in_range(self, state):
        if self.state_range[MIN] <= state \
            and state <= self.state_range[MAX]:
                return True
        return False

    def get_available_cars(self):
        return self.state - self.state_range[MIN]
    
    def get_acceptable_cars(self):
        return self.state_range[MAX] - self.state

    def rent_cars(self, request_cars):
        available_cars = self.state - self.state_range[MIN]
        
        if available_cars < request_cars:
            rent_cars = available_cars
        else:
            rent_cars = request_cars
        
        self.state -= rent_cars
        return rent_cars
    
    def return_cars(self, request_cars):
        acceptable_cars = self.state_range[MAX] - self.state
        
        if acceptable_cars < request_cars:
            return_cars = acceptable_cars
        else:
            return_cars = request_cars

        self.state += return_cars
        return return_cars

    def move_from(self, request_cars):
        if request_cars > config.env.max_movable_cars:
            return 0
        
        available_cars = self.get_available_cars()
        
        if available_cars < request_cars:
            move_cars = available_cars
        else:
            move_cars = request_cars
        
        self.state -= move_cars
        return move_cars

    def move_into(self, request_cars):
        if request_cars > config.env.max_movable_cars:
            return 0
        
        acceptable_cars = self.get_acceptable_cars()
        
        if acceptable_cars <= request_cars:
            move_cars = acceptable_cars
        else:
            move_cars = request_cars

        self.state += move_cars
        return move_cars


class Agent():
    def __init__(self, env):
        
        self.env = env
        self.qurey_environment()
        
        self.plot = Plot(2, 3, self.num_state, self.state_list)
        self.policy_labels = ['# of cars in branch 1', '# of cars in branch 2', 'Policy']
        self.v_labels = ['# of cars in branch 1', '# of cars in branch 2', 'Value Function']


    def qurey_environment(self):
        self.num_state = self.env.get_num_state()
        self.num_action = self.env.get_num_action()

        self.state_list = self.env.get_statelist()
        self.action_list = self.env.get_actionlist()
        self.request_list = self.env.get_requestlist()
        self.return_list = self.env.get_returnlist()
        
    
    def policy_iteration(self):
        
        # 1. Initialize
        self.policy = np.zeros((self.num_state, self.num_state), dtype=np.int8)
        self.value_function = np.zeros((self.num_state, self.num_state))

        iter_count = 0
        policy_stable = False
        while policy_stable is False:            
            # 3. Visualization of Policy
            self.policy_labels[2] = "Policy %d" % (iter_count)
            self.plot.draw_heatmap(np.flipud(self.policy), self.policy_labels)
            filename = config.plot.save_dir + "figure_4_2_rentcar_%d.png" % (iter_count)
            self.plot.save(filename)

            # 2. Policy Evaluation
            print("Policy evaluation (%d)" % (iter_count), dt.datetime.now())
            self.policy_evaluation()

            # 3. Policy Improvement
            print("Policy Improvement (%d)" % (iter_count), dt.datetime.now())
            policy_stable = self.policy_improvement()

            iter_count += 1

        # 4. Visualization of Value Function
        self.v_labels[2] = "Value Function %d" % (iter_count)
        self.plot.draw_heatmap(np.flipud(self.value_function) , self.v_labels)
        filename = config.plot.save_dir + "figure_4_2_rentcar_final.png"
        self.plot.save(filename)       
        self.plot.close()

    def policy_evaluation(self):
        
        theta = config.value.theta
        loop_count = 1
        while True:
            delta = 0
            for state in self.state_list:
                s1, s2 = state
                v = self.value_function[s1, s2]
                action_idx = self.policy[s1, s2]
                new_value = self.calc_qvalue(state, action_idx)
                self.value_function[s1, s2] = new_value
                
                delta = max(delta, math.fabs(v - new_value)) 
                
            print("loop %d : delta %.4f, theta %.4f" % (loop_count, delta, theta), dt.datetime.now())
            #print(self.V.get_array())
            loop_count += 1
            
            if delta < theta:
                break
        
    def policy_improvement(self):
        
        change_count = 0
        old_policy_sum = np.sum(self.policy)
        
        for state in self.state_list:
            s1, s2 = state
            old_action = self.policy[s1, s2]

            # calculate return for each action
            qvalue_list = []
            for action in self.action_list:
                if self.env.validate(state, action):
                    qvalue_list.append(self.calc_qvalue(state, action))
                else:
                    qvalue_list.append(-float('inf'))
            
            # pick baset action
            best_action = self.action_list[self.pick_best(qvalue_list)]
            self.policy[s1, s2] = best_action
            
            if old_action != best_action:
                change_count += 1
                
        policy_stable = True
        if old_policy_sum != np.sum(self.policy):
            policy_stable = False

        print("%d policies are changed" % (change_count))            
        #print(self.policy.get_array())             
        return policy_stable
    
    def pick_best(self, candidates):
        best_idx = np.argmax(candidates)
        return best_idx
    
    def calc_qvalue(self, state, action):
       
        new_state, immediate_reward = self.env.lookahead_do_action(state, action)
        discounted_return = self.lookahead_daily_transaction(new_state)
 
        exptected_return = immediate_reward + discounted_return
        return exptected_return
    
    def lookahead_daily_transaction(self, state):
        discount = config.value.discount
        discounted_return = 0
        for requests in self.request_list:
            new_state, reward = self.env.lookahead_rent_cars(state, requests)
            prob_request = self.env.rent_prob(requests)

            for returns in self.return_list:
                new_state2 = self.env.lookahead_return_cars(new_state, returns)
                prob_return = self.env.return_prob(returns)
                prob = prob_request * prob_return
                s1, s2 = new_state2
                discounted_return += prob * \
                    (reward + discount * self.value_function[s1, s2] )
 
        return discounted_return

if __name__ == "__main__":
    begin_time = dt.datetime.now()
    print("Start Rent Car World", begin_time)
    env = Environment()
    agent = Agent(env)
    agent.policy_iteration()
    end_time = dt.datetime.now()
    print("End Rent Car World", end_time)
    print("Running Time", end_time - begin_time)

Start Rent Car World 2018-08-14 21:01:02.788934
Policy evaluation (0) 2018-08-14 21:01:03.813660
loop 1 : delta 191.1404, theta 0.0001 2018-08-14 21:01:45.640942
loop 2 : delta 131.9191, theta 0.0001 2018-08-14 21:02:28.341482
loop 3 : delta 88.6194, theta 0.0001 2018-08-14 21:03:12.451805
loop 4 : delta 66.2761, theta 0.0001 2018-08-14 21:03:55.957540
loop 5 : delta 52.3040, theta 0.0001 2018-08-14 21:04:38.448066
loop 6 : delta 40.5043, theta 0.0001 2018-08-14 21:05:20.968686
loop 7 : delta 31.5718, theta 0.0001 2018-08-14 21:06:03.675360
loop 8 : delta 25.0090, theta 0.0001 2018-08-14 21:06:45.596833
loop 9 : delta 20.7762, theta 0.0001 2018-08-14 21:07:26.761352
loop 10 : delta 17.3732, theta 0.0001 2018-08-14 21:08:08.006027
loop 11 : delta 14.4891, theta 0.0001 2018-08-14 21:08:49.311890
loop 12 : delta 12.0544, theta 0.0001 2018-08-14 21:09:30.521525
loop 13 : delta 10.0059, theta 0.0001 2018-08-14 21:10:11.508569
loop 14 : delta 8.2881, theta 0.0001 2018-08-14 21:10:52.620919
l

loop 2 : delta 2.7978, theta 0.0001 2018-08-14 22:33:46.951381
loop 3 : delta 1.8926, theta 0.0001 2018-08-14 22:34:27.720818
loop 4 : delta 1.3512, theta 0.0001 2018-08-14 22:35:08.331808
loop 5 : delta 0.9218, theta 0.0001 2018-08-14 22:35:48.868624
loop 6 : delta 0.6108, theta 0.0001 2018-08-14 22:36:29.316206
loop 7 : delta 0.4029, theta 0.0001 2018-08-14 22:37:09.877094
loop 8 : delta 0.2699, theta 0.0001 2018-08-14 22:37:50.390878
loop 9 : delta 0.1940, theta 0.0001 2018-08-14 22:38:30.918646
loop 10 : delta 0.1468, theta 0.0001 2018-08-14 22:39:11.591853
loop 11 : delta 0.1142, theta 0.0001 2018-08-14 22:39:52.234410
loop 12 : delta 0.0933, theta 0.0001 2018-08-14 22:40:33.117148
loop 13 : delta 0.0761, theta 0.0001 2018-08-14 22:41:13.855501
loop 14 : delta 0.0620, theta 0.0001 2018-08-14 22:41:54.524343
loop 15 : delta 0.0505, theta 0.0001 2018-08-14 22:42:35.152436
loop 16 : delta 0.0411, theta 0.0001 2018-08-14 22:43:15.717301
loop 17 : delta 0.0335, theta 0.0001 2018-08-14 