<a href="https://colab.research.google.com/github/arnavdodiedo/RL-Algorithms/blob/main/DynaQ_GridWorld.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
import random

In [81]:
# deterministic gridworld environment (with unreachable/blocked states) and deterministic greedy policy
class GridWorld():
    def __init__(self, seed=0, grid_h=4, grid_w=4, start_state=[0,0], goal_states=[[3,3]], blocked_states=[], episodes=50, alpha=0.4, gamma=0.9, epsilon_decay=0.9, N=20, reward_per_step=-1):
        np.random.seed(seed)
        self.grid_w = grid_w    # grid width
        self.grid_h = grid_h    # grid height       
        self.start_state = start_state  # start state, should have only one start state
        self.goal_states = goal_states  # target/goal states
        self.blocked_states = blocked_states # blocked states that cannot be visited
        self.episodes = episodes    # number of episodes to run for
        self.alpha = alpha  # equivalent of learning rate
        self.gamma = gamma  # gamma - weight for future rewards
        self.epsilon = 1    # epsilon greedy exploration hyper-parameter, always start with 1
        self.epsilon_decay = epsilon_decay  # epsilon is reduced by this factor
        self.N = N  # number of simulated experiences to run per episode
        self.reward_per_step = reward_per_step  
        self.directions = [[-1,0],[1,0],[0,-1],[0,1]] # up, down, left, right
        self.num_actions = len(self.directions)
        self.q_values = np.zeros((self.grid_h, self.grid_w, self.num_actions))  # state-action values initialized to 0
        self.env_model = dict() # null initial model of the environment

    # reset to initial values
    def reset(self):
        self.q_values = np.zeros((self.grid_h, self.grid_w, self.num_actions))
        self.epsilon = 1
        self.env_model = dict()

    # display gridworld environment
    def display_gridworld(self):
        for i in range(self.grid_h):
            for j in range(self.grid_w):
                if [i,j] == self.start_state:
                    print('[S]', end=' ')
                elif [i,j] in self.goal_states: 
                    print('[G]', end=' ')
                elif [i,j] in self.blocked_states:
                    print('[X]', end=' ')
                else:
                    print('[ ]', end=' ')
            print('')

    # display state-action values for each action in each state
    def display_q_values(self):
        print('[')
        for i in range(self.grid_h):
            print('[', end='')
            for j in range(self.grid_w):
                print('[', end='')
                for a in range(self.num_actions):
                    print(self.q_values[i][j][a], sep='', end=', ')
                print("]", end=' ')
            print(']')
        print(']')

    # display deterministic policy based in state-action values
    def display_policy(self):        
        for i in range(self.grid_h):            
            for j in range(self.grid_w):
                l = []      
                if [i,j] in self.goal_states: 
                    print("[G]", end=' ')
                    continue
                elif [i,j] in self.blocked_states:
                    print("[X]", end=' ')
                    continue
                for a in range(self.num_actions):
                    l.append([self.q_values[i][j][a], a])
                l.sort()
                l.reverse()
                m = [l[0][1]]
                moves = ["^", "v", "<", ">"]

                for k in range(1, self.num_actions):
                    if l[k][0] == l[0][0]:
                        m.append(l[k][1])
                    else: break
                
                print('[', end='')
                for k in range(len(m)):
                    if k!=len(m)-1: print(moves[m[k]], end='&')
                    else: print(moves[m[k]], end='] ')

            print('')
        print('')

    # get next state from current state, action
    def get_next_state(self, state, action):
        next_state = copy.deepcopy(state)

        next_state[0] += self.directions[action][0]
        next_state[1] += self.directions[action][1]

        if next_state[0]<0: next_state[0] = 0
        elif next_state[0]>=self.grid_h: next_state[0] = self.grid_h-1

        if next_state[1]<0: next_state[1] = 0
        elif next_state[1]>=self.grid_w: next_state[1] = self.grid_w-1

        if next_state in self.blocked_states: next_state = copy.deepcopy(state)

        return next_state

    # q learning on gridworld
    def q_learning(self):
        self.reset()
        for _ in range(self.episodes):

            # pick initial state randomly which are not either goal or blocked state
            k = np.random.choice(range(self.grid_h*self.grid_w))
            i = k//self.grid_w
            j = k - self.grid_w*i

            while [i,j] in self.blocked_states or [i,j] in self.goal_states:
                k = np.random.choice(self.grid_h*self.grid_w)
                i = k//self.grid_w
                j = k - self.grid_w*i

            state = [i,j]                        

            # run till terminal state is reached
            while state not in self.goal_states:
                # assign probability to pick action epsilon-greedily based on state-action values
                prob = np.zeros(self.num_actions) + self.epsilon/self.num_actions
                prob[np.argmax(self.q_values[state[0]][state[1]])] = 1 - self.epsilon + self.epsilon/self.num_actions
                action = np.random.choice(range(self.num_actions), p=prob)
                next_state = self.get_next_state(state, action)

                # q learning update
                self.q_values[state[0]][state[1]][action] += self.alpha * (self.reward_per_step + self.gamma * np.max(self.q_values[next_state[0]][next_state[1]]) - self.q_values[state[0]][state[1]][action])
                
                state = next_state

            if (_+1)%(self.episodes//10) == 0:
                print("episode %d:"%(_+1))
                self.display_q_values()            
                self.epsilon *= self.epsilon_decay

    # dynaq on current environment
    def dynaq(self):
        self.reset()
        # initialize number of visits to each state to 0
        number_of_visits = dict()
        
        for _ in range(self.episodes):            
            # pick initial state randomly which are not either goal or blocked state
            k = np.random.choice(range(self.grid_h*self.grid_w))
            i = k//self.grid_w
            j = k - self.grid_w*i

            while [i,j] in self.blocked_states or [i,j] in self.goal_states:
                k = np.random.choice(self.grid_h*self.grid_w)
                i = k//self.grid_w
                j = k - self.grid_w*i
            
            state = [i,j]
            
            path = []

            # run till terminal state is reached
            while state not in self.goal_states:
                # assign probability to pick action epsilon-greedily based on state-action values
                prob = np.zeros(self.num_actions) + self.epsilon/self.num_actions
                prob[np.argmax(self.q_values[state[0]][state[1]])] = 1 - self.epsilon + self.epsilon/self.num_actions
                action = np.random.choice(self.num_actions, p=prob)
                next_state = self.get_next_state(state, action)
                path.append((state, action))

                # build model of deterministic environment based on experience
                if (tuple(state), action) not in self.env_model:
                    self.env_model[(tuple(state), action)] = (self.reward_per_step, next_state)
                    number_of_visits[(tuple(state), action)] = 1
                else:
                    temp = self.env_model[(tuple(state), action)]                    
                    self.env_model[(tuple(state), action)] = ((temp[0]*number_of_visits[(tuple(state), action)]+self.reward_per_step)/(number_of_visits[(tuple(state), action)]+1), next_state)
                    number_of_visits[(tuple(state), action)] += 1

                # q learning update
                self.q_values[state[0]][state[1]][action] += self.alpha * (self.reward_per_step + self.gamma * np.max(self.q_values[next_state[0]][next_state[1]]) - self.q_values[state[0]][state[1]][action])
                
                state = next_state
            
            # learn from simulated experience N times
            for n in range(self.N):
                [(state, action)] = random.sample(path, 1)
                (reward, next_state) = self.env_model[(tuple(state), action)]

                # q learning update using simulated experience
                self.q_values[state[0]][state[1]][action] += self.alpha * (reward + self.gamma * np.max(self.q_values[next_state[0]][next_state[1]]) - self.q_values[state[0]][state[1]][action])

            if (_+1)%(self.episodes//2) == 0:
                print("episode %d:"%(_+1))
                self.display_q_values()            
                self.epsilon *= self.epsilon_decay
        
        print("\ngenerated policy is-")
        self.display_policy()

In [82]:
grid_h = 6
grid_w = 9
blocked_states = [[1,2],[2,2],[3,2],[4,5],[0,7],[1,7],[2,7]]
goal_states = [[0,8]]
grid = GridWorld(seed=0, grid_h=grid_h, grid_w=grid_w, start_state=[2,0], goal_states=goal_states, blocked_states=blocked_states, episodes=50, alpha=0.4, gamma=0.9, epsilon_decay=0.9, N=20, reward_per_step=-1)
grid.display_gridworld()

[ ] [ ] [ ] [ ] [ ] [ ] [ ] [X] [G] 
[ ] [ ] [X] [ ] [ ] [ ] [ ] [X] [ ] 
[S] [ ] [X] [ ] [ ] [ ] [ ] [X] [ ] 
[ ] [ ] [X] [ ] [ ] [ ] [ ] [ ] [ ] 
[ ] [ ] [ ] [ ] [ ] [X] [ ] [ ] [ ] 
[ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] 


In [83]:
grid.dynaq()

episode 25:
[
[[-7.936685550774175, -8.14264036540527, -7.936708980508157, -7.709053760603241, ] [-7.70955579711994, -7.937391692391289, -7.936450518311378, -7.456312557048685, ] [-7.456453255994408, -7.456801063933587, -7.709778008470832, -7.175100988450424, ] [-7.175214945874221, -6.861624190289208, -7.457173257655629, -6.861667384506117, ] [-6.861668273518995, -6.513110165101468, -7.17528644929358, -6.513178737790515, ] [-6.513183100441505, -6.125776990632821, -6.861519965723206, -6.125780076367637, ] [-6.125783789049729, -5.695316402833531, -6.513182921833128, -6.125783140969413, ] [0.0, 0.0, 0.0, 0.0, ] [0.0, 0.0, 0.0, 0.0, ] ]
[[-7.937853216631923, -7.94007112028268, -8.143015643984302, -7.9379835585294325, ] [-7.7093881576552885, -7.712069697088226, -8.143565288808711, -7.937821553290069, ] [0.0, 0.0, 0.0, 0.0, ] [-7.175125847173596, -6.5131665532442655, -6.861780396577413, -6.513183282719984, ] [-6.8616535210839436, -6.125773861004935, -6.861770455325234, -6.12577903827262, ] [