In [1]:
%matplotlib inline

from __future__ import unicode_literals
import matplotlib
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['text.latex.unicode'] = True

import math
from itertools import product
import gym
from gym import spaces, logger
from gym.utils import seeding
import numpy as np
import time
import matplotlib.patches as patches

import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython import display
from IPython.display import HTML, Image, clear_output
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots

In [2]:
class GridWorldEnv(gym.Env):
    
    #The action space constitutes of the following moves:
    #    0: Right
    #    1: Up
    #    2: Left
    #    3: Down
    
    metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second' : 100
    }
    
    def __init__(self, grid_size = 5, rewarded_locations = [[0,1],[0,3]], 
                 rewarded_targets=[[4,1],[2,3]], reward_vals = [10, 5], 
                 hard_version=False):
        
        #Defining the grid size
        if isinstance(grid_size, int):
            self.width=self.height=grid_size
        else:
            self.width = grid_size[0]
            self.height = grid_size[1]
            
        
        #Defining A,B,A',B'
        self.rewarded_locations = np.array(rewarded_locations, dtype=np.int)
        self.rewarded_targets = np.array(rewarded_targets, dtype=np.int)
        self.reward_vals = np.array(reward_vals)
        
        #Defining the action space and observation space
        self.action_space = spaces.Discrete(4)
        self.move_dict = {0: np.array([ 1, 0]),
                          1: np.array([ 0, 1]),
                          2: np.array([-1, 0]),
                          3: np.array([ 0,-1])}
        self.observation_space = spaces.MultiDiscrete([self.width, self.height])
        
        self.hard_version=hard_version
        
        #Setting the seed
        self.seed()
        
        self.viewer = None
        self.state = None
        self.exp_history=[]
        
    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    def reset(self):
        self.state = self.observation_space.sample()
        self.steps_beyond_done = None
        self.exp_history.append({'states':[], 'actions':[], 'rewards':[]})
        return np.array(self.state)
    
    def step(self, action):
        assert self.observation_space.contains(self.state), "%r (%s) invalid"%(state, type(state))
        assert self.action_space.contains(action), "%r (%s) invalid"%(action, type(action))
        
        reward = None
        
        if self.hard_version:
            if np.random.random_sample() < 0.3:
                action = np.random.choice(4)
        
        for i,loc in enumerate(self.rewarded_locations):
            if (self.state[0] == loc[0] and self.state[1] == loc[1]): 
                reward = self.reward_vals[i]
                self.state = self.rewarded_targets[i,:]
                
        if reward is None:
            movement = self.move_dict[action]
            next_state = self.state + movement
            if self.observation_space.contains(next_state):
                reward = 0
                self.state = next_state
            else:
                reward = -1
        
        done=False
        
        self.exp_history[-1]['states'].append(self.state)
        self.exp_history[-1]['actions'].append(action)
        self.exp_history[-1]['rewards'].append(reward)
        
        return np.array(self.state), reward, done, {}
    
    def render(self, mode='human'):
        screen_width = 600
        screen_height = 600
        
        top_left = [50,50]
        cell_size = 100
        point_size = 30
        
        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(screen_width, screen_height)
            
            for i in range(self.width):
                for j in range(self.height):
                    l = i*cell_size + top_left[0]
                    r = (i+1)*cell_size + top_left[0]
                    t = j*cell_size + top_left[1]
                    b = (j+1)*cell_size + top_left[1]
                    
                    cart = rendering.PolyLine([(l,b), (l,t), (r,t), (r,b)], True)
                    cart.set_color(0,0,0)
                    carttrans = rendering.Transform()
                    cart.add_attr(carttrans)
                    self.viewer.add_geom(cart)
            
            i = 0
            j = 0

            l = top_left[0] + cell_size/2 - point_size/2
            r = top_left[0] + cell_size/2 + point_size/2
            t = top_left[1] + cell_size/2 - point_size/2
            b = top_left[1] + cell_size/2 + point_size/2
            self.point = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
            self.point.set_color(1,0,0)
            #FilledPolygon
            
            self.pointtrans = rendering.Transform(translation=(0, 0))
            self.point.add_attr(self.pointtrans)
            self.viewer.add_geom(self.point)
            
        #Drawing the point
        if self.state is None: return None
        
        i = self.state[0]
        j = self.state[1]

        x = i*cell_size 
        y = j*cell_size
        self.pointtrans.set_translation(x, y)
        
        return self.viewer.render(return_rgb_array = mode=='rgb_array')
    
    def plt_render(self):
        
        point_radius = 0.1
        
        all_states = self.exp_history[-1]['states']
        all_actions = self.exp_history[-1]['actions']
        all_rewards = self.exp_history[-1]['rewards']
        
        row = [s[0] for _,s in enumerate(all_states)]
        col = [s[1] for _,s in enumerate(all_states)]
        time_array = np.arange(len(row))
        
        self.figure = plt.figure()
        self.figure.set_size_inches(15, 15, forward=True)
        
        self.act_ax = plt.subplot2grid((4, 4), (0, 0))
        self.row_ax = plt.subplot2grid((4, 4), (0, 1))
        self.reward_ax = plt.subplot2grid((4, 4), (1, 0))
        self.col_ax = plt.subplot2grid((4, 4), (1, 1))
        self.traj_ax = plt.subplot2grid((4, 4), (2, 0), colspan=2, rowspan=2)
        
        self.traj_ax.set_xlim([-0.5, self.width - 0.5])
        self.traj_ax.set_ylim([-0.5, self.height - 0.5])
        
        plot_ax_list = [self.act_ax, self.row_ax,
                        self.reward_ax, self.col_ax]
        plot_data_list = [all_actions, row, all_rewards, col]
        plot_title_list = [r'$a$', r'$row$',
                           r'$R$', r'$column$']
        
        
        for i,curr_ax in enumerate(plot_ax_list):
            curr_data=plot_data_list[i]
            curr_ax.set_xlim([0,time_array[-1]])
            if np.min(curr_data) == np.max(curr_data):
                curr_ax.set_ylim([np.min(curr_data) - 1, np.max(curr_data) + 1])
            else:
                curr_ax.set_ylim([1.05*np.min(curr_data),1.05*np.max(curr_data)])
            curr_ax.set_title(plot_title_list[i], fontsize=16)
        
        for k in range(self.height-1):
            self.traj_ax.axhline(y = k+0.5)
        for k in range(self.width-1):
            self.traj_ax.axvline(x = k+0.5)
        self.point = patches.Circle((0,0),point_radius,linewidth=10,edgecolor='r',facecolor='r')
        self.traj_ax.add_patch(self.point)
        
        #point.set_center((0,0))
        traj_title = self.traj_ax.set_title('Trajectory', fontsize=16)
        
        self.plot_lines = []
        for i,curr_ax in enumerate(plot_ax_list):
            curr_line, = curr_ax.plot([], [], color='k')
            self.plot_lines.append(curr_line)

        def init():
            self.point.center = (0,0)
            traj_title.set_text('Trajectory')
            for curr_line in self.plot_lines:
                curr_line.set_data([], [])
            return self.point, self.plot_lines[0], self.plot_lines[1], self.plot_lines[2], \
                    self.plot_lines[3]

        def animate(i):
            #traj_title.set_text('Trajectory (time= '+str(i)+',state='+str(all_states[i])+
            #', c='+str([col[i], self.height - row[i]-1])+')')
            traj_title.set_text('Trajectory (time= '+str(i)+')')
            self.point.center=(col[i], self.height - row[i] -1 )
            
            for j,curr_line in enumerate(self.plot_lines):
                curr_data = plot_data_list[j]
                curr_line.set_data(time_array[:i], curr_data[:i])
            return self.point, self.plot_lines[0], self.plot_lines[1], self.plot_lines[2], \
                    self.plot_lines[3]

        self.ani = animation.FuncAnimation(self.figure, animate, np.arange(len(row)),
                                           interval=25, blit=True, init_func=init)
    
    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

