In [None]:
!pip install gym[toy_text]
import gym
import numpy as np

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pygame==2.1.0
  Downloading pygame-2.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[K     |████████████████████████████████| 18.3 MB 3.5 MB/s 
Installing collected packages: pygame
Successfully installed pygame-2.1.0


## 0. Monte Carlo

In [None]:
# V[x] += alpha * (G[x] - V[x])
def monte_carlo(env, V, policy, episodes=5000, max_steps=100, alpha=0.1, gamma=0.99):
    """
    performs the Monte Carlo algorithm
    env is the openAI environment instance
    V is a numpy.ndarray of shape (s,) containing the value estimate
    policy is a function that takes in a state and returns the next action to take
    episodes is the total number of episodes to train over
    max_steps is the maximum number of steps per episode
    alpha is the learning rate
    gamma is the discount rate
    Returns: V, the updated value estimate
    """
    for episode in range(episodes):
      results_list = []
      state = env.reset()
      #print(env.render()[0])
      for frame in range(max_steps):
        action = policy(state)
        nxt_state, reward, terminate, _ = env.step(action)
        results_list.append((state, reward))
        if terminate or frame > max_steps:
          break
        state = nxt_state
      
      results_list = np.array(results_list, dtype=int)
      G = 0
      # replay episode backwards, applying discount rate each time
      for (s, r) in results_list[::-1]:
        G = (gamma * G) + r
        # first visit approach
        if s not in results_list[:episode, 0]:
          V[s] += alpha * (G - V[s])
    return V.round(2)

In [None]:
np.random.seed(0)

env = gym.make('FrozenLake8x8-v1', render_mode="ansi")
LEFT, DOWN, RIGHT, UP = 0, 1, 2, 3

def policy(s):
    p = np.random.uniform()
    if p > 0.5:
        if s % 8 != 7 and env.desc[s // 8, s % 8 + 1] != b'H':
            return RIGHT
        elif s // 8 != 7 and env.desc[s // 8 + 1, s % 8] != b'H':
            return DOWN
        elif s // 8 != 0 and env.desc[s // 8 - 1, s % 8] != b'H':
            return UP
        else:
            return LEFT
    else:
        if s // 8 != 7 and env.desc[s // 8 + 1, s % 8] != b'H':
            return DOWN
        elif s % 8 != 7 and env.desc[s // 8, s % 8 + 1] != b'H':
            return RIGHT
        elif s % 8 != 0 and env.desc[s // 8, s % 8 - 1] != b'H':
            return LEFT
        else:
            return UP

V = np.where(env.desc == b'H', -1, 1).reshape(64).astype('float64') 
np.set_printoptions(precision=2)
env.seed(0)
print(monte_carlo(env, V, policy).reshape((8, 8)))

  "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
  "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
  "Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "


[[ 0.9   0.73  0.66  0.73  0.9   0.9   0.59  0.53]
 [ 0.59  0.66  0.73  0.81  0.66  0.39  0.48  0.39]
 [ 0.66  0.25  0.35 -1.    1.    0.48  0.43  0.43]
 [ 0.9   0.43  0.28  0.59  0.9  -1.    0.48  0.48]
 [ 1.    0.73  0.59 -1.    1.    1.    0.73  0.73]
 [ 1.   -1.   -1.    1.    1.    1.   -1.    0.9 ]
 [ 1.   -1.    1.    1.   -1.    1.   -1.    1.  ]
 [ 1.    1.    1.   -1.    1.    1.    1.    1.  ]]


In [None]:
'''
[[ 0.81    0.9     0.4783  0.4305  0.3874  0.4305  0.6561  0.9   ]
 [ 0.9     0.729   0.5905  0.4783  0.5905  0.2824  0.2824  0.3874]
 [ 1.      0.5314  0.729  -1.      1.      0.3874  0.2824  0.4305]
 [ 1.      0.5905  0.81    0.9     1.     -1.      0.3874  0.6561]
 [ 1.      0.6561  0.81   -1.      1.      1.      0.729   0.5314]
 [ 1.     -1.     -1.      1.      1.      1.     -1.      0.9   ]
 [ 1.     -1.      1.      1.     -1.      1.     -1.      1.    ]
 [ 1.      1.      1.     -1.      1.      1.      1.      1.    ]]
'''

'''
[[ 0.9   0.73  0.66  0.73  0.9   0.9   0.59  0.53]
 [ 0.59  0.66  0.73  0.81  0.66  0.39  0.48  0.39]
 [ 0.66  0.25  0.35 -1.    1.    0.48  0.43  0.43]
 [ 0.9   0.43  0.28  0.59  0.9  -1.    0.48  0.48]
 [ 1.    0.73  0.59 -1.    1.    1.    0.73  0.73]
 [ 1.   -1.   -1.    1.    1.    1.   -1.    0.9 ]
 [ 1.   -1.    1.    1.   -1.    1.   -1.    1.  ]
 [ 1.    1.    1.   -1.    1.    1.    1.    1.  ]]
'''

## TD(λ)

In [None]:
def td_lambtha(env, V, policy, lambtha, episodes=5000, max_steps=100, alpha=0.1, gamma=0.99):
  """
  performs the TD(λ) algorithm
  env is the openAI environment instance
  V is a numpy.ndarray of shape (s,) containing the value estimate
  policy is a function that takes in a state and returns the next action to take
  lambtha is the eligibility trace factor
  episodes is the total number of episodes to train over
  max_steps is the maximum number of steps per episode
  alpha is the learning rate
  gamma is the discount rate
  Returns: V, the updated value estimate
  """
  elig_trace = np.zeros(V.shape[0])
  for episode in range(episodes):
    state = env.reset()
    #print(env.render()[0])
    for frame in range(max_steps):
      action = policy(state)
      nxt_state, reward, terminate, _ = env.step(action)
      elig_trace *= (gamma * lambtha)
      elig_trace[state] += 1

      V += alpha * (reward + (gamma * V[nxt_state]) - V[state]) * elig_trace
            
      if terminate or frame > max_steps:
        break
      
      state = nxt_state
  return V

In [None]:
np.random.seed(0)

env = gym.make('FrozenLake8x8-v1')
LEFT, DOWN, RIGHT, UP = 0, 1, 2, 3

def policy(s):
    p = np.random.uniform()
    if p > 0.5:
        if s % 8 != 7 and env.desc[s // 8, s % 8 + 1] != b'H':
            return RIGHT
        elif s // 8 != 7 and env.desc[s // 8 + 1, s % 8] != b'H':
            return DOWN
        elif s // 8 != 0 and env.desc[s // 8 - 1, s % 8] != b'H':
            return UP
        else:
            return LEFT
    else:
        if s // 8 != 7 and env.desc[s // 8 + 1, s % 8] != b'H':
            return DOWN
        elif s % 8 != 7 and env.desc[s // 8, s % 8 + 1] != b'H':
            return RIGHT
        elif s % 8 != 0 and env.desc[s // 8, s % 8 - 1] != b'H':
            return LEFT
        else:
            return UP

V = np.where(env.desc == b'H', -1, 1).reshape(64).astype('float64') 
np.set_printoptions(precision=4)
print(td_lambtha(env, V, policy, 0.9).reshape((8, 8)))

[[-0.8646 -0.8508 -0.8404 -0.7797 -0.6912 -0.6594 -0.5903 -0.5297]
 [-0.8924 -0.8899 -0.878  -0.8919 -0.8141 -0.7076 -0.6863 -0.5484]
 [-0.9237 -0.9204 -0.9442 -1.     -0.9273 -0.7414 -0.6061 -0.3766]
 [-0.9291 -0.9221 -0.9604 -0.979  -0.9795 -1.     -0.6747 -0.4439]
 [-0.9444 -0.9726 -0.9697 -1.     -0.9302 -0.8485 -0.642  -0.3288]
 [-0.8948 -1.     -1.      0.5021 -0.9561 -0.9146 -1.      0.2218]
 [-0.9057 -1.     -0.4269  0.5536 -1.     -0.78   -1.      0.7115]
 [-0.903  -0.9503 -0.8844 -1.      1.     -0.3773  0.1801  1.    ]]


In [None]:
'''
[[ 0.5314  0.5905  0.3138  0.3138  0.6561  0.9     0.81    0.9   ]
 [ 0.5314  0.5905  0.4783  0.6561  0.5905  0.6561  0.6561  0.5314]
 [ 0.6561  0.729   0.5905 -1.      0.9     0.9     0.5905  0.3874]
 [ 0.729   0.81    0.81    0.9     1.     -1.      0.5314  0.4305]
 [ 0.5905  0.6561  0.81   -1.      1.      1.      0.729   0.4783]
 [ 0.9    -1.     -1.      1.      1.      1.     -1.      0.81  ]
 [ 1.     -1.      1.      1.     -1.      1.     -1.      1.    ]
 [ 0.9     0.81    1.     -1.      1.      1.      1.      1.    ]]
'''

## SARSA(λ)

In [None]:
def sarsa_lambtha(env, Q, lambtha, episodes=5000, max_steps=100, alpha=0.1, gamma=0.99, epsilon=1, min_epsilon=0.1, epsilon_decay=0.05):
    """
    performs SARSA(λ):

    env is the openAI environment instance
    Q is a numpy.ndarray of shape (s,a) containing the Q table
    lambtha is the eligibility trace factor
    episodes is the total number of episodes to train over
    max_steps is the maximum number of steps per episode
    alpha is the learning rate
    gamma is the discount rate
    epsilon is the initial threshold for epsilon greedy
    min_epsilon is the minimum value that epsilon should decay to
    epsilon_decay is the decay rate for updating epsilon between episodes
    Returns: Q, the updated Q table
    """
    elig_traces = np.zeros_like(Q)
    
    for episode in range(episodes):
        state = env.reset()
        if np.random.uniform() < epsilon:
            action = np.random.randint(0, Q.shape[1])
        else:
            action = np.argmax(Q[state, :])
        for step in range(max_steps):
            next_state, reward, terminate, _ = env.step(action)
            if np.random.uniform() < epsilon:
              action2 = np.random.randint(0, Q.shape[1])
            else:
              action2 = np.argmax(Q[next_state, :])
          
            delta = reward + gamma * Q[next_state, action2] - Q[state, action]
            elig_traces[state, action] += 1

            Q += alpha * delta * elig_traces
            elig_traces *= lambtha * gamma

            state = next_state
            action = action2
            if terminate:
                break
        epsilon = min_epsilon + (1 - min_epsilon) * np.exp(-1 * epsilon_decay * episode)
    
    return Q

In [None]:
np.random.seed(0)
env = gym.make('FrozenLake8x8-v1')
Q = np.random.uniform(size=(64, 4))
np.set_printoptions(precision=4)
print(sarsa_lambtha(env, Q, 0.9))

[[0.5945 0.5181 0.6    0.5935]
 [0.5541 0.6034 0.5546 0.5923]
 [0.5843 0.6234 0.5557 0.6015]
 [0.5913 0.6093 0.6021 0.6462]
 [0.5698 0.6021 0.6952 0.5798]
 [0.6378 0.6463 0.7073 0.6227]
 [0.6435 0.7124 0.7074 0.6367]
 [0.7213 0.6751 0.6702 0.6721]
 [0.6123 0.6136 0.6114 0.6001]
 [0.6256 0.6044 0.6144 0.5959]
 [0.6061 0.6079 0.6586 0.6014]
 [0.4624 0.4572 0.4363 0.6766]
 [0.5598 0.5404 0.7211 0.5365]
 [0.6648 0.6871 0.7408 0.6352]
 [0.7558 0.7017 0.7004 0.6817]
 [0.5548 0.5287 0.7293 0.5583]
 [0.6184 0.6236 0.6241 0.6202]
 [0.6611 0.6974 0.6432 0.6507]
 [0.6903 0.5452 0.547  0.4837]
 [0.2828 0.1202 0.2961 0.1187]
 [0.4604 0.4928 0.7766 0.4787]
 [0.6738 0.7131 0.8178 0.673 ]
 [0.766  0.8042 0.7492 0.7227]
 [0.6743 0.7893 0.6717 0.6535]
 [0.6735 0.6367 0.6685 0.6576]
 [0.6984 0.7666 0.6821 0.6527]
 [0.7538 0.6166 0.6692 0.6471]
 [0.6015 0.6684 0.6531 0.485 ]
 [0.6615 0.5575 0.813  0.5578]
 [0.8811 0.5813 0.8817 0.6925]
 [0.8117 0.7689 0.7826 0.8222]
 [0.7439 0.8192 0.6834 0.7238]
 [0.6416

In [None]:
'''
[[0.5452 0.5363 0.6315 0.5329]
 [0.5591 0.6166 0.5316 0.5425]
 [0.5336 0.602  0.529  0.5463]
 [0.5475 0.5974 0.5362 0.5436]
 [0.5531 0.5693 0.6117 0.568 ]
 [0.6147 0.6011 0.6511 0.5966]
 [0.6472 0.6183 0.599  0.6176]
 [0.6334 0.6267 0.6519 0.634 ]
 [0.5571 0.5233 0.646  0.5867]
 [0.6456 0.5602 0.545  0.5321]
 [0.6303 0.53   0.5055 0.5394]
 [0.4495 0.4853 0.4384 0.5781]
 [0.5291 0.5351 0.5489 0.5821]
 [0.6182 0.6166 0.6186 0.6047]
 [0.6266 0.5832 0.6497 0.5645]
 [0.5369 0.3657 0.7081 0.4936]
 [0.5924 0.7393 0.5806 0.5818]
 [0.5621 0.7052 0.5681 0.5429]
 [0.6894 0.509  0.4663 0.5361]
 [0.2828 0.1202 0.2961 0.1187]
 [0.4457 0.4633 0.411  0.5208]
 [0.5899 0.6983 0.7595 0.5963]
 [0.7263 0.699  0.6698 0.6954]
 [0.6126 0.7508 0.4898 0.4768]
 [0.6615 0.5872 0.7568 0.5987]
 [0.5805 0.5433 0.5839 0.7284]
 [0.6236 0.6239 0.7243 0.5689]
 [0.6498 0.7383 0.6077 0.5422]
 [0.6334 0.6377 0.7003 0.6311]
 [0.8811 0.5813 0.8817 0.6925]
 [0.7557 0.7478 0.7796 0.7706]
 [0.6687 0.8253 0.65   0.5062]
 [0.6277 0.7568 0.6078 0.6561]
 [0.6366 0.6973 0.6338 0.7487]
 [0.7121 0.7965 0.7082 0.7455]
 [0.8965 0.3676 0.4359 0.8919]
 [0.7498 0.8535 0.3625 0.7401]
 [0.7681 0.7448 0.2974 0.837 ]
 [0.4996 0.6835 0.4382 0.8703]
 [0.8936 0.7053 0.4904 0.3181]
 [0.6677 0.7224 0.8078 0.6766]
 [0.9755 0.8558 0.0117 0.36  ]
 [0.73   0.1716 0.521  0.0543]
 [0.2466 0.0813 0.8518 0.2852]
 [0.3454 0.8602 0.7229 0.1075]
 [0.2801 0.7741 0.6684 0.288 ]
 [0.9342 0.614  0.5356 0.5899]
 [1.0137 0.391  0.4284 0.2431]
 [0.382  0.4696 0.4571 0.599 ]
 [0.2274 0.2544 0.058  0.4344]
 [0.3118 0.6755 0.4197 0.1796]
 [0.0247 0.0672 0.778  0.4537]
 [0.5366 0.8967 0.9903 0.2169]
 [0.6914 0.3132 0.0996 0.7817]
 [0.32   0.3835 0.5883 0.831 ]
 [0.629  1.3232 0.2735 0.8131]
 [0.2803 0.5022 0.5382 0.2851]
 [0.6295 0.6324 0.2997 0.2133]
 [0.5699 0.0643 0.2075 0.4247]
 [0.3742 0.4636 0.2776 0.5868]
 [0.8639 0.1175 0.5174 0.1321]
 [0.7169 0.3961 0.5654 0.1833]
 [0.1448 0.4881 0.3556 0.9404]
 [0.7653 0.7487 0.9037 0.0834]]
 '''