In [7]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn as sk
import seaborn as sns

from operator import add


from IPython.core.debugger import set_trace


In [14]:
REWARD_NONTERMINAL = -1
REWARD_TERMINAL = 10
REWARD_CLIFF = -100
ACTION_DIRECTIONS = [(0, 1), (1, 0), (0, -1), (-1, 0)]

class Environment:

    def __init__(self, nr_columns, nr_rows, nr_actions=4, init_qa_values=0):
        self.world = np.zeros((nr_rows, nr_columns))
        self.nr_columns = nr_columns
        self.nr_rows = nr_rows
        self.nr_actions = nr_actions
        
    def set_world_rewards(self):
        pass
    
    def is_out_of_bounds(self, state):
        if state[0] < 0 or state[0] > self.nr_rows - 1:
            return True
        
        if state[1] < 0 or state[1] > self.nr_columns -1:
            return True
        
        return False
    
    def next_state(self, state, action):
        pass
    
    def check_termination(self, state):
        pass

class CliffEnvironment(Environment):
    
    def __init__(self, nr_columns, nr_rows, nr_actions=4, init_qa_values=0):
        super().__init__(nr_columns, nr_rows)
        self.set_world_rewards()
        
    def set_world_rewards(self):
        self.world[:, :] = REWARD_NONTERMINAL
        self.world[self.nr_rows - 1:, 1:self.nr_columns - 1] = REWARD_CLIFF
        self.world[self.nr_rows - 1, self.nr_columns - 1] = REWARD_TERMINAL
        
    def next_state(self, state, action_index):
        action = ACTION_DIRECTIONS[action_index]
        next_state = tuple(map(add, state, action))
        
        if self.is_out_of_bounds(next_state):
            next_state = state
        
        return next_state
    
    def check_termination(self, state):
        return state[1] == self.nr_rows - 1 and state[0] > 0

In [18]:
ALPHA = 0.1
GAMMA = 1

class Agent:
    def __init__(self, env, epsilon=0.2, init_position=(0,0)):
        self.init_position = init_position
        self.curr_state = init_position
        self.env = env
        self.epsilon = epsilon
        self.q_table = np.zeros((env.nr_columns, env.nr_rows + 1, env.nr_actions))
        
    def run(self):
        pass
    
    def get_next_action(self):
        """Returns the next index of the action according to the epsilon-greedy choice"""
        actions = self.q_table[self.curr_state]
        
        # If we choose randomly
        if np.random.random() < self.epsilon:
            return np.random.choice(4)
        
        return np.argmax(actions)
    
    def get_next_state(self, action):
        """Return next theoretical state according to the environment."""
        return self.env.next_state(self.curr_state, action)
    
    def update_q_table(self, action, next_state, next_best_action):
        pass
    
    def get_reward_for_state(self, state):
        return self.env.world.transpose()[state]
    
    def terminated(self):
        set_trace()
        return self.env.check_termination(self.curr_state)
    
    def update_state(self, next_state):
        self.curr_state = next_state
        
class QLearner(Agent):
    def __init__(self, env, nr_episodes, epsilon=0.2 , init_position=(0,0)):
        super().__init__(env, epsilon, init_position)
        self.nr_episodes = nr_episodes
        
    def run(self):
        for i in range(self.nr_episodes):
            self.curr_state = self.init_position
            
            while not self.terminated():
                action_index = self.get_next_action()
                next_state = self.get_next_state(action_index)
                next_state_best_action_index = np.argmax(self.q_table[next_state])

                self.update_q_table(action_index, next_state, next_state_best_action_index)
                self.update_state(next_state)
                
    def update_q_table(self, action_index, next_state, next_best_action_index):
        curr_q = self.q_table[self.curr_state][action_index]
        update = (self.get_reward_for_state(next_state) + GAMMA * self.q_table[next_state][next_best_action_index] - curr_q)
        self.q_table[self.curr_state][action_index] = curr_q + ALPHA * update
        
    
class SarsaLearner(Agent):
    def __init__(self, env, nr_episodes, epsilon=0.2, init_position=(0,0)):
        super().__init__(env, init_position)
        self.nr_episodes = nr_episodes
        
    def run(self):
        for i in range(self.nr_episodes):
            pass
            

In [19]:
###
# Main Run - Initialization
###
NR_COLUMNS = 5

# Including the cliff
NR_ROWS = 4

cliffWorld = CliffEnvironment(NR_COLUMNS, NR_ROWS)
agent = QLearner(cliffWorld, 100, 0.2, (0, NR_ROWS))

In [20]:
agent.run()

> [0;32m<ipython-input-18-f7e87a280039>[0m(37)[0;36mterminated[0;34m()[0m
[0;32m     35 [0;31m    [0;32mdef[0m [0mterminated[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 37 [0;31m        [0;32mreturn[0m [0mself[0m[0;34m.[0m[0menv[0m[0;34m.[0m[0mcheck_termination[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mcurr_state[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m[0;34m[0m[0m
[0m[0;32m     39 [0;31m    [0;32mdef[0m [0mupdate_state[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mnext_state[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0m
[0m
ipdb> self.env.check_termination(self.curr_state)
False
ipdb> self.curr_state
(0, 4)
ipdb> c


IndexError: index 4 is out of bounds for axis 1 with size 4

In [6]:
agent.q_table

array([[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 1.26625651e+05,  0.00000000e+00,  0.00000000e+00,
          3.63495324e+03],
        [ 1.53607802e+06,  0.00000000e+00, -1.00000000e-01,
          0.00000000e+00],
        [ 2.84297523e+06,  3.90024330e+05,  2.86302610e+05,
          9.85181610e+05],
        [ 1.89412752e+06,  3.01021348e+06,  2.13826911e+06,
          2.46569319e+06]],

       [[ 8.83232395e+04,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 1.62684834e+06,  4.26151034e+04, -1.00000000e-01,
          3.63495324e+03],
        [ 2.83835647e+06,  1.31182903e+06,  6.82459874e+05,
          3.33303931e+05],
        [ 2.80155104e+06,  3.02441243e+06,  2.61618305e+06,
          2.58617473e+06],
        [ 3.00703375e+06,  3.03879785e+06,  2.97406177e+06,
          2.96463275e+06]],

       [[ 4.99312956e+05,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 2.77809546e+06,  1.02531923e+