# Session 1: Discrete States and Discrete Actions

## Imports

In [335]:
import numpy as np

## Gridworld

### Environment

In [336]:
class Gridworld:
    def __init__(self, sz = (3,3), start = (0,0), goal = (0,2), traps = [(0,1)],
                 goal_reward = 5, trap_reward = -3, move_reward = -1):
        self.sz = sz
        self.action_space = ['up','left','down','right']
        #create grids
        self.grid_keys = [(i,j) for i in range(sz[0]) for j in range(sz[1])]
        self.start =start
        self.goal = goal
        self.traps = traps
        self.move_reward = move_reward
        self.trap_reward = trap_reward
        self.goal_reward = goal_reward
        self.reset()
        
    def reset(self):
        self.i = self.start[0]
        self.j = self.start[1]
        self.done = False
        #physical grid
        self.physical_grid = dict.fromkeys(self.grid_keys,['F','x'])
        self.physical_grid[self.start] = ['F','o']
        self.physical_grid[self.goal] = ['G','x']
        for t in self.traps: self.physical_grid[t] = ['T','x']
        #reward grid
        self.reward_grid = dict.fromkeys(self.grid_keys,0)
        self.reward_grid[self.goal] = self.goal_reward
        for t in self.traps: self.reward_grid[t] = self.trap_reward
        return((self.i,self.j))
        
    def print_reward(self):
        for i in range(self.sz[0]):
            print('\n----------')
            for j in range(self.sz[1]):
                print(f'{self.reward_grid[(i,j)]} |',end='')
    def print_physical(self):
        for i in range(self.sz[0]):
            print('\n------------------------------------')
            for j in range(self.sz[1]):
                print(f'{self.physical_grid[(i,j)]} |',end='')
    def update_physical(self):
        for key in self.grid_keys:
            self.physical_grid[key][1] = 'x'
        tile = self.physical_grid[(self.i,self.j)][0] 
        self.physical_grid[(self.i,self.j)] = [tile,'o']
    def step(self,action):
        reward = self.move_reward
        i,j = self.i,self.j
        if action == 'up':
            i -= 1
        elif action == 'left':
            j -= 1
        elif action == 'down':
            i += 1
        elif action == 'right':
            j += 1
        #check legality
        if (i,j) in self.grid_keys:
            #update position
            self.i,self.j = i,j
            #update physical
            self.update_physical()
            #update reward
            reward += self.reward_grid[(i,j)]
        else:
            pass
        if (self.i,self.j) == self.goal: self.done = True
        #return s',r, done or not
        return((self.i,self.j),reward,self.done)

In [337]:
g = Gridworld()
g.print_physical()
g.print_reward()


------------------------------------
['F', 'o'] |['T', 'x'] |['G', 'x'] |
------------------------------------
['F', 'x'] |['F', 'x'] |['F', 'x'] |
------------------------------------
['F', 'x'] |['F', 'x'] |['F', 'x'] |
----------
0 |-3 |5 |
----------
0 |0 |0 |
----------
0 |0 |0 |

In [338]:
print(g.step('right'))
print(g.step('right'))
g.print_physical()

((0, 1), -4, False)
((0, 2), 4, True)

------------------------------------
['F', 'x'] |['T', 'x'] |['G', 'o'] |
------------------------------------
['F', 'x'] |['F', 'x'] |['F', 'x'] |
------------------------------------
['F', 'x'] |['F', 'x'] |['F', 'x'] |

In [339]:
g.reset()

(0, 0)

### Agent

In [340]:
class Agent:
    def __init__(self, env, policy, gamma = 1):
        self.env = env
        self.policy = policy
        self.gamma = gamma
        self.v = dict.fromkeys(self.env.grid_keys,0)
    def select_action(self,state):
        action = self.policy[state]
        return(action)

In [348]:
env = Gridworld()
policy = {(0, 0): 'down',
          (0, 1): 'right',
          (0, 2): None,
          (1, 0): 'right',
          (1, 1): 'right',
          (1, 2): 'up',
          (2, 0): 'right',
          (2, 1): 'up',
          (2, 2): 'up'}
a = Agent(env,policy,gamma=1)

In [349]:
a.env.print_physical()


------------------------------------
['F', 'o'] |['T', 'x'] |['G', 'x'] |
------------------------------------
['F', 'x'] |['F', 'x'] |['F', 'x'] |
------------------------------------
['F', 'x'] |['F', 'x'] |['F', 'x'] |

In [350]:
state = env.reset()
while True:
    old_state = state
    state,reward,done = env.step(a.select_action(state))
    print(old_state,state,reward,done)
    if done: break

(0, 0) (1, 0) -1 False
(1, 0) (1, 1) -1 False
(1, 1) (1, 2) -1 False
(1, 2) (0, 2) 4 True
