In [1]:
import numpy as np
from food_truck_env import FoodTruck

In [2]:
def choose_action(state, policy):
    prob_a = policy[state]
    action = np.random.choice(a=list(prob_a.keys()), p=list(prob_a.values()))

    return action

In [3]:
def td_prediction(env, policy, discount, learning_rate, n_iter):
    states = env.state_space
    v = {s: 0 for s in states}
    
    s = env.reset()
    for i in range(n_iter):
        a = choose_action(s, policy)
        s_next, reward, done, _ = env.step(a)
        
        v[s] += learning_rate * (reward + discount * v[s_next] - v[s])
        
        if done:
            s = env.reset()
        else:
            s = s_next
            
    return v

In [4]:
def some_policy(states):
    policy = {}
    for s in states:
        day, inventory = s
        prob_a = {}
        
        if inventory >= 300:
            prob_a[0] = 1
        else:
            prob_a[200 - inventory] = 0.5
            prob_a[300 - inventory] = 0.5

        policy[s] = prob_a
    
    return policy

In [5]:
env = FoodTruck()
policy = some_policy(env.state_space)

In [6]:
v_estimate = td_prediction(env, policy, 1, 0.01, 10 ** 5)
print("Expected weekly profit for some policy is: ", v_estimate["Mon", 0])

Expected weekly profit for some policy is:  2499.6345339012455


We can see that td prediction also corretly find the weekly profit for this policy. Lets create on-policy TD control method SARSA

In [7]:
def get_eps_greedy(actions, eps, a_best):
    """
    Assigns probability to each action
    
    If there are 4 actions and eps=0.4, Best action gets 0.7 probability and other actions get 0.1 probability
    """
    prob_a = {}
    for a in actions:
        if a == a_best:
            prob_a[a] = 1 - eps + eps / len(actions)
        else:
            prob_a[a] = eps / len(actions)
            
    return prob_a

In [8]:
def get_random_policy(states, actions):
    policy = {}
    for s in states:
        policy[s] = {a: 1 / len(actions) for a in actions}
        
    return policy

In [9]:
def sarsa(env, discount, eps, lr, n_iter):
    states = env.state_space
    actions = env.action_space
    
    Q = {s: {a: 0 for a in actions} for s in states}
    policy = get_random_policy(states, actions)
    
    s = env.reset()
    a = choose_action(s, policy)
    for i in range(n_iter):
        if i % 100000 == 0:
            print(f"Iteration: {i}")
        
        s_next, reward, done, _ = env.step(a)
        a_best = max(Q[s_next], key=Q[s_next].get)
        
        policy[s_next] = get_eps_greedy(actions, eps, a_best)
        a_next = choose_action(s_next, policy)
        
        Q[s][a] += lr * (reward + discount * Q[s_next][a_next] - Q[s][a])
        
        if done:
            s = env.reset()
            a_best = max(Q[s], key=Q[s].get)
            policy[s] = get_eps_greedy(actions, eps, a_best)
            a = choose_action(s, policy)
        else:
            s = s_next
            a = a_next
            
    return policy, Q

In [10]:
policy, Q = sarsa(env, 1, 0.1, 0.05, 1000000)
policy

Iteration: 0
Iteration: 100000
Iteration: 200000
Iteration: 300000
Iteration: 400000
Iteration: 500000
Iteration: 600000
Iteration: 700000
Iteration: 800000
Iteration: 900000


{('Mon', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.92},
 ('Tue', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},
 ('Tue', 100): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},
 ('Tue', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Tue', 300): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Wed', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},
 ('Wed', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Wed', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Wed', 300): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Thu', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},
 ('Thu', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Thu', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Thu', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Fri', 0): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Fri', 100): {0: 0.02, 100: 

In [11]:
def q_learning(env, discount, eps, lr, n_iter):
    states = env.state_space
    actions = env.action_space
    
    Q = {s: {a: 0 for a in actions} for s in states}
    policy = get_random_policy(states, actions)
    
    s = env.reset()
    for i in range(n_iter):
        if i % 100000 == 0:
            print("Iteration:", i)
            
        a_best = max(Q[s], key=Q[s].get)
        policy[s] = get_eps_greedy(actions, eps, a_best)
        a = choose_action(s, policy)
        
        s_next, reward, done, _ = env.step(a)
        Q[s][a] += lr * (reward + discount * max(Q[s_next].values()) - Q[s][a])
        
        if done:
            s = env.reset()
        else:
            s = s_next
            
    return policy, Q

In [12]:
policy, Q = q_learning(env, 1, 0.1, 0.01, 1000000)
policy

Iteration: 0
Iteration: 100000
Iteration: 200000
Iteration: 300000
Iteration: 400000
Iteration: 500000
Iteration: 600000
Iteration: 700000
Iteration: 800000
Iteration: 900000


{('Mon', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.92},
 ('Tue', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.92},
 ('Tue', 100): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},
 ('Tue', 200): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Tue', 300): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Wed', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.92},
 ('Wed', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Wed', 200): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Wed', 300): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Thu', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.92},
 ('Thu', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Thu', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Thu', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},
 ('Fri', 0): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},
 ('Fri', 100): {0: 0.02, 100: 