In [106]:
import numpy as np
import os
import pandas as pd
import random
from collections import defaultdict
import gym
import gym_minigrid
import matplotlib.pyplot as plt
from IPython.display import clear_output
import imageio
import time
%matplotlib inline

In [107]:
class SARSA:
    def __init__(self, actions, agent_indicator=10):
        self.actions = actions
        self.agent_indicator = agent_indicator
        self.alpha = 0.01
        self.gamma = 0.9
        self.epsilon = 0.2
        self.q_values = defaultdict(lambda: [0.0] * actions)
        
    def _convert_state(self, s):
        return tuple(s.flatten())
        
    def update(self, state, action, reward, next_state, next_action):
        state = self._convert_state(state)
        next_state = self._convert_state(next_state)
        
        q_value = self.q_values[state][action]
        next_q_value = self.q_values[next_state][next_action]
        
        td_error = reward + self.gamma * next_q_value - q_value
        self.q_values[state][action] = q_value + self.alpha * td_error
    
    def act(self, state):
        if np.random.rand() < self.epsilon:
            action = np.random.choice(self.actions)
        else:
            state = self._convert_state(state)
            q_values = self.q_values[state]
            action = np.argmax(q_values)
        return action

In [108]:
def gen_wrapped_env():
    env = gym.make('MiniGrid-LavaCrossingS9N1-v0')
    env = gym_minigrid.wrappers.FullyObsWrapper(env)
    return env

In [109]:
def show_video(frames, path='./minigrid_crossing.mp4'):
    with imageio.get_writer(path, fps=30) as video:
        for frame in frames:
            video.append_data(frame)
    print(f"Video saved to {path}")

In [110]:
def train_sarsa_on_minigrid(episodes=10000):
    env = gen_wrapped_env()
    agent = SARSA(actions=env.action_space.n)
    
    reward_history = []
    frames = []
    
    for episode in range(episodes):
        state = env.reset()["image"]
        action = agent.act(state)
        total_reward = 0
        done = False
        
        while not done:
            next_state, reward, done, _ = env.step(action)
            next_state = next_state["image"]
            next_action = agent.act(next_state)
            
            agent.update(state, action, reward, next_state, next_action)
            
            state = next_state
            action = next_action
            total_reward += reward
            
            # Collect frames for video
            frames.append(env.render(mode='rgb_array'))
        
        reward_history.append(total_reward)
        
        # Show progress and average reward periodically
        if (episode + 1) % 20 == 0:
            avg_reward = np.mean(reward_history[-20:])
            progress = (episode + 1) / episodes * 100
            print(f"Episode: {episode + 1}, Rewards: {avg_reward:.2f}, Progress: {progress:.1f}%")
    
    # Generate video of all episodes
    show_video(frames)

In [None]:
train_sarsa_on_minigrid(episodes=10000)

Episode: 20, Rewards: 0.00, Progress: 0.2%
Episode: 40, Rewards: 0.00, Progress: 0.4%
Episode: 60, Rewards: 0.00, Progress: 0.6%
Episode: 80, Rewards: 0.00, Progress: 0.8%
Episode: 100, Rewards: 0.00, Progress: 1.0%
Episode: 120, Rewards: 0.00, Progress: 1.2%
Episode: 140, Rewards: 0.00, Progress: 1.4%
Episode: 160, Rewards: 0.00, Progress: 1.6%
Episode: 180, Rewards: 0.00, Progress: 1.8%
Episode: 200, Rewards: 0.00, Progress: 2.0%
Episode: 220, Rewards: 0.00, Progress: 2.2%
Episode: 240, Rewards: 0.00, Progress: 2.4%
Episode: 260, Rewards: 0.00, Progress: 2.6%
Episode: 280, Rewards: 0.00, Progress: 2.8%
Episode: 300, Rewards: 0.00, Progress: 3.0%
Episode: 320, Rewards: 0.00, Progress: 3.2%