In [3]:
class StateActionTable():
    def __init__(self, state_sizes, action_sizes, use_obj_dtype = False):
        self.state_shape = np.array(state_sizes).reshape(-1)
        if action_sizes is None:
            self.action_shape = None
            if use_obj_dtype:
                self.data = np.zeros(tuple(self.state_shape), dtype=object)
            else:
                self.data = np.zeros(tuple(self.state_shape))
        else:
            self.action_shape = np.array(action_sizes).reshape(-1)
            shape = np.concatenate([self.state_shape, self.action_shape], axis = 0)
            if use_obj_dtype:
                self.data = np.zeros(tuple(shape), dtype=object)
            else:
                self.data = np.zeros(tuple(shape))
    
    def __getitem__(self, pos):
        if self.action_shape is None:
            state = np.array(pos).reshape(-1)
            return self.data[tuple(state)]
        else:
            state, action = pos
            state = np.array(state).reshape(-1)
            action = np.array(action).reshape(-1)
            return self.data[tuple(list(state) + list(action))]
    
    def __setitem__(self, pos, val):
        if self.action_shape is None:
            state = np.array(pos).reshape(-1)
            self.data[tuple(state)] = val
        else:
            state, action = pos
            state = np.array(state).reshape(-1)
            action = np.array(action).reshape(-1)
            self.data[tuple(list(state) + list(action))] = val

