In [None]:

import gridworlds           # import to trigger registration of the environment
import gymnasium as gym
import numpy as np

# create instance
env = gym.make("gridworld-v0")
env.reset()

In [None]:
class Policy:
    def __init__(self, exploration_factor = 0.1, gamma=0.9, initial_policy=None):
        
        self.exploration_factor = exploration_factor
        self.gamma=gamma

        if initial_policy ==None:
           self.policy = np.full((5, 5, 4), 0.25)
           #self.policy = np.tile([0.1, 0.1, 0.4, 0.4], (16, 1))
     
        else:
            self.policy=initial_policy
       
        self.N_detailed=np.zeros((16,4,16))
        self.Q=np.zeros((16,4))
        self.v=np.zeros((16))
        self.Returns=[[] for i in range(16)]
        self.Rewards_cum=np.zeros((16,4))
        self.pos_dict={i*4+j: [i,j] for i in range(4) for j in range(4)}

    def act(self, state):
        # epsilon-greedy
        if np.random.rand() < self.exploration_factor:
            return np.random.choice(4) # random action
        else:
            return np.argmax(self.policy[state[0], state[1]]) # greedy action
    
    
    def update(self, state, action):
        self.policy[state[0], state[1]] = np.zeros(4)
        self.policy[state[0], state[1], action] = 1
    
    
    def eval_episode(self,episode):
        states, actions, rewards= episode 
        next_state=states[-1]
        G=0
        for i, state in reversed(list(enumerate(states[:-1]))):
            action=actions[i]
            reward=rewards[i]
            self.Rewards_cum[state,action]+=reward
            G=self.gamma*G + reward
            self.N_detailed[state,action,next_state]+=1
            if state not in states[:i]:
                self.Returns[state].append(G)
                self.v[state]=sum(self.Returns[state])/len(self.Returns[state])
            next_state=state


    def compute_q(self):
        for state in range(self.Q.shape[0]):
            for action in range(self.Q.shape[1]):

                s_a_visits=np.sum(self.N_detailed[state,action])
                if s_a_visits==0:
                    s_a_visits=1 # dirty solution to avoiding nan values
                probs_s_a=self.N_detailed[state,action]/s_a_visits
                s_a_rewards=self.Rewards_cum[state,action]/s_a_visits
                self.Q[state,action]=np.dot(probs_s_a,self.v)+s_a_rewards

    def improve_policy(self):
        self.compute_q()
        for state in range(self.Q.shape[0]):
            best_action_value=np.max(self.Q[state])
            base=0
            new_probs=np.zeros(4)
            for i, action_value in enumerate(self.Q[state]):
                if action_value== best_action_value:
                    new_probs[i]=1
                    base+=1
            new_probs=new_probs/base
            self.policy[state] = new_probs
        self.Returns=[[] for i in range(16)]
    
    def optimise(self,env,max_iterations, eps_per_iter=100,max_eps_length=100):
        for it in range(max_iterations):
            for eps_nr in range(eps_per_iter):
                env.reset()
                state, _ = env.reset()
                episode=[[state],[],[]]
                for i in range(max_eps_length):
                    action = self.act(state)
                    state, reward, terminated, truncated, _ = env.step(action) 
                    episode[0].append(state)
                    episode[1].append(action)
                    episode[2].append(reward)
                    
                    if terminated or truncated:
                        break
                
                self.eval_episode(episode)
            
            if it % 10 == 0:
                self.improve_policy()
                self.show_heatmap(it)
                self.show_policy(it)
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_gridworld(final_value, final_policy):
    fig, ax = plt.subplots()
    im = ax.imshow(final_value, cmap='viridis')

    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
            rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(5):
        for j in range(5):
            text = ax.text(j, i, round(final_value[i, j], 2),
                        ha="center", va="center", color="w")

    ax.set_title("5x5 Gridworld Value Function")
    fig.tight_layout()
    plt.show()
    policy_arrows = {0: '↑', 1: '→', 2: '↓', 3: '←'}

    fig, ax = plt.subplots()
    ax.imshow(final_value, cmap='viridis')

    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(5))
    ax.set_yticks(np.arange(5))

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
            rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(5):
        for j in range(5):
            best_action = np.argmax(final_policy[i, j])
            text = ax.text(j, i, policy_arrows[best_action],
                        ha="center", va="center", color="w")

    ax.set_title("5x5 Gridworld Optimal Policy")
    fig.tight_layout()
    plt.show()