# Escape Room Game

In [1]:
import random
from sklearn import tree
import emoji 

Functionality:

In [2]:
class Grid:
    def __init__(self, pos_init=(1, 1), key_init=1, door_init=0):
        self.pos_init = pos_init
        self.key_init = key_init
        self.key_pos = (5, key_init)
        self.door_init = door_init
        self.door_pos = (7, door_init)
        self.pos = list(pos_init)
        self.has_key = False
        self.move_count = 0
        self.fires = [(3, 1), (3, 2)]
        self.score = 0
        self.move_reward = -1
        self.fire_reward = -50
        self.key_reward = 0
        self.escape_reward = 0
        self.max_nrof_moves = 25
        self.reward_functions = {}
        
        self.all_states_list = None
        
        if self.pos == self.key_pos:
            self.has_key = True
    
    def from_state(self, state):
        posX, posY, has_key, key_init, door_init = state
        self.__init__((posX, posY), key_init, door_init)
        self.has_key = has_key
        return self
    
    def make_random(self):
        x = random.randrange(9)
        y = random.randrange(3)
        k = random.randrange(3)
        d = random.randrange(3)
        # print("Initial parameters:", x, y, k, d)
        self.__init__((x, y), k, d)
        return self
    
    def add_key(self, remove=False):
        self.has_key = not remove
        return self
    
    def as_list(self, with_emoji=True):
        grid_list = []
        for i in range(9):
            grid_list.append([])
            for j in range(3):
                if (i, j) == tuple(self.pos):
                    # field = "A"
                    field = emoji.emojize(":cowboy_hat_face:") if with_emoji else "A"
                elif (i, j) in self.fires:
                    # field = "F"
                    field = emoji.emojize(":fire:") if with_emoji else "F"
                elif (i, j) == self.key_pos:
                    # field = "K" if not self.has_key else "."
                    if self.has_key:
                        field = ". "
                    else:
                        field = emoji.emojize(":key:") if with_emoji else "K"
                elif (i, j) == self.door_pos:
                    # field = "D"
                    field = emoji.emojize(":door:") if with_emoji else "D"
                else:
                    field = "  "
                grid_list[i].append(field)
        return grid_list
    
    def __str__(self, grid_list=None):
        if grid_list is None:
            grid_list = self.as_list()
        conlen = len(grid_list[0][0]) + 2
        row_strs = []
        for j in range(3):
            row = [grid_list[i][j] for i in range(9)]
            row_str = str(j) + "  " + " | ".join(row)
            row_strs.append(row_str)
        grid_row = "\n  " + "-"*(conlen*9+8) + "\n"
        firstline = "  "
        firstline += " "*(conlen//2)
        firstline += (" "*conlen).join([str(i) for i in range(9)]) + " \n"
        grid_str = firstline + grid_row.join(row_strs)
        return grid_str
    
    def move(self, direction, suppress_print=False):
        # redo using next state
        if direction == 0:
            self.pos[1] = max(0, self.pos[1]-1)
        elif direction == 1:
            self.pos[0] = min(8, self.pos[0]+1)
        elif direction == 2:
            self.pos[1] = min(2, self.pos[1]+1)
        elif direction == 3:
            self.pos[0] = max(0, self.pos[0]-1)
        else:
            print("\nInvalid move!! (pos {}, direction {})\n".format(self.pos, direction))
            
        self.move_count += 1
        self.score += self.move_reward
        
        if tuple(self.pos) == self.key_pos and not self.has_key:
            self.has_key = True
            self.score += self.key_reward
            if not suppress_print: print("Found key!")
            
        if tuple(self.pos) in self.fires:
            self.score += self.fire_reward
            if not suppress_print: print("Oouch, you got burned!")
            
        if tuple(self.pos) == self.door_pos and self.has_key:
            self.score += self.escape_reward
            if not suppress_print: print("Found door! (finished in " + str(self.move_count) + " moves.)")
            if not suppress_print: print("Congratulations! Achieved a score of " + str(self.score) + ".")
            return False
        
        if self.move_count >= self.max_nrof_moves:
            if not suppress_print: print("You loose! (move " + str(self.move_count) +")")
            if not suppress_print: print("Ended on a score of", self.score)
            return False
        
        return True

    def next_state(self, state, direction):
        posX, posY, has_key, key_init, door_init = state
        reward = self.move_reward
        finished = False
        
        if direction == 0:
            posY = max(0, posY-1)
        elif direction == 1:
            posX = min(8, posX+1)
        elif direction == 2:
            posY = min(2, posY+1)
        elif direction == 3:
            posX = max(0, posX-1)
        else:
            print("\nInvalid move!! (state {}, direction {})\n".format(state, direction))
            return None
        
        if (posX, posY) == (self.key_pos[0], key_init) and not has_key:
            has_key = True
            reward += self.key_reward
            
        if (posX, posY) in self.fires:
            reward += self.fire_reward
            
        if (posX, posY) == (self.door_pos[0], door_init) and has_key:
            reward += self.escape_reward
            finished = True
                    
        state = posX, posY, has_key, key_init, door_init
        return state, reward, finished

    def oracle_move(self, state=None):
        if state is None:
            x, y = self.pos
            has_key = int(self.has_key)
            key_init = self.key_init
            door_init = self.door_init
        else:
            x, y, has_key, key_init, door_init = state
            
        if x<=4:
            if x<3 and y>0:
                return 0
            else:
                return 1
        elif has_key==1:
            if x<=6:
                return 1
            elif x >7:
                return 3
            elif y==0:
                return 2
            elif y==2:
                return 0
            elif door_init==0:
                return 0
            elif door_init==2:
                return 2
            else:
                return None
        elif x>5:
            return 3
        elif y==0:
            return 2
        elif y==2:
            return 0
        elif key_init==0:
            return 0
        elif key_init==2:
            return 2
        else:
            return None
        
    @property
    def state(self):
        return(self.pos[0], self.pos[1], int(self.has_key), self.key_init, self.door_init)
    
    def get_state_str(self, state=None):
        if state is None:
            state = self.state
    
        state_str = "Agent position: ({}, {}), has key: {}, key and door positions: {}, {}".format(*state)
        return state_str
    state_str = property(get_state_str)
        
        
    @property
    def all_states(self):
        if self.all_states_list is not None:
            return self.all_states_list
        else:
            maxnrof_elements = [9,3,2,3,3]
            curr_state = [i-1 for i in maxnrof_elements]
            all_states = []
            while curr_state != [0,0,0,0,0]:
                all_states.append([i for i in curr_state])

                for i in range(len(maxnrof_elements)):
                    if curr_state[i] == 0:
                        curr_state[i] = maxnrof_elements[i] - 1
                    else:
                        curr_state[i] = curr_state[i] - 1
                        break
            all_states.append([0,0,0,0,0])
            
            assert len(all_states) == 9*3*2*3*3
            self.all_states_list = [tuple(state) for state in all_states]
            return self.all_states_list

    def print_fn_on_grid(self, fn, fn_name="Policy"):
        """
        fn: state -> string (of constant length)
        """
        grid_list = self.as_list()
        for state in self.all_states:
            if state[2:] != self.state[2:]:
                continue
            else:
                grid_list[state[0]][state[1]] = fn(grid_list[state[0]][state[1]], state)
        print("\n" + fn_name + " without key:")
        print(self.__str__(grid_list=grid_list))
        gl_no_key = grid_list

        self.has_key=True
        grid_list = self.as_list()
        for state in self.all_states:
            if state[2:] != self.state[2:]:
                continue
            else:
                grid_list[state[0]][state[1]] = fn(grid_list[state[0]][state[1]], state)
        print("\n" + fn_name + " with key:")
        print(self.__str__(grid_list=grid_list))
        
        return gl_no_key, grid_list        

    def print_policy(self, policy, policy_name="Policy"):
        def fn(grid_symb, state):
            arrow = dir_str(policy(state))
            if grid_symb == " ":
                return " " + arrow + " "
            else:
                return grid_symb + " " + arrow
            
        return self.print_fn_on_grid(fn, fn_name=policy_name)  
     
    def print_reward_fn(self, reward_fn, show_diff=False, small=False, nr_len=3):
        assert not self.has_key
        def reward_str(grid_symb, state):
            rew = reward_fn(state)
            opt_rew = self.reward_function(state) if show_diff else 0
            rew_str = " ?"+" "*(nr_len-2) if rew is None else ("{: "+str(nr_len)+"d}").format(rew-opt_rew)
            
            if not small:
                if grid_symb == " "*len(grid_symb):
                    return " "*len(grid_symb) + rew_str + " "
                else:
                    return grid_symb + " " + rew_str                
            else:
                if grid_symb == " "*len(grid_symb):
                    return rew_str 
                else:
                    return grid_symb + " "*(nr_len-2)

        fn_name = "Future reward" if not show_diff else "Relative future reward"
        return self.print_fn_on_grid(reward_str, fn_name=fn_name);

        
    def get_reward_function(self, policy=None, suppress_print=True):
        if policy is None:
            policy = self.oracle_move
        
        if policy in self.reward_functions.keys():
            return self.reward_functions[policy]
        
        # states_to_process = [tuple(state) for state in self.all_states]
        states_to_process = [(tuple(state), moves_left) for state in self.all_states \
                             for moves_left in range(self.max_nrof_moves+1)]
        reward = {state:None for state in states_to_process}
        # stats = {state:None for state in states_to_process} # (#fire, key?, door?)
        if not suppress_print: print("Starting going through " + str(len(states_to_process)) + " states.")

        count = 0
        first_new = len(states_to_process)
        epoch_nr = 1
        rew_ass_list = []
        impossible_states, end_states = [], []
        while len(states_to_process) > 0:
            if count==first_new:
                if not suppress_print: print("Went through epoch {}. Still got {} to process.".format(
                    epoch_nr, len(states_to_process)))
                epoch_nr += 1
                first_new = count + len(states_to_process)

            state, moves_left = states_to_process.pop(0) # moves_left .. Number of moves left before time is over.
            if moves_left == 0:
                reward[(state, moves_left)] = 0
                continue
            move = policy(state)

            if state[0] == 5 and state[1]==state[3] and state[2]==0: # On key field but no key.
                impossible_states.append(state)
                continue
            if state[0] == 7 and state[1]==state[4] and state[2]==1: # end states
                end_states.append(state)
                reward[(state, moves_left)] = 0
                continue

            if move is None:
                print("Don't know how to move :0")
                print("State:", state)
            next_state, move_reward, finished = self.next_state(state, move)

            if finished:
                # reward[next_state] = 0
                reward[(next_state, moves_left)] = 0
            # if reward[next_state] is not None:
            if reward[(next_state, moves_left-1)] is not None:
                # reward[state] = reward[next_state] + move_reward
                reward[(state, moves_left)] = reward[(next_state, moves_left-1)] + move_reward
                # rew_ass_list.append(reward[state])
                rew_ass_list.append(reward[(state, moves_left)])
            else:
                # states_to_process.append(state)
                states_to_process.append((state, moves_left))

            if count == 1000000:
                print("\nCba!!!!\n")
                break
            count += 1

        reward_dict = reward
        def reward(state):
            # return reward_dict[tuple(state)]
            return reward_dict[(state, self.max_nrof_moves)]

        if not suppress_print: print("\nImpossible states:", set(impossible_states))
        if not suppress_print: print("\nEnd states:", set(end_states))
        
        self.reward_functions[policy] = reward
        return reward
    reward_function = property(get_reward_function)
     
    def get_average_reward(self, policy=None, suppress_print=True):
        reward_fn = self.get_reward_function(policy=policy, suppress_print=suppress_print)
        
        rew_sum = 0
        for posX in range(9):
            for posY in range(3):
                has_key = 0
                if (posX, posY) == Grid().key_pos:
                    has_key = 1
                rew = reward_fn((posX, posY, has_key, 1, 0))
                rew_sum += rew
        
        # print("\nAverage reward of {:.2f}.".format(rew_sum/9/3))
        return(rew_sum/9/3)
    average_reward = property(get_average_reward)
        
    def generate_data(self, policy=None, oracle=None, nrof_rollouts=1000, fixed_pos=False, fixed_key=True, fixed_door=True):
        if policy is None:
            policy = self.oracle_move
        if oracle is None:
            oracle = self.oracle_move
        
        (x, y), k, d = self.pos, self.key_init, self.door_init

        xs = []
        ys = []

        for _ in range(nrof_rollouts):
            if not fixed_pos:
                x = random.randrange(9)
                y = random.randrange(3)
            if not fixed_key:
                k = random.randrange(3)
            if not fixed_door:
                d = random.randrange(3)
            grid = Grid((x, y), k, d)

            oracle_move = oracle(state=grid.state)
            policy_move = policy(grid.state)

            if oracle_move is None: # = finished
                continue

            xs.append(grid.state)
            ys.append(oracle_move)

            while grid.move(policy_move, suppress_print=True):
                oracle_move = oracle(state=grid.state)
                policy_move = policy(grid.state)
                if oracle_move is None:
                    print(grid.state)
                    print(grid)
                    assert False
                xs.append(grid.state)
                ys.append(oracle_move)

        return xs, ys
    
grid = Grid().make_random()
print("Grid state:", grid.state)
print(grid)

def dir_str(direction):
    if direction == 0:
        return "\u2191"
    elif direction == 1:
        return "\u2192"
    elif direction == 2:
        return "\u2193"
    elif direction == 3:
        return "\u2190"
    else:
        return "?"

mo_st = "\n\nPossible moves: " + " ".join([dir_str(i)+"("+str(i)+")" for i in range(4)])
print(mo_st)

print("\n\nStandard grid:")
print(Grid())

Grid state: (7, 2, 0, 0, 2)
    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    | 🔑 |    |    |   
  --------------------------------------------
1     |    |    | 🔥 |    |    |    |    |   
  --------------------------------------------
2     |    |    | 🔥 |    |    |    | 🤠 |   


Possible moves: ↑(0) →(1) ↓(2) ←(3)


Standard grid:
    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    |    |    | 🚪 |   
  --------------------------------------------
1     | 🤠 |    | 🔥 |    | 🔑 |    |    |   
  --------------------------------------------
2     |    |    | 🔥 |    |    |    |    |   


In [3]:
print("Showing oracle on following grid:\n")

grid = Grid().make_random()
print(grid, "\n")

while grid.move(grid.oracle_move()):
    print("Move " + str(grid.move_count) + ":")
    print(grid, "\n\n")
    
print(grid)

Showing oracle on following grid:

    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    | 🔑 |    |    |   
  --------------------------------------------
1     |    |    | 🔥 |    |    |    |    |   
  --------------------------------------------
2     |    |    | 🤠 |    |    |    | 🚪 |    

Move 1:
    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    | 🔑 |    |    |   
  --------------------------------------------
1     |    |    | 🔥 |    |    |    |    |   
  --------------------------------------------
2     |    |    | 🔥 | 🤠 |    |    | 🚪 |    


Move 2:
    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    | 🔑 |    |    |   
  --------------------------------------------
1     |    |    | 🔥 |    |    |    |    |   
  --------------------------------------------
2     |    |    | 🔥 |    | 🤠 |    | 🚪 |    


Move 3:
    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    | 🔑 |    |    |   
  ----------------

In [4]:
print("Showing policy that always goes right:\n")
grid = Grid((1, 1), 1, 1)
print(grid, "\n\n")

while grid.move(1):
    print(grid, "\n\n")

Showing policy that always goes right:

    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    |    |    |    |   
  --------------------------------------------
1     | 🤠 |    | 🔥 |    | 🔑 |    | 🚪 |   
  --------------------------------------------
2     |    |    | 🔥 |    |    |    |    |    


    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    |    |    |    |   
  --------------------------------------------
1     |    | 🤠 | 🔥 |    | 🔑 |    | 🚪 |   
  --------------------------------------------
2     |    |    | 🔥 |    |    |    |    |    


Oouch, you got burned!
    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    |    |    |    |   
  --------------------------------------------
1     |    |    | 🤠 |    | 🔑 |    | 🚪 |   
  --------------------------------------------
2     |    |    | 🔥 |    |    |    |    |    


    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    |    |    |    |   
  ----------

In [5]:
print("Showing (hard coded) oracle policy:\n")
grid = Grid()
# print(grid)
policy = grid.oracle_move
grid.print_policy(policy, policy_name="Oracle policy");

Showing (hard coded) oracle policy:


Oracle policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     ↑ | 🤠 ↑ |    ↑ | 🔥 → |    → | 🔑 ? |    ← |    ← |    ←
  --------------------------------------------------------------
2     ↑ |    ↑ |    ↑ | 🔥 → |    → |    ↑ |    ← |    ← |    ←

Oracle policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↓ |    ←
  --------------------------------------------------------------
1     ↑ | 🤠 ↑ |    ↑ | 🔥 → |    → | .  → |    → |    ↑ |    ←
  --------------------------------------------------------------
2     ↑ |    ↑ |    ↑ | 🔥 → |    → |    → |    → |    ↑ |    ←


## Create reward function for each state:

In [6]:
reward = grid.reward_function

# grid = Grid().make_random()
grid = Grid()
# print("\n\nShowing reward function for following grid:")
# print(grid)

grid.add_key(remove=True).print_reward_fn(reward, small=True);


Future reward without key:
    0     1     2     3     4     5     6     7     8 
0   -9 |  -8 |  -7 |  -6 |  -5 |  -4 |  -5 | 🚪  |  -7
  -----------------------------------------------------
1  -10 | 🤠  |  -8 | 🔥  |  -4 | 🔑  |  -4 |  -5 |  -6
  -----------------------------------------------------
2  -11 | -10 |  -9 | 🔥  |  -5 |  -4 |  -5 |  -6 |  -7

Future reward with key:
    0     1     2     3     4     5     6     7     8 
0   -7 |  -6 |  -5 |  -4 |  -3 |  -2 |  -1 | 🚪  |  -1
  -----------------------------------------------------
1   -8 | 🤠  |  -6 | 🔥  |  -4 | .   |  -2 |  -1 |  -2
  -----------------------------------------------------
2   -9 |  -8 |  -7 | 🔥  |  -5 |  -4 |  -3 |  -2 |  -3


In [7]:
Q_fun_dict = {}
for state in grid.all_states:
    for direction in range(4):
        next_state, move_reward, finished = grid.next_state(state, direction)
        Q_fun_dict[(state, direction)] = reward(next_state) + move_reward
        
        
def Q_fn(state, direction):
    return Q_fun_dict[(state, direction)]
        
grid = Grid().make_random()
# grid = Grid().make_random().add_key()
print(grid)
print(grid.state_str)
print("Reward(state):", reward(grid.state), "\n")
print("Oracle move:", dir_str(grid.oracle_move()), "\n")

for direction in range(4):
    print("Q(state, {}) = {}".format(dir_str(direction), Q_fn(grid.state, direction)))

    0    1    2    3    4    5    6    7    8 
0     |    |    |    |    | 🔑 |    |    |   
  --------------------------------------------
1     |    |    | 🔥 | 🤠 |    |    |    |   
  --------------------------------------------
2     |    |    | 🔥 |    |    |    | 🚪 |   
Agent position: (4, 1), has key: 0, key and door positions: 0, 2
Reward(state): -6 

Oracle move: → 

Q(state, ↑) = -6
Q(state, →) = -6
Q(state, ↓) = -8
Q(state, ←) = -58


# Learning a decision tree
## 1) Imitation Learning:
Generate data:

In [31]:
xs, ys = Grid().generate_data(nrof_rollouts=10000)

print("Collected {} data points.".format(len(xs)))

Collected 63541 data points.


### Learn decision tree

In [32]:
max_leaf_nodes = 6
clf_imitation = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
clf_imitation = clf_imitation.fit(xs, ys)
nodes = []
# nodes = tree.plot_tree(clf_imidation)

imitation_policy = lambda state: clf_imitation.predict([state])[0]
Grid().print_policy(imitation_policy);

node_text = "\nNodes:\n"
node_fls = [node.get_text().split("\n")[0] for node in nodes]
node_text += "\n".join([node_fl for node_fl in node_fls if node_fl[0] != "g"])
# print(node_text)


Policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | 🔑 ↓ |    ← |    ← |    ←
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    ↓ |    ← |    ← |    ←

Policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | .  → |    → |    ↑ |    ↑
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    → |    → |    ↑ |    ↑


In [33]:
rew_grid = Grid()
imitation_reward_fn = rew_grid.get_reward_function(policy=imitation_policy, suppress_print=True)

grid = Grid()
grid.print_policy(imitation_policy, policy_name="Imitation policy");
rew_grid.add_key(remove=True).print_reward_fn(imitation_reward_fn, nr_len=4, show_diff=True)

average_reward = rew_grid.get_average_reward(policy=imitation_policy, suppress_print=True)
optimal_reward = rew_grid.average_reward
print("\n\nAverage reward: {:.2f} (optimal: {:.2f}).".format(average_reward, optimal_reward))


Imitation policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | 🔑 ↓ |    ← |    ← |    ←
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    ↓ |    ← |    ← |    ←

Imitation policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | .  → |    → |    ↑ |    ↑
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    → |    → |    ↑ |    ↑

Relative future reward without key:
      0         1         2         3         4         5         6         7         8 
0       0  |      0  |      0  |      0  |      0  |

## Dagger
Firstly, train on oracle data:

In [34]:
xs, ys = Grid().generate_data(nrof_rollouts=1000)
dagger_policies = []
dagger_rewards = []

print("Collected {} data points.".format(len(xs)))

clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
clf = clf.fit(xs, ys)
# nodes = tree.plot_tree(clf)
dagger_policies.append(lambda state: clf.predict([state])[0])

grid = Grid((1,1),1,0)
policy = lambda state: clf.predict([state])[0]
grid.print_policy(policy);

curr_dagger_reward_fn = rew_grid.get_reward_function(policy=dagger_policies[-1], suppress_print=True)
rew_grid.add_key(remove=True).print_reward_fn(curr_dagger_reward_fn, show_diff=True)

average_reward = rew_grid.get_average_reward(policy=dagger_policies[-1], suppress_print=True)
optimal_reward = rew_grid.average_reward
print("\n\nAverage reward: {:.2f} (optimal: {:.2f}).".format(average_reward, optimal_reward))

node_text = "\nNodes:\n"
node_fls = [node.get_text().split("\n")[0] for node in nodes]
node_text += "\n".join([node_fl for node_fl in node_fls if node_fl[0] != "g"])
# print(node_text)

Collected 6294 data points.

Policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | 🔑 ↓ |    ← |    ← |    ←
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    ↓ |    ← |    ← |    ←

Policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | .  → |    → |    ↑ |    ↑
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    → |    → |    ↑ |    ↑

Relative future reward without key:
      0        1        2        3        4        5        6        7        8 
0      0  |     0  |     0  |     0  |     0  |     

Second: Train with data obtained following old policy. **Run folling cell multiple times.**

In [40]:
print("Doing another step of dataset aggregation (Done so far: {})".format(len(dagger_policies)))

new_xs, new_ys = Grid().generate_data(policy=dagger_policies[-1], nrof_rollouts=1000)
xs += new_xs
ys += new_ys
print("Extended dataset to size {}.".format(len(xs)))    


clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
clf = clf.fit(xs, ys)
# nodes = tree.plot_tree(clf)
dagger_policies.append(lambda state: clf.predict([state])[0])
Grid().print_policy(dagger_policies[-1])

curr_dagger_reward_fn = rew_grid.get_reward_function(policy=dagger_policies[-1], suppress_print=True)
rew_grid.add_key(remove=True).print_reward_fn(curr_dagger_reward_fn, show_diff=True)

# dagger_policies.append(lambda state: clf.predict([state])[0])
average_reward = rew_grid.get_average_reward(policy=dagger_policies[-1], suppress_print=True)
optimal_reward = rew_grid.average_reward
print("\n\nAverage reward: {:.2f} (optimal: {:.2f}).".format(average_reward, optimal_reward))

node_text = "\nNodes:\n"
node_fls = [node.get_text().split("\n")[0] for node in nodes]
node_text += "\n".join([node_fl for node_fl in node_fls if node_fl[0] != "g"])
# print(node_text)


Doing another step of dataset aggregation (Done so far: 6)
Extended dataset to size 99201.

Policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | 🔑 ↓ |    ← |    ← |    ←
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    ↑ |    ← |    ← |    ←

Policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | .  → |    → |    ↑ |    ↑
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    → |    → |    ↑ |    ↑

Relative future reward without key:
      0        1        2        3        4        5        6        7

In [41]:

for i, policy in enumerate(dagger_policies):
    # print(rew_grid.get_average_reward(policy=policy, suppress_print=True))
    print("r{}: {:.1f}".format(i+1, rew_grid.get_average_reward(policy=policy, suppress_print=True)))
    

r1: -23.3
r2: -25.6
r3: -23.3
r4: -38.0
r5: -26.9
r6: -8.6
r7: -17.1


## Viper
### Learn basic policy

In [42]:
xs, ys = Grid().generate_data(nrof_rollouts=1000)
print("Collected {} data points.".format(len(xs)))

viper_policies = []
clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
clf = clf.fit(xs, ys)
clfs = [clf]
# nodes = tree.plot_tree(clf)
viper_policies.append(lambda state: clf.predict([state])[0])
Grid().print_policy(viper_policies[-1])

curr_viper_reward_fn = rew_grid.add_key(remove=True).get_reward_function(policy=viper_policies[-1], suppress_print=True)
rew_grid.add_key(remove=True).print_reward_fn(curr_viper_reward_fn, show_diff=True)
# rew_grid.add_key(remove=True).print_reward_fn(curr_viper_reward_fn, show_diff=False)

average_reward = rew_grid.get_average_reward(policy=viper_policies[-1], suppress_print=True)
optimal_reward = rew_grid.average_reward
print("\n\nAverage reward: {:.2f} (optimal: {:.2f}).".format(average_reward, optimal_reward))

node_text = "\nNodes:\n"
node_fls = [node.get_text().split("\n")[0] for node in nodes]
node_text += "\n".join([node_fl for node_fl in node_fls if node_fl[0] != "g"])
# print(node_text)

Collected 6390 data points.

Policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | 🔑 ↓ |    ← |    ← |    ←
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    ↓ |    ← |    ← |    ←

Policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     → | 🤠 → |    → | 🔥 → |    → | .  → |    → |    ↑ |    ↑
  --------------------------------------------------------------
2     → |    → |    → | 🔥 → |    → |    → |    → |    ↑ |    ↑

Relative future reward without key:
      0        1        2        3        4        5        6        7        8 
0      0  |     0  |     0  |     0  |     0  |     

## Sample from states according to l-function:
**Run the next two cells multiple times!**

In [52]:
# Generate Set of all states in data
xs_set = set(xs)
# print("Reduced dataset of size {} to set of size {}.".format(len(xs), len(xs_set)))
print("Found {} different states in dataset of size {}.".format(len(xs), len(xs_set)))


# Calculate l for those states
l = {}
grid = Grid()
for state in xs_set:
    rew = grid.reward_function(state)
    minQ = None
    for direction in range(4):
        q_fn_val = Q_fn(state, direction)
        if minQ is None:
            minQ = q_fn_val
        else:
            minQ = min(minQ, q_fn_val)
    l[state] = rew - minQ

l_dict = l
# l = lambda state: l_dict[state]
# l_str = lambda grid_symb, state: " {:3d} ".format(l(state)) if grid_symb==" " else "{} {:3d}".format(grid_symb, l(state))
def l_str(grid_symb, state): 
    val_str = "{:3d}".format(l_dict[state]) if state in l_dict.keys() else " X "
    if grid_symb==" ":
        return " " + val_str + " "
    else:
        return "{} {}".format(grid_symb, val_str)


# print(l_dict)
grid.print_fn_on_grid(l_str, fn_name="l-function");
max_l = max(*l_dict.values())
# print(max_l)

# Sample datapoint and reject wp l
xs_dash, ys_dash = [], []
count = {state:0 for state in Grid().all_states}
while len(xs_dash) < 10000:
    # cand = random.choice(list(xs_set))  # resample from unique data points
    cand = random.choice(xs)  # resample from data
    p = random.uniform(0,1)
    if p < l_dict[cand]/max_l:
        xs_dash.append(cand)
        ys_dash.append(Grid().oracle_move(cand))
        count[cand]+=1
        
count_fn = lambda symb, state: "{:4d}".format(count[state])
Grid().print_fn_on_grid(count_fn, fn_name="Data samples");
        
# print(len(xs_dash))


Found 69742 different states in dataset of size 31.

l-function without key:
      0        1        2        3        4        5        6        7        8 
0       2 |      2 |      2 |     50 |      2 |      2 |      2 | 🚪   2 |      1
  --------------------------------------------------------------------------------
1       2 | 🤠   2 |     48 | 🔥  52 |     52 | 🔑  X  |      2 |      2 |      2
  --------------------------------------------------------------------------------
2       1 |      2 |     48 | 🔥  51 |     52 |      2 |      2 |      2 |      1

l-function with key:
      0        1        2        3        4        5        6        7        8 
0      X  |     X  |     X  |     X  |     X  |      2 |     X  | 🚪  X  |     X 
  --------------------------------------------------------------------------------
1      X  | 🤠  X  |     X  | 🔥  X  |     X  | .    2 |      2 |      2 |      2
  --------------------------------------------------------------------------------
2    

## Train classifier from sampled data:

In [53]:
clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
clf = clf.fit(xs_dash, ys_dash)
# nodes = tree.plot_tree(clf)
# viper_policies.append(lambda state: clf.predict([state])[0])
viper_policies.append(lambda state: clf.predict([state])[0])
clfs.append(clf)
Grid().print_policy(viper_policies[-1])

curr_viper_reward_fn = rew_grid.get_reward_function(policy=viper_policies[-1], suppress_print=True)
rew_grid.add_key(remove=True).print_reward_fn(curr_viper_reward_fn, show_diff=True)

# dagger_policies.append(lambda state: clf.predict([state])[0])
average_reward = rew_grid.get_average_reward(policy=viper_policies[-1], suppress_print=True)
optimal_reward = rew_grid.average_reward
print("\n\nAverage reward: {:.2f} (optimal: {:.2f}).".format(average_reward, optimal_reward))

node_text = "\nNodes:\n"
node_fls = [node.get_text().split("\n")[0] for node in nodes]
node_text += "\n".join([node_fl for node_fl in node_fls if node_fl[0] != "g"])
# print(node_text)

new_xs, new_ys = Grid().generate_data(policy=viper_policies[-1], nrof_rollouts=1000)
xs += new_xs
ys += new_ys



Policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     ↑ | 🤠 ↑ |    ↑ | 🔥 → |    → | 🔑 ↓ |    ← |    ← |    ←
  --------------------------------------------------------------
2     ↑ |    ↑ |    ↑ | 🔥 → |    → |    ↓ |    ← |    ← |    ←

Policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    ← | 🚪 ← |    ←
  --------------------------------------------------------------
1     ↑ | 🤠 ↑ |    ↑ | 🔥 → |    → | .  → |    ← |    ← |    ←
  --------------------------------------------------------------
2     ↑ |    ↑ |    ↑ | 🔥 → |    → |    → |    ← |    ← |    ←

Relative future reward without key:
      0        1        2        3        4        5        6        7        8 
0    -16  |   -17  |   -18  |   -19  |   -20  |   -21  |   -20  | 🚪 -19 |   -18 

In [54]:
for policy in viper_policies:
    # print(policy)
    print(rew_grid.get_average_reward(policy=policy, suppress_print=True))

-23.333333333333332
-15.111111111111112
-25.0
-25.0
-25.0


## Show best policy:

In [55]:
best_pol_idx = -1
best_pol_val = None
for i, policy in enumerate(viper_policies):
    pol_val = rew_grid.get_average_reward(policy=policy, suppress_print=True)
    if best_pol_val is None or best_pol_val < pol_val:
        best_pol_val = pol_val
        best_pol_idx = i

print("Showing best viper policy: (idx {})".format(best_pol_idx))

# Grid().print_policy(viper_policies[best_pol_idx]) # prints wrong policy it seems.
pol = lambda state: clfs[best_pol_idx].predict([state])[0]
Grid().print_policy(pol)


best_pol_reward_fn = rew_grid.get_reward_function(policy=viper_policies[best_pol_idx], suppress_print=True)
rew_grid.add_key(remove=True).print_reward_fn(best_pol_reward_fn, show_diff=True)

print("\n\nAverage reward: {:.2f} (optimal: {:.2f}).".format(best_pol_val, optimal_reward))


Showing best viper policy: (idx 1)

Policy without key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    ↓ |    ↓ | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     ↑ | 🤠 ↑ |    ↑ | 🔥 → |    → | 🔑 ↓ |    ↓ |    ↑ |    ↑
  --------------------------------------------------------------
2     ↑ |    ↑ |    ↑ | 🔥 → |    → |    ↓ |    ↓ |    ↑ |    ↑

Policy with key:
     0      1      2      3      4      5      6      7      8 
0     → |    → |    → |    → |    → |    → |    → | 🚪 ↑ |    ↑
  --------------------------------------------------------------
1     ↑ | 🤠 ↑ |    ↑ | 🔥 → |    → | .  → |    → |    ↑ |    ↑
  --------------------------------------------------------------
2     ↑ |    ↑ |    ↑ | 🔥 → |    → |    → |    → |    ↑ |    ↑

Relative future reward without key:
      0        1        2        3        4        5        6        7        8 
0      0  |     0  |     0  |     0  |     0 