This code snippet imports necessary libraries (random, seaborn, matplotlib.pyplot, numpy) and modules (gridworld, monte_Carlo) for further use in the program

In [16]:
import random
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from gridworld import *
from monte_Carlo import *

This code defines a function plot_state_values(V) that visualizes a state value function represented by the input V (a dictionary mapping state coordinates to their corresponding values). It creates a 5x5 grid, assigns the state values to the corresponding positions, and displays a heatmap with annotations using matplotlib. The color map 'coolwarm' is used, and a color bar is added for reference. The resulting plot provides a visual representation of the state values in the gridworld.

In [17]:
def plot_state_values(V):
    values = np.zeros((5, 5))
    fig, ax = plt.subplots(figsize=(8, 6))
    for y in range(5):
        for x in range(5):
            values[y, x] = V[(y, x)]

    ax.matshow(values, cmap='coolwarm')
    for i in range(5):
        for j in range(5):
            c = values[i, j]
            ax.text(j, i, f'{c:.2f}', va='center', ha='center')
    plt.colorbar(ax.matshow(values, cmap='coolwarm'), ax=ax)
    plt.title("state values")
    plt.show()


This code block is the main entry point for the program. It initializes a Gridworld, sets the number of episodes to 10,000, and the discount factor (gamma) to 0.9. Then, it creates an instance of MonteCarloPolicyEvaluation named mc_policy_eval with the specified parameters. The evaluate_policy() method is called to perform Monte Carlo policy evaluation, and the resulting state value function (V) is obtained using the get_value_function() method. Finally, the plot_state_values(V) function is called to visualize the state values in the gridworld. Overall, this script demonstrates the Monte Carlo policy evaluation for a gridworld environment and visualizes the learned state values.

 Function to print the grid

In [18]:

def print_values(V, gridworld):
    for i in range(len(V)):
        print('-------------------------------------------------')
        for j in range(len(V[0])):
            v = V[i][j]
            if v >= 0:
                print(" %.2f|" % v, end="")
            else:
                print("%.2f|" % v, end="")  # Negative values don't have a space
        print("")
    print('-------------------------------------------------')

In [19]:
if __name__ == "__main__":
 gridworld = Gridworld()
num_episodes = 10000
gamma = 0.9

mc_policy_eval = MonteCarloPolicyEvaluation(gridworld, num_episodes, gamma)
mc_policy_eval.evaluate_policy()
V = mc_policy_eval.get_value_function()
print_values(V, gridworld)
plot_state_values(V)

-------------------------------------------------
 3.95| 10.00| 4.92| 5.00| 1.25|
-------------------------------------------------
 1.89| 3.43| 2.51| 1.93| 0.53|
-------------------------------------------------
 0.27| 0.96| 0.74| 0.43|-0.45|
-------------------------------------------------
-0.89|-0.34|-0.28|-0.56|-1.16|
-------------------------------------------------
-1.81|-1.26|-1.15|-1.38|-1.95|
-------------------------------------------------


ModuleNotFoundError: No module named 'matplotlib.backends'