In [2]:
import gymnasium as gym
import collections
from torch.utils.tensorboard import SummaryWriter

In [35]:
GAMMA = 0.9
TEST_EPISODES = 20
ENV = "FrozenLake8x8-v1"

In [36]:
class Agent():
    def __init__(self, env=ENV):
        self.env = gym.make(env)
        self.old_obs, _ = self.env.reset()
        self.rewards = collections.defaultdict(float)
        self.transits = collections.defaultdict(collections.Counter)
        self.values = collections.defaultdict(float)

    def play_n_random_steps(self, count=1000):
        for i in range(count):
            action = self.env.action_space.sample()
            new_obs, reward, terminated, truncated, info = self.env.step(action)
            self.rewards[(self.old_obs, action, new_obs)] = reward
            self.transits[(self.old_obs, action)][new_obs] += 1
            if (terminated or truncated):
                self.old_obs, _ = self.env.reset()
            else:
                self.old_obs = new_obs

    def calc_action_value(self, state, action):
        target_counts = self.transits[(state, action)]
        total = sum(target_counts.values())
        action_value = 0.0
        for target_state, count in target_counts.items():
            reward = self.rewards[(state, action, target_state)]
            val = reward + GAMMA*self.values[(target_state)]
            action_value += (count/total)*val
        return action_value

    def select_action(self, state):
        best_action, best_value = None, None
        for action in range(self.env.action_space.n):
            action_value = self.calc_action_value(state, action)
            if best_value is None or best_value<action_value:
                best_action = action
                best_value = action_value
        return best_action
    
    def play_episode(self, env):
        total_reward = 0.0
        state, _ = env.reset()
        while True:
            action = self.select_action(state)
            new_state, reward, terminated, truncated, _ = env.step(action)
            self.rewards[(state, action, new_state)] = reward
            self.transits[(state, action)][new_state] += 1
            total_reward += reward
            if (terminated or truncated):
                break
            state = new_state
        return total_reward

    def value_iteration(self):
        for state in range(self.env.observation_space.n):
            state_values = [self.calc_action_value(state, action) for action in range(self.env.action_space.n)]
            self.values[state] = max(state_values)

In [37]:
if __name__=='__main__':
    test_env = gym.make(ENV, render_mode='human')
    agent = Agent()
    iter_no = 0
    best_reward = 0.0
    while True:
        iter_no += 1
        agent.play_n_random_steps(100)
        agent.value_iteration()
        reward = 0.0
        for _ in range(TEST_EPISODES):
            reward += agent.play_episode(test_env)
        reward /=  TEST_EPISODES
        if (reward > best_reward):
            best_reward = reward
            print(f"Best reward updated: {best_reward}")
            print(f"V(S) = {agent.values}")
        
        if (best_reward > 0.9):
            print("WON!")
            print(f"Solved in {iter_no} iterations")
            break

Best reward updated: 0.05
V(S) = defaultdict(<class 'float'>, {0: 0.0, 8: 0.0, 1: 0.0, 9: 0.0, 2: 0.0, 10: 0.005575117304963663, 3: 0.011911102946954445, 11: 0.028048046587298807, 4: 0.044795040169256006, 12: 0.07658866889381666, 5: 0.09105860569532159, 6: 0.13082626632988031, 14: 0.21657535512604015, 7: 0.15214852197923207, 15: 0.2627352809595015, 19: 0.0, 13: 0.14705221143586447, 21: 0.17688941832175503, 22: 0.3329316540263473, 16: 0.0, 17: 0.003106373164268094, 18: 0.008368951824260515, 20: 0.12129085559554652, 29: 0.0, 23: 0.36716326203698035, 24: 0.0013922923823437505, 25: 0.006731183916981508, 26: 0.020015251553844864, 27: 0.04008345833114755, 28: 0.1271400985997299, 30: 0.535087757223992, 31: 0.5961489786110942, 32: 0.0011482129056359087, 33: 0.003263995926529338, 34: 0.006483392955266139, 35: 0.0, 36: 0.165723129227438, 37: 0.19319232428392125, 38: 0.48157898150159284, 39: 0.81, 40: 0.000516695807536159, 41: 0.0, 42: 0.0, 43: 0.1016883878056726, 44: 0.12994039387816458, 45: 0.1