In [3]:
# Based on
# https://github.com/dennybritz/reinforcement-learning/blob/master/TD/Q-Learning%20Solution.ipynb
# Under MIT License by Denny Britz

In [31]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle

import actions
import states

In [5]:
def get_training_set():
    return pd.read_csv("../../../../train_data/rl_labels.csv")

In [33]:
def default_action_values():
    return np.zeros(actions.n)

In [34]:
def q_learning(train_set, Q=None, num_episodes=1, discount_factor=1.0, alpha=0.5, epsilon=0.1):
    """
    Q-Learning algorithm: Off-policy TD control.
    
    Args:
        Q: if you have an existing Q to train more
        num_episodes: Number of loops over the train set
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        Q
        Q is the optimal action-value function, a dictionary mapping state -> action values.
    """
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    if Q is None:
        Q = defaultdict(default_action_values)  
    
    for i_episode in range(num_episodes):
        print("\rEpisode {}/{}...".format(i_episode + 1, num_episodes), end="")
        sys.stdout.flush()
        
        for index, state, action, reward, next_state in train_set.itertuples():
            
            # TD Update
            best_next_action = np.argmax(Q[next_state])    
            td_target = reward + discount_factor * Q[next_state][best_next_action]
            td_delta = td_target - Q[state][action]
            Q[state][action] += alpha * td_delta
    
    return Q

In [35]:
Q = q_learning(get_training_set(), num_episodes=1)

Episode 1/1...

In [36]:
def print_action_values(action_values):
    for a, v in enumerate(action_values):
        print("{0}\t{1}".format(actions.Action(a), v))
    

In [37]:
print_action_values(Q[states.State.far_straight.value])

Action.stop	13.360255515525962
Action.low_hard_left	0.0
Action.high_hard_left	7.268059128891072
Action.low_soft_left	0.0
Action.high_soft_left	0.0
Action.low_straight	0.0
Action.high_straight	13.391505513399997
Action.low_soft_right	0.0
Action.high_soft_right	0.0
Action.low_hard_right	0.0
Action.high_hard_right	4.543242898560518
Action.rev_low_hard_left	0.0
Action.rev_high_hard_left	0.0
Action.rev_low_soft_left	0.0
Action.rev_high_soft_left	0.0
Action.rev_low_straight	0.0
Action.rev_high_straight	0.0
Action.rev_low_soft_right	0.0
Action.rev_high_soft_right	0.0
Action.rev_low_hard_right	0.0
Action.rev_high_hard_right	0.0


In [38]:
Q_serialized = pickle.dumps(Q)

In [42]:
out_path = "../models/q.pkl"
with open(out_path, "bw") as f:
    f.write(Q_serialized)