In [42]:
import numpy as np
from matplotlib import pyplot as plt

In [43]:
class Agent:
    def choose_action(self,state):
        action = 0
        if np.random.uniform(0,1) < self.epsilon:
            action = self.action_space.sample()
        else:
            action = np.argmax(self.Q[state,:])
        return action

In [44]:
class SarsaAgent(Agent):
    def __init__(self,epsilon,alpha,gamma,num_states,num_actions,action_space):
        
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.num_states = num_states
        self.num_actions = num_actions
        self.action_space = action_space
        
        self.Q = np.zeros((num_states,num_actions))
        self.action_space = action_space
        
    def update(self,prev_state,prev_action,reward,next_state,next_action):
        prediction = self.Q[prev_state,prev_action]
        target = reward + self.gamma * self.Q[next_state,next_action]
        error = target - prediction
        self.Q[prev_state,prev_action] += alpha * error
    

In [55]:
class QLearningAgent(Agent):
    def __init__(self,epsilon,alpha,gamma,num_states,num_actions,action_space):
        
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.num_states = num_states
        self.num_actions = num_actions
        self.action_space = action_space
        
        self.Q = np.zeros((num_states,num_actions))
        self.action_space = action_space
        
    def update(self,prev_state,prev_action,reward,next_state,next_action):
        prediction = self.Q[prev_state,prev_action]
        target = reward + self.gamma * np.max(self.Q[next_state,:])
        error = target - prediction
        self.Q[prev_state,prev_action] += alpha * error

In [62]:
class QLearningAgent(Agent):
    def __init__(self,epsilon,alpha,gamma,num_states,num_actions,action_space):
        
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.num_states = num_states
        self.num_actions = num_actions
        self.action_space = action_space
        
        self.Q = np.zeros((num_states,num_actions))
        self.action_space = action_space
        
    def update(self,prev_state,prev_action,reward,next_state,next_action):
        prediction = self.Q[prev_state,prev_action]
        target = reward + self.gamma * np.max(self.Q[next_state,:])
        error = target - prediction
        self.Q[prev_state,prev_action] += alpha * error

In [67]:
import gym

env = gym.make('CliffWalking-v0') 

epsilon = .1
total_episodes = 500
max_steps = 500
alpha = .5
gamma = 1

episodeReward = 0
totalReward = {
    'SarsaAgent':[],
    'QLearningAgent':[],
    'ExpectedSarsa': []
}

sarsaAgent = SarsaAgent(epsilon,alpha,gamma,env.observation_space.n,env.action_space.n,env.action_space)
qAgent = QLearningAgent(epsilon,alpha,gamma,env.observation_space.n,env.action_space.n,env.action_space)

agents = [sarsaAgent,qAgent]

for agent in agents:
    for _ in range(total_episodes):
        t = 0
        state = env.reset()
        action = agent.choose_action(state)
        episode_reward = 0
        while t < max_steps:
            next_state,reward,done,info = env.step(action)            
            next_action = agent.choose_action(next_state)
            
            agent.update(state,action,reward,next_state,next_action)
            
            state = next_state
            action = next_action
            
            t += 1
            episode_reward += reward
            
            if done:
                break
        totalReward[type(agent).__name__].append(episode_reward)
env.close()

            
           
meanReturn = {
    'sarsaAgent': np.mean(totalReward['SarsaAgent']),
    'qLearningAgent': np.mean(totalReward['QLearningAgent']),
}
print(f"SARSA Average Sum of Reward: {meanReturn['sarsaAgent']}") 
print(f"Q-Learning Average Sum of Return: {meanReturn['qLearningAgent']}") 
            
            

SARSA Average Sum of Reward: -33.252
Q-Learning Average Sum of Return: -53.164
