In [None]:
%run Imports.ipynb
%run Discrete_Env.ipynb

In [None]:
class GridWorldEnv(DiscreteEnvironment):
    
    def __init__(self, gamma, gridsize):
        
        # Grid Size NxN
        self.N = gridsize
        
        # States set S
        self.states = [(i,j) for i in range(0,self.N) for j in range(0,self.N)]
        
        # Action sets for every state A(s)
        self.actions = {ele:["up","down","left","right"] for ele in self.states}
        
        # Discount factor Gamma
        self.gamma = gamma
        
        # Rewards r(s,a)
        self.rewards = {}
        for state in self.states:
            action_list = self.actions[state]
            for action in action_list:
                tup = (state,action)
                self.rewards[tup] = 0
        
        for action in self.actions[self.final_state()]:
            self.rewards[(self.final_state(),action)] = +100
        
        # Transition function p(s'|s,a) -> [s,a,s']
        self.transitions = {}
        for i in range(0,self.N):
            # Boundaries
            tup_list = [((0,i),"up",(0,i)), ((i,0),"left",(i,0)), 
                        ((self.N-1,i),"down",(self.N-1,i)), ((i,self.N-1),"right",(i,self.N-1))]
            for tup in tup_list:
                self.transitions[tup] = 1.0
            # down, right
            for j in range(0,self.N-1):
                tup_list = [((j,i), "down", (j+1,i)), ((i,j), "right", (i,j+1))]
                for tup in tup_list:
                    self.transitions[tup] = 1.0
            # up, left
            for j in range(1,self.N):
                tup_list = [((j,i), "up", (j-1,i)), ((i,j), "left", (i,j-1))]
                for tup in tup_list:
                    self.transitions[tup] = 1.0
                    
        # Current state of agent
        self.agent_state = self.initial_state()
        
        # Has the game terminated?
        self.is_terminated = False
    
    def step(self, action):
        state = self.agent_state
        reward = self.rewards[(state,action)]
        for dest in self.states:
            # note that this case is specific to this deterministic grid world
            if (state,action,dest) in self.transitions and self.transitions[(state,action,dest)] == 1:
                self.agent_state = dest
                if self.agent_state == self.final_state():
                    self.is_terminated = True 
                return [self.agent_state, reward, self.is_terminated]
        return [self.agent_state, reward, self.is_terminated]
        
    def reset(self): 
        self.agent_state = self.initial_state()
        is_terminated = False
        return self.agent_state
    
    def initial_state(self):
        return random.choice(self.states)
    
    def final_state(self):
        return (self.N-1,self.N-1)