# Temporal Difference Prediction and Control

In [None]:
import gym
import numpy as np
from tqdm.notebook import trange

## TD Prediction

In [None]:
def td_prediction(env, policy, obs_space, num_episodes, alpha, gamma):
    # v as value function
    v = np.zeros(len(obs_space))
    
    for episode in trange(num_episodes):
        # reset variables
        done, obs = False, env.reset()
        
        while not done:
            action = policy(obs)
            next_obs, reward, done, _ = env.step(action)
            v[obs] = v[obs] + alpha * (reward + gamma * v[next_obs] - v[obs])
            obs = next_obs
            
    return v

## TD-Control

### SARSA

In [None]:
def sarsa(env, obs_space, action_space, num_episodes, alpha, gamma, epsilon):
    # q as action value function
    q = np.zeros(shape=(len(obs_space), len(action_space)))
                 
    # epsilon greedy policy
    def policy(obs):
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            action = q[obs].argmax()
        return action
    
    for episode in trange(num_episodes):
        # reset variables
        done, obs = False, env.reset()
        action = policy(obs)
        
        while not done:
            next_obs, reward, done, _ = env.step(action)
            next_action = policy(next_obs)
            
            q[obs][action] = q[obs][action] + alpha * (reward + gamma * q[next_obs][next_action] * (not done) - q[obs][action])
            obs, action = next_obs, next_action
    
    # greedy policy
    policy_mapping = np.argmax(q, axis=1)
    policy = lambda x: policy_mapping[x]
        
    return policy, q

### Q-Learning

In [None]:
def q_learning(env, obs_space, action_space, num_episodes, alpha, gamma, epsilon):
    # q as action value function
    q = np.zeros(shape=(len(obs_space), len(action_space)))
                 
    # epsilon greedy policy
    def policy(obs):
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            action = q[obs].argmax()
        return action
    
    for episode in trange(num_episodes):
        # reset variables
        done, obs = False, env.reset()
        
        while not done:
            action = policy(obs)
            next_obs, reward, done, _ = env.step(action)
            next_action = policy(next_obs)
            
            q[obs][action] = q[obs][action] + alpha * (reward + gamma * q[next_obs].max() * (not done) - q[obs][action])
            obs = next_obs
    
    # greedy policy
    policy_mapping = np.argmax(q, axis=1)
    policy = lambda x: policy_mapping[x]
        
    return policy, q

## Testing using FrozenLake

In [None]:
env = gym.make('FrozenLake-v1')

In [None]:
obs_space = {obs for obs in range(env.observation_space.n)}
action_space = {action for action in range(env.action_space.n)}

### TD Prediction

In [None]:
def policy(state):
    mapping = {
            0: 2,
            1: 2,
            2: 1,
            3: 0,
            4: 1,
            5: 1,
            6: 1,
            7: 1,
            8: 2,
            9: 1,
            10: 1,
            11: 1,
            12: 2,
            13: 2,
            14: 2,
            15: 2
    }
    return mapping[state]

In [None]:
td_prediction(env=env, policy=policy, obs_space=obs_space, num_episodes=100000, alpha=0.01, gamma=0.99)

### SARSA

In [None]:
policy, q = sarsa(env=env, 
               obs_space=obs_space, 
               action_space=action_space, 
               num_episodes=100000, 
               alpha=0.1, 
               gamma=0.99, 
               epsilon=0.2)

In [None]:
for obs in obs_space:
    print(f'Observation: {obs}, q-values: {q[obs]}, action: {policy(obs)}')

### Q-Learning

In [None]:
policy, q = q_learning(env=env, 
               obs_space=obs_space, 
               action_space=action_space, 
               num_episodes=100000, 
               alpha=0.1, 
               gamma=0.99, 
               epsilon=0.2)

In [None]:
for obs in obs_space:
    print(f'Observation: {obs}, q-values: {q[obs]}, action: {policy(obs)}')