In [1]:
from hiive.mdptoolbox.mdp import ValueIteration, PolicyIteration, QLearning
from hiive.mdptoolbox.example import forest
from hiive.mdptoolbox import mdp
from hiive.mdptoolbox import util
import gym
import numpy as np
import sys
import os
from numpy.random import choice
import pandas as pd
import seaborn as sns
np.random.seed(44)

In [3]:
def test_policy(P, R, policy, test_count=1000, gamma=0.9):
    num_state = P.shape[-1]
    total_episode = num_state * test_count
    # start in each state
    total_reward = 0
    for state in range(num_state):
        state_reward = 0
        for state_episode in range(test_count):
            episode_reward = 0
            disc_rate = 1
            while True:
                # take step
                action = policy[state]
                # get next step using P
                probs = P[action][state]
                candidates = list(range(len(P[action][state])))
                next_state =  choice(candidates, 1, p=probs)[0]
                # get the reward
                reward = R[state][action] * disc_rate
                episode_reward += reward
                # when go back to 0 ended
                disc_rate *= gamma
                if next_state == 0:
                    break
            state_reward += episode_reward
        total_reward += state_reward
    return total_reward / total_episode

In [4]:
def grid_search_VI(P, R, discount=0.9, epsilon=[1e-9]):
    vi_df = pd.DataFrame(columns=["Epsilon", "Policy", "Iteration",
                                  "Time", "Reward", "Value Function"])
    for eps in epsilon:
        vi = ValueIteration(P, R, gamma=discount, epsilon=eps, max_iter=int(1e15))
        vi.run()
        reward = test_policy(P, R, vi.policy)
        info = [float(eps), vi.policy, vi.iter, vi.time, reward, vi.V]
        df_length = len(vi_df)
        vi_df.loc[df_length] = info
    return vi_df

Value Iteration

In [6]:
P,R = forest(400, r1 = 100, r2 = 20, p = 0.1)
fm_400_vi = grid_search_VI(P, R, discount = 0.9, epsilon = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9])

In [7]:
fm_400_vi

Unnamed: 0,Epsilon,Policy,Iteration,Time,Reward,Value Function
0,0.001,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",66,0.048837,2.205296,"(4.4706146525683454, 5.023100336527209, 5.0231..."
1,0.0001,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",76,0.007672,2.267714,"(4.473560831234312, 5.026046957818786, 5.02604..."
2,1e-05,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",87,0.014224,2.269377,"(4.474643139169861, 5.027129333047953, 5.02712..."
3,1e-06,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",98,0.014509,2.239972,"(4.47498279201032, 5.027468979261533, 5.027468..."
4,1e-07,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",109,0.099923,2.232036,"(4.475089377376456, 5.027575565280265, 5.02757..."
5,1e-08,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",120,0.017961,2.271127,"(4.475122825121185, 5.027609012960728, 5.02760..."
6,1e-09,"(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",131,0.01796,2.256289,"(4.475133321365347, 5.027619509211218, 5.02761..."


In [13]:
fm_400_vi.Policy.nunique()

1

In [14]:
fm_400_vi.Policy.unique()

array([(0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

Policy Iteration

In [10]:
pi = PolicyIteration(P, R, gamma=0.9, max_iter=1e6)
pi.run()
pi_pol = pi.policy
pi_reward = test_policy(P, R, pi_pol)
pi_iter = pi.iter
pi_time = pi.time
pi_iter, pi_time, pi_reward

(25, 0.2506129741668701, 2.2402734536534465)

Q-Learning

In [11]:
def grid_search_Q(P, R, discount=0.9, alpha_dec=[.99], alpha_min=[0.001],
            epsilon=[1.0], epsilon_decay=[0.99], n_iter=[1000000]):
    q_df = pd.DataFrame(columns=["Iterations", "Alpha Decay", "Alpha Min",
                                 "Epsilon", "Epsilon Decay", "Reward",
                                 "Time", "Policy", "Value Function",
                                 "Training Rewards"])

    count = 0
    for i in n_iter:
        for eps in epsilon:
            for eps_dec in epsilon_decay:
                for a_dec in alpha_dec:
                    for a_min in alpha_min:
                        q = QLearning(P, R, discount, alpha_decay=a_dec,
                                      alpha_min=a_min, epsilon=eps,
                                      epsilon_decay=eps_dec, n_iter=i)
                        q.run()
                        reward = test_policy(P, R, q.policy)
                        count += 1
                        print("{}: {}".format(count, reward))
                        st = q.run_stats
                        rews = [s['Reward'] for s in st]
                        info = [i, a_dec, a_min, eps, eps_dec, reward,
                                q.time, q.policy, q.V, rews]

                        df_length = len(q_df)
                        q_df.loc[df_length] = info
    return q_df

In [12]:
epsilons = [0.5, 0.8, 0.9, 0.99]
epsilon_decays = [.999]
alpha_decs = [0.999]
alpha_mins =[0.001]
iters = [int(e) for e  in [1e5, 1e6, 1e7, 1e8]]
q_df = grid_search_Q(P, R, discount=0.9, alpha_dec=alpha_decs, alpha_min=alpha_mins,
            epsilon=epsilons, epsilon_decay=epsilon_decays, n_iter=iters)

1: 1.6399208092729483
2: 1.6446780936557064
3: 1.6485164002249715
4: 1.6212276919345465
5: 1.9146440973561087
6: 1.9151487664995168
7: 1.8814180670123866
8: 1.9271553162166386
9: 1.8984754114196418
10: 1.9391791842969952
11: 1.9592610215712216
12: 1.939309037973741
13: 1.9872245597790028
14: 1.960505150296107
15: 1.946179073237952
16: 1.909701288977713


In [15]:
q_df


Unnamed: 0,Iterations,Alpha Decay,Alpha Min,Epsilon,Epsilon Decay,Reward,Time,Policy,Value Function,Training Rewards
0,100000,0.999,0.001,0.5,0.999,1.639921,4.935805,"(0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, ...","(4.479119717616096, 5.031495423431422, 4.90655...","[0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, ..."
1,100000,0.999,0.001,0.8,0.999,1.644678,4.999697,"(0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, ...","(4.469351615216611, 5.023869204066711, 4.88366...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ..."
2,100000,0.999,0.001,0.9,0.999,1.648516,4.927144,"(0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, ...","(4.478107299787259, 5.027798348189291, 4.88787...","[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ..."
3,100000,0.999,0.001,0.99,0.999,1.621228,4.920149,"(0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","(4.4704411988706365, 5.026609875899918, 4.8336...","[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, ..."
4,1000000,0.999,0.001,0.5,0.999,1.914644,47.281077,"(0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, ...","(4.477200800647011, 5.029746308915379, 5.02981...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0.0..."
5,1000000,0.999,0.001,0.8,0.999,1.915149,47.027452,"(0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, ...","(4.46540345700321, 5.0201170832271895, 5.02698...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
6,1000000,0.999,0.001,0.9,0.999,1.881418,47.885751,"(0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, ...","(4.47662706513603, 5.0320274616189815, 5.03037...","[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 100.0, 1.0, 0.0..."
7,1000000,0.999,0.001,0.99,0.999,1.927155,46.152514,"(0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, ...","(4.470532174291042, 5.026599183884151, 5.02930...","[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ..."
8,10000000,0.999,0.001,0.5,0.999,1.898475,476.423969,"(0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, ...","(4.482904736443411, 5.036442199047051, 5.02893...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9,10000000,0.999,0.001,0.8,0.999,1.939179,441.349567,"(0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, ...","(4.4664022715081, 5.023587380474288, 5.0285880...","[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, ..."
