In [1]:
# Based on
# https://github.com/dennybritz/reinforcement-learning/blob/master/TD/Q-Learning%20Solution.ipynb
# Under MIT License by Denny Britz

In [2]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle

import actions
import states

In [3]:
def get_training_set():
    return pd.read_csv("../../../../train_data/rl_labels.csv")

In [4]:
def default_action_values():
    return np.zeros(actions.n)

In [5]:
def q_learning(train_set, Q=None, num_episodes=1, discount_factor=1.0, alpha=0.5):
    """
    Q-Learning algorithm: Off-policy TD control.
    
    Args:
        Q: if you have an existing Q to train more
        num_episodes: Number of loops over the train set
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
    
    Returns:
        Q
        Q is the optimal action-value function, a dictionary mapping state -> action values.
    """
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    if Q is None:
        Q = defaultdict(default_action_values)  
    
    for i_episode in range(num_episodes):
        print("\rEpisode {}/{}...".format(i_episode + 1, num_episodes), end="")
        sys.stdout.flush()
        
        for index, state, action, reward, next_state in train_set.itertuples():
            
            # TD Update
            best_next_action = np.argmax(Q[next_state])    
            td_target = reward + discount_factor * Q[next_state][best_next_action]
            td_delta = td_target - Q[state][action]
            Q[state][action] += alpha * td_delta
    
    return Q

In [101]:
def print_action_values(action_values):
    for a, v in enumerate(action_values):
        print("{0}\t{1}".format(actions.Action(a), v))
    
def print_best_action(action_values):
    print("Best: {0}".format(actions.Action(np.argmax(action_values))))

In [188]:
Q = q_learning(get_training_set(), num_episodes=100, discount_factor=0.9)

Episode 100/100...

In [189]:
#print_action_values(Q[states.State.unknown.value])
print_best_action(Q[states.State.unknown.value])
print("straight:")
print_best_action(Q[states.State.near_straight.value])
print_best_action(Q[states.State.good_straight.value])
print_best_action(Q[states.State.far_straight.value])
print("left:")
print_best_action(Q[states.State.near_left.value])
print_best_action(Q[states.State.good_left.value])
print_best_action(Q[states.State.far_left.value])
print("right:")
print_best_action(Q[states.State.near_right.value])
print_best_action(Q[states.State.good_right.value])
print_best_action(Q[states.State.far_right.value])

Best: Action.low_straight
straight:
Best: Action.stop
Best: Action.low_hard_left
Best: Action.rev_low_soft_left
left:
Best: Action.rev_low_straight
Best: Action.high_soft_left
Best: Action.rev_low_soft_left
right:
Best: Action.stop
Best: Action.low_soft_left
Best: Action.stop


In [190]:
print_action_values(Q[states.State.good_straight.value])
print_best_action(Q[states.State.good_straight.value])
print_action_values(Q[states.State.far_straight.value])
print_best_action(Q[states.State.far_straight.value])

Action.stop	8.377058010275665
Action.low_hard_left	9.999999984928056
Action.high_hard_left	8.233055918727262
Action.low_soft_left	0.0
Action.high_soft_left	9.999999981891463
Action.low_straight	0.0
Action.high_straight	8.905251277062312
Action.low_soft_right	0.0
Action.high_soft_right	8.232960375720488
Action.low_hard_right	8.99999998362361
Action.high_hard_right	9.646361919967173
Action.rev_low_hard_left	0.0
Action.rev_high_hard_left	7.336978333561209
Action.rev_low_soft_left	8.733333327903427
Action.rev_high_soft_left	9.99999998362361
Action.rev_low_straight	0.0
Action.rev_high_straight	9.999999984414373
Action.rev_low_soft_right	9.999999982488298
Action.rev_high_soft_right	0.0
Action.rev_low_hard_right	0.0
Action.rev_high_hard_right	0.0
Best: Action.low_hard_left
Action.stop	7.226288778448289
Action.low_hard_left	7.080348241799777
Action.high_hard_left	6.9182067162413325
Action.low_soft_left	6.4837178273827725
Action.high_soft_left	6.50484219505876
Action.low_straight	6.846630537745

In [191]:
Q_serialized = pickle.dumps(Q)
out_path = "../models/q.pkl"
with open(out_path, "bw") as f:
    f.write(Q_serialized)