In [None]:
from env import *

In [None]:
# returns an epsilon greedy chosen action given the action-value function
# N_0 is a constant, N_s is the dictionary containing how many times each state has been visited
def epsilon_greedy(Q, N_0, N_s, state):
    action_values = Q[state]
    epsilon = N_0 / (N_0 + N_s.get(state, 0))

    # the second check makes sure the algorithm chooses an action
    # at random when the values are the same (rather than always choosing HIT)
    if epsilon > random.uniform(0, 1) or action_values[0] == action_values[1]:
        return random.choice([action.value for action in Action])

    return max(action_values, key=lambda k: action_values[k])

In [24]:
def sarsa_lambda(env: Env, _lambda=0, gamma=1,num_episodes=1000, N_0=100):
  Q = initializeQ()
  e = initializeQ() # e is also a dict initially with 0 for every (state, action) pair
  N_s = {}
  N_sa = {}
  wins = 0

  for i in range(num_episodes):
    dealer_sum = NewCard(firstCard=True).get_value()
    player_sum = NewCard(firstCard=True).get_value()
    terminated = False
    state = (dealer_sum, player_sum)
    action = epsilon_greedy(Q, N_0, N_s, state)

    while not terminated:
      N_s[state] = N_s.get(state, 0) + 1
      N_sa[(state, action)] = N_sa.get((state, action), 0) + 1

      dealer_sum, player_sum, reward, terminated = env.step(
        dealer_sum, player_sum, action
      )
      # selecting the next state and action
      # TODO: here the dealer_sum and player_sum can become < 0
      if terminated:
        next_state_Q = 0
      else:
        new_state = (dealer_sum, player_sum)
        new_action = epsilon_greedy(Q, N_0, N_s, new_state) # choosing a new action (a')
        next_state_Q = Q[new_state][new_action]

      delta = reward + _lambda * next_state_Q - Q[state][action] # td error
      e[state][action] += 1

      wins += reward == 1

      for s in Q.keys():
        for a in range(2):
          alpha = 1 / N_sa.get((s, a), 1) # is 1 a good choice here?
          Q[s][a] += alpha * delta * e[s][a]
          e[s][a] = _lambda * gamma * e[s][a]

      if not terminated:
        state = new_state
        action = new_action
        
    if i % (num_episodes / 10) == 0 and i > 0:
      print("Episode: %d, score: %f" % (i, (float(wins) / i * 100.0)))

In [25]:
env = Env()

In [29]:
Q0 = sarsa_lambda(env, _lambda=0.5, num_episodes=100000)

Episode: 10000, score: 46.960000
Episode: 20000, score: 48.290000
Episode: 30000, score: 48.826667
Episode: 40000, score: 49.037500
Episode: 50000, score: 49.430000
Episode: 60000, score: 49.730000
Episode: 70000, score: 49.822857
Episode: 80000, score: 49.948750
Episode: 90000, score: 50.051111