def get_space_shape(space):
    if isinstance(space, gym.spaces.discrete.Discrete):
        return space.n
    if isinstance(space, gym.spaces.multi_discrete.MultiDiscrete):
        return space.nvec
    
    raise 'Not implemented'

In [4]:
#SARSA
env = GridWorldEnv()
alpha = 0.1
gamma = 0.9
lamda = 1
epsilon = 0.9
episode_length = 200
num_episodes = 100

action_space_shape = get_space_shape(env.action_space)
state_space_shape = get_space_shape(env.observation_space)
Q = StateActionTable(state_space_shape , action_space_shape)
e = StateActionTable(state_space_shape , action_space_shape)

last_print_time = time.time()
for i_episode in range(num_episodes):
    observation = env.reset()
    s = env.state
    a = env.action_space.sample()
    for t in range(episode_length):
        if time.time() - last_print_time > 0.2:
            clear_output(wait=True)
            print('Episode: ' +str(i_episode))
            print('Time step: ' +str(t))
            last_print_time = time.time()            
        observation, r, done, info = env.step(a)
        s_prime = env.state
        if np.random.random() < epsilon:
            a_prime = env.action_space.sample()
        else:
            canid_q = Q[s_prime, :]
            a_prime = np.unravel_index(canid_q.argmax(), canid_q.shape)
            if len(a_prime) == 1:
                a_prime = a_prime[0]
            else:
                a_prime = list(a_prime)
        e[s,a] = e[s,a] + 1
        delta = r + gamma * Q[s_prime, a_prime] - Q[s,a]
        for s_tilda in product(*(range(y) for y in Q.state_shape)):
            for a_tilda in product(*(range(z) for z in Q.action_shape)):
                s_tilda = np.array(s_tilda)
                a_tilda = np.array(a_tilda)
                Q[s_tilda, a_tilda] = Q[s_tilda, a_tilda] + alpha * delta * e[s_tilda, a_tilda]
                e[s_tilda, a_tilda] = gamma * lamda * e[s_tilda, a_tilda]
        s = s_prime
        a = a_prime    
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break

Episode: 99
Time step: 162


In [5]:
%matplotlib inline
test_env = GridWorldEnv(hard_version=True)
policy = StateActionTable(state_sizes = state_space_shape,
                          action_sizes = None, 
                          use_obj_dtype = True)

for s_tilda in product(*(range(y) for y in Q.state_shape)):
    canid_q = Q[s_tilda, :]
    best_a = np.unravel_index(canid_q.argmax(), canid_q.shape)
    if len(best_a) == 1:
        policy[s_tilda] = best_a[0]
    else:
        policy[s_tilda] = list(best_a)

for i_episode in range(1):
    observation = test_env.reset()
    for t in range(100):
        action = policy[test_env.state]
        observation, reward, done, info = test_env.step(action)
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break

In [6]:
%matplotlib inline
test_env.plt_render()
show_inline_matplotlib_plots()
clear_output()

In [7]:
HTML(test_env.ani.to_jshtml(fps = 5))

In [9]:
test_env.ani.save('./sarsa_sim.gif', writer=animation.PillowWriter(fps=10))