In [1]:
# Based on
# https://github.com/dennybritz/reinforcement-learning/blob/master/TD/Q-Learning%20Solution.ipynb
# Under MIT License by Denny Britz

In [2]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle

import actions
import states

In [3]:
def get_training_set():
    return pd.read_csv("../../../../train_data/rl_labels.csv")

In [4]:
def default_action_values():
    return np.zeros(actions.n)

In [5]:
def q_learning(train_set, Q=None, num_episodes=1, discount_factor=1.0, alpha=0.5):
    """
    Q-Learning algorithm: Off-policy TD control.
    
    Args:
        Q: if you have an existing Q to train more
        num_episodes: Number of loops over the train set
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
    
    Returns:
        Q
        Q is the optimal action-value function, a dictionary mapping state -> action values.
    """
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    if Q is None:
        Q = defaultdict(default_action_values)  
    
    for i_episode in range(num_episodes):
        print("\rEpisode {}/{}...".format(i_episode + 1, num_episodes), end="")
        sys.stdout.flush()
        
        for index, state, action, reward, next_state in train_set.itertuples():
            
            # TD Update
            best_next_action = np.argmax(Q[next_state])    
            td_target = reward + discount_factor * Q[next_state][best_next_action]
            td_delta = td_target - Q[state][action]
            Q[state][action] += alpha * td_delta
    
    return Q

In [6]:
def print_action_values(action_values):
    for a, v in enumerate(action_values):
        print("{0}\t{1}".format(actions.Action(a), v))
    

In [35]:
Q = q_learning(get_training_set(), num_episodes=10, discount_factor=0.99999999)

Episode 10/10...

In [36]:
print_action_values(Q[states.State.unknown.value])

Action.stop	341.02307713187236
Action.low_hard_left	338.84475580405956
Action.high_hard_left	339.3571872695918
Action.low_soft_left	340.7697872668849
Action.high_soft_left	338.77692350704723
Action.low_straight	338.9753353328969
Action.high_straight	341.0121224153018
Action.low_soft_right	337.7880961314197
Action.high_soft_right	334.95921949383234
Action.low_hard_right	338.3079877407539
Action.high_hard_right	337.01466277948805
Action.rev_low_hard_left	336.2006096663831
Action.rev_high_hard_left	340.56965488279934
Action.rev_low_soft_left	339.6581530258793
Action.rev_high_soft_left	336.75478373441615
Action.rev_low_straight	337.88684913843986
Action.rev_high_straight	339.5684127545251
Action.rev_low_soft_right	340.027546010354
Action.rev_high_soft_right	340.29658450628034
Action.rev_low_hard_right	337.9062663463309
Action.rev_high_hard_right	340.3186193895947


In [16]:
Q_serialized = pickle.dumps(Q)
out_path = "../models/q.pkl"
with open(out_path, "bw") as f:
    f.write(Q_serialized)