In [1]:
# Based on
# https://github.com/dennybritz/reinforcement-learning/blob/master/TD/Q-Learning%20Solution.ipynb
# Under MIT License by Denny Britz

In [2]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict
import pickle

import actions
import states

In [3]:
def get_training_set():
    return pd.read_csv("../../../../train_data/rl_labels.csv")

In [4]:
def default_action_values():
    return np.zeros(actions.n)

In [5]:
def q_learning(train_set, Q=None, num_episodes=1, discount_factor=1.0, alpha=0.5):
    """
    Q-Learning algorithm: Off-policy TD control.
    
    Args:
        Q: if you have an existing Q to train more
        num_episodes: Number of loops over the train set
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
    
    Returns:
        Q
        Q is the optimal action-value function, a dictionary mapping state -> action values.
    """
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    if Q is None:
        Q = defaultdict(default_action_values)  
    
    for i_episode in range(num_episodes):
        print("\rEpisode {}/{}...".format(i_episode + 1, num_episodes), end="")
        sys.stdout.flush()
        
        for index, state, action, reward, next_state in train_set.itertuples():
            
            # TD Update
            best_next_action = np.argmax(Q[next_state])    
            td_target = reward + discount_factor * Q[next_state][best_next_action]
            td_delta = td_target - Q[state][action]
            Q[state][action] += alpha * td_delta
    
    return Q

In [6]:
def print_action_values(action_values):
    for a, v in enumerate(action_values):
        print("{0}\t{1}".format(actions.Action(a), v))
    

In [7]:
Q = q_learning(get_training_set(), num_episodes=10, discount_factor=0.999)

Episode 10/10...

In [8]:
print_action_values(Q[states.State.unknown.value])

Action.stop	173.1172354945815
Action.low_hard_left	172.35530112447321
Action.high_hard_left	172.73219556640532
Action.low_soft_left	163.29292553622741
Action.high_soft_left	173.12907820360658
Action.low_straight	171.77283347852273
Action.high_straight	172.85521815171774
Action.low_soft_right	166.61248594975308
Action.high_soft_right	173.07395618197756
Action.low_hard_right	171.88881766423802
Action.high_hard_right	173.15973925299267
Action.rev_low_hard_left	170.22902206334575
Action.rev_high_hard_left	170.51413437340557
Action.rev_low_soft_left	158.32616968404906
Action.rev_high_soft_left	172.67614958522927
Action.rev_low_straight	171.8204163508566
Action.rev_high_straight	173.20388075385762
Action.rev_low_soft_right	0.0
Action.rev_high_soft_right	172.4798890099812
Action.rev_low_hard_right	170.28090380250762
Action.rev_high_hard_right	171.3498729726768


In [9]:
Q_serialized = pickle.dumps(Q)
out_path = "../models/q.pkl"
with open(out_path, "bw") as f:
    f.write(Q_serialized)