In [1]:
# Based on
# https://github.com/dennybritz/reinforcement-learning/blob/master/TD/Q-Learning%20Solution.ipynb
# Under MIT License by Denny Britz

In [21]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict

import actions
import states

In [6]:
def get_training_set():
    return pd.read_csv("../../../../train_data/rl_labels.csv")

In [27]:
def q_learning(train_set, Q=None, num_episodes=1, discount_factor=1.0, alpha=0.5, epsilon=0.1):
    """
    Q-Learning algorithm: Off-policy TD control.
    
    Args:
        Q: if you have an existing Q to train more
        num_episodes: Number of loops over the train set
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        Q
        Q is the optimal action-value function, a dictionary mapping state -> action values.
    """
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    if Q is None:
        Q = defaultdict(lambda: np.zeros(actions.n))  
    
    for i_episode in range(num_episodes):
        print("\rEpisode {}/{}...".format(i_episode + 1, num_episodes), end="")
        sys.stdout.flush()
        
        for index, state, action, reward, next_state in train_set.itertuples():
            
            # TD Update
            best_next_action = np.argmax(Q[next_state])    
            td_target = reward + discount_factor * Q[next_state][best_next_action]
            td_delta = td_target - Q[state][action]
            Q[state][action] += alpha * td_delta
    
    return Q

In [28]:
q_learning(get_training_set())

Episode 1/1...

defaultdict(<function __main__.q_learning.<locals>.<lambda>()>,
            {nan: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                    0., 0., 0., 0.]),
             1: array([0. , 0. , 1.5, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
                    0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ])})