# All this section consists in the implementation of the Mirror Descent with the relative entropy. This work is based on the paper "A Unified View of Entropy-Regularized Markov Decision Processes" with the pseudo code detailed in the paper "Relative Entropy Policy Search" from Jan Peters.

## GRID RENDER

In [52]:
from tkinter import *
import numpy as np
import matplotlib.pyplot as plt

class GUI(Canvas):
    def __init__(self, master, *args, **kwargs):
        Canvas.__init__(self, master=master, *args, **kwargs)


def draw_square_q(polygon, x, y, q, actions, dim=50):
    polygon.create_polygon([x, y, x + dim, y, x + dim, y + dim, x, y + dim], outline='black',
                           fill='white', width=2)

    font = ('Helvetica', '30', 'bold')

    for i, a in enumerate(actions):
        if a == 0:
            polygon.create_polygon([x + dim, y, x + dim / 2., y + dim / 2., x + dim, y + dim], outline='gray',
                                   fill='red', width=2)
            polygon.create_text(x + 3 * dim / 4., y + dim / 2., font=font, text="{:.3f}".format(q[i]), anchor='center')
        elif a == 1:
            polygon.create_polygon([x, y + dim, x + dim / 2., y + dim / 2., x + dim, y + dim], outline='gray',
                                   fill='green', width=2)
            polygon.create_text(x + dim / 2., y + 3 * dim / 4., font=font, text="{:.3f}".format(q[i]), anchor='n')
        elif a == 2:
            polygon.create_polygon([x, y, x + dim / 2., y + dim / 2., x, y + dim], outline='gray',
                                   fill='yellow', width=2)
            polygon.create_text(x + dim / 4., y + dim / 2., font=font, text="{:.3f}".format(q[i]), anchor='center')
        elif a == 3:
            polygon.create_polygon([x + dim, y, x + dim / 2., y + dim / 2., x, y], outline='gray',
                                   fill='purple', width=2)
            polygon.create_text(x + dim / 2., y + dim / 4., font=font, text="{:.3f}".format(q[i]), anchor='s')


def draw_square_policy(w, x, y, pol, actions, dim=50):
    w.create_polygon([x, y, x + dim, y, x + dim, y + dim, x, y + dim], outline='black',
                     fill='white', width=2)

    font = ('Helvetica', '30', 'bold')
    if (hasattr(pol, "size") and pol.size > 1) or isinstance(pol, list):
        d = pol
    else:
        d = [-1] * len(actions)
        idx = actions.index(pol)
        d[idx] = 1

    for j, v in enumerate(d):
        if j < len(actions):
            a = actions[j]
            if a == 0 and v > 0:
                w.create_line(x + dim / 2., y + dim / 2., x + 3*dim / 4., y + dim / 2., tags=("line",), arrow="last")
                if not np.isclose(v, 1.):
                    w.create_text(x + 3*dim / 4., y + dim / 2., font=font, text="{:.1f}".format(v), anchor='w')
            elif a == 1 and v > 0:
                w.create_line(x + dim / 2., y + dim / 2., x + dim / 2., y + 3* dim / 4., tags=("line",), arrow="last")
                if not np.isclose(v, 1.):
                    w.create_text(x + dim / 2., y + 3*dim / 4., font=font, text="{:.1f}".format(v), anchor='n')
            elif a == 2 and v >0:
                w.create_line(x + dim / 2., y + dim / 2., x+dim/4., y + dim/2., tags=("line",), arrow="last")
                if not np.isclose(v, 1.):
                    w.create_text(x + dim / 4., y + dim / 2., font=font, text="{:.1f}".format(v), anchor='e')
            elif a == 3 and v >0:
                w.create_line(x + dim / 2., y + dim / 2., x + dim / 2., y + dim / 4., tags=("line",), arrow="last")
                if not np.isclose(v, 1.):
                    w.create_text(x + dim / 2., y + dim / 4., font=font, text="{:.1f}".format(v), anchor='s')


def render_q(env, q):
    root = Tk()
    w = GUI(root)
    rows, cols = len(env.grid), max(map(len, env.grid))
    dim = 200
    w.config(width=cols * (dim + 12), height=rows * (dim + 12))
    for s in range(env.n_states):
        r, c = env.state2coord[s]
        draw_square_q(w, 10 + c * (dim + 4), 10 + r * (dim + 4), dim=dim, q=q[s],
                      actions=env.state_actions[s])
        w.pack()
    w.pack()
    root.mainloop()


def render_policy(env, d):
    root = Tk()
    w = GUI(root)
    rows, cols = len(env.grid), max(map(len, env.grid))
    dim = 200
    w.config(width=cols * (dim + 12), height=rows * (dim + 12))
    for s in range(env.n_states):
        r, c = env.state2coord[s]
        draw_square_policy(w, 10 + c * (dim + 4), 10 + r * (dim + 4), dim=dim, pol=d[s],
                           actions=[i for i in range(env.n_states)])
        w.pack()
    w.pack()
    root.mainloop()

## GRID WORLD

In [None]:
import numpy as np
import numbers
from tkinter import Tk
import tkinter.font as tkFont
import copy



class GridWorld:
    def __init__(self, gamma=0.95, grid=None, render=False):
        self.grid = grid

        self.action_names = np.array(['right', 'down', 'left', 'up'])

        self.n_rows, self.n_cols = len(self.grid), max(map(len, self.grid))

        # Create a map to translate coordinates [r,c] to scalar index
        # (i.e., state) and vice-versa
        self.coord2state = np.empty_like(self.grid, dtype=np.int)
        self.n_states = 0
        self.n_actions = 4
        self.state2coord = []
        for i in range(self.n_rows):
            for j in range(len(self.grid[i])):
                if self.grid[i][j] != 'x':
                    self.coord2state[i, j] = self.n_states
                    self.n_states += 1
                    self.state2coord.append([i, j])
                else:
                    self.coord2state[i, j] = -1

        # compute the actions available in each state
        self.compute_available_actions()
        self.gamma = gamma
        self.proba_succ = 0.9
        self.render = render

    def reset(self):
        """
        Returns:
            An initial state randomly drawn from
            the initial distribution
        """
        n_states = self.n_states
        u = 0.9
        a = np.zeros((n_states,))
        a[0] = 0.5
        a[3] = -0.8
        u = np.power(np.ones((n_states,)) + a, u)
        p = np.exp(u) / np.sum(np.exp(u))
        x_0 = np.random.choice(np.arange(n_states), p=p)
        return x_0

    def step(self, state, action):
        """
        Args:
            state (int): the amount of good
            action (int): the action to be executed

        Returns:
            next_state (int): the state reached by performing the action
            reward (float): a scalar value representing the immediate reward
            absorb (boolean): True if the next_state is absorsing, False otherwise
        """
        r, c = self.state2coord[state]
        #assert action in self.state_actions[state]
        if isinstance(self.grid[r][c], numbers.Number):
            return state, 0, True
        else:
            failed = np.random.rand(1) > self.proba_succ
            if action == 0:
                c = min(self.n_cols - 1, c + 1) if not failed else max(0, c - 1)
            elif action == 1:
                r = min(self.n_rows - 1, r + 1) if not failed else max(0, r - 1)
            elif action == 2:
                c = max(0, c - 1) if not failed else min(self.n_cols - 1, c + 1)
            elif action == 3:
                r = max(0, r - 1) if not failed else min(self.n_rows - 1, r + 1)

            if self.grid[r][c] == 'x':
                next_state = state
                r, c = self.state2coord[next_state]
            elif action not in self.state_actions[state]:
                next_state = state
                r, c = self.state2coord[next_state]
            else:
                next_state = self.coord2state[r, c]
            if isinstance(self.grid[r][c], numbers.Number):
                reward = self.grid[r][c]
                absorb = True
            else:
                reward = 0.
                absorb = False

        if self.render:
            self.show(state, action, next_state, reward)

        return next_state, reward, absorb

    def show(self, state, action, next_state, reward):
        dim = 200
        rows, cols = len(self.grid) + 0.5, max(map(len, self.grid))
        if not hasattr(self, 'window'):
            root = Tk()
            self.window = GUI(root)

            self.window.config(width=cols * (dim + 12), height=rows * (dim + 12))
            my_font = tkFont.Font(family="Arial", size=32, weight="bold")
            for s in range(self.n_states):
                r, c = self.state2coord[s]
                x, y = 10 + c * (dim + 4), 10 + r * (dim + 4)
                if isinstance(self.grid[r][c], numbers.Number):
                    self.window.create_polygon([x, y, x + dim, y, x + dim, y + dim, x, y + dim], outline='black',
                                               fill='blue', width=2)
                    self.window.create_text(x + dim / 2., y + dim / 2., text="{:.1f}".format(self.grid[r][c]),
                                            font=my_font, fill='white')
                else:
                    self.window.create_polygon([x, y, x + dim, y, x + dim, y + dim, x, y + dim], outline='black',
                                               fill='white', width=2)
            self.window.pack()

        my_font = tkFont.Font(family="Arial", size=32, weight="bold")

        r0, c0 = self.state2coord[state]
        r0, c0 = 10 + c0 * (dim + 4), 10 + r0 * (dim + 4)
        x0, y0 = r0 + dim / 2., c0 + dim / 2.
        r1, c1 = self.state2coord[next_state]
        r1, c1 = 10 + c1 * (dim + 4), 10 + r1 * (dim + 4)
        x1, y1 = r1 + dim / 2., c1 + dim / 2.

        if hasattr(self, 'oval2'):
            # self.window.delete(self.line1)
            # self.window.delete(self.oval1)
            self.window.delete(self.oval2)
            self.window.delete(self.text1)
            self.window.delete(self.text2)

        # self.line1 = self.window.create_arc(x0, y0, x1, y1, dash=(3,5))
        # self.oval1 = self.window.create_oval(x0 - dim / 20., y0 - dim / 20., x0 + dim / 20., y0 + dim / 20., dash=(3,5))
        self.oval2 = self.window.create_oval(x1 - dim / 5., y1 - dim / 5., x1 + dim / 5., y1 + dim / 5., fill='red')
        self.text1 = self.window.create_text(dim, (rows - 0.25) * (dim + 12), font=my_font,
                                             text="r= {:.1f}".format(reward), anchor='center')
        self.text2 = self.window.create_text(2 * dim, (rows - 0.25) * (dim + 12), font=my_font,
                                             text="action: {}".format(self.action_names[action]), anchor='center')
        self.window.update()

    def compute_available_actions(self):
        # define available actions in each state
        # actions are indexed by: 0=right, 1=down, 2=left, 3=up
        self.state_actions = []
        for i in range(self.n_rows):
            for j in range(self.n_cols):
                if isinstance(self.grid[i][j], numbers.Number):
                    self.state_actions.append([0])
                elif self.grid[i][j] != 'x':
                    actions = [0, 1, 2, 3]
                    if i == 0:
                        actions.remove(3)
                    if j == self.n_cols - 1:
                        actions.remove(0)
                    if i == self.n_rows - 1:
                        actions.remove(1)
                    if j == 0:
                        actions.remove(2)

                    for a in copy.copy(actions):
                        r, c = i, j
                        if a == 0:
                            c = min(self.n_cols - 1, c + 1)
                        elif a == 1:
                            r = min(self.n_rows - 1, r + 1)
                        elif a == 2:
                            c = max(0, c - 1)
                        else:
                            r = max(0, r - 1)
                        if self.grid[r][c] == 'x':
                            actions.remove(a)

                    self.state_actions.append(actions)


grid1 = [
    ['', '', '', 1],
    ['', 'x', '', -1],
    ['', '', '', '']
]
env = GridWorld(gamma=0.95, grid=grid1)


In [57]:
import scipy
import scipy.optimize as opt
import pdb

def collect_episodes(mdp, policy=None, horizon=None, n_episodes=1, render=False):
    paths = []

    for _ in range(n_episodes):
        observations = []
        actions = []
        rewards = []
        next_states = []

        state = mdp.reset()
        for _ in range(horizon):
            action = policy.draw_action(state)
            next_state, reward, terminal = mdp.step(state,action)
            if render:
                mdp.render()
            observations.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            state = copy.copy(next_state)
            if terminal:
                # Finish rollout if terminal state reached
                break
                # We need to compute the empirical return for each time step along the
                # trajectory

        paths.append(dict(
            states=np.array(observations),
            actions=np.array(actions),
            rewards=np.array(rewards),
            next_states=np.array(next_states)
        ))
    return paths

def compute_new_policy(eta,policy,phi,theta,samples):
    log_new_pi = np.zeros((policy.n_states,policy.n_actions))
    A = np.zeros((policy.n_states,policy.n_actions))
    counter = np.zeros((policy.n_states,policy.n_actions))
    nb_samples = 0
    for i in range(len(samples)):
        states = samples[i]['states']
        actions = samples[i]['actions']
        rewards = samples[i]['rewards']
        next_states = samples[i]['next_states']

        for j in range(len(states)):
            A[states[j],actions[j]] += rewards[j] + np.dot(phi[next_states[j],:],theta) - np.dot(phi[states[j],:],theta)
            counter[states[j],actions[j]] += 1
            nb_samples += 1
    for s in range(env.n_states):
        for a in range(env.n_actions):
            if counter[s,a]!=0:
                A[s,a] /= counter[s,a]
    for s in range(policy.n_states):
        for a in range(policy.n_actions):
            argexpo = np.zeros(policy.n_actions)
            if policy.pi[s,a] == 0:
                log_new_pi[s,a] = -float('inf')
            else:
                for b in range(policy.n_actions):
                    argexpo[b] = np.log(policy.pi[s,b]+0.0001) + eta * A[s,b]
                maxi = np.max(argexpo)
                log_new_pi[s,a] = argexpo[a] - np.log(np.sum(np.exp(argexpo - maxi))) - maxi
    print(np.exp(log_new_pi))
    return(Policy(np.exp(log_new_pi)))


def g(theta,eta,phi,samples):
    res = 0
    A = np.zeros((env.n_states,env.n_actions))
    counter = np.zeros((env.n_states,env.n_actions))
    nb_samples = 0
    for i in range(len(samples)):
        states = samples[i]['states']
        actions = samples[i]['actions']
        rewards = samples[i]['rewards']
        next_states = samples[i]['next_states']

        for j in range(len(states)):
            A[states[j],actions[j]] += rewards[j] + np.dot(phi[next_states[j],:],theta) - np.dot(phi[states[j],:],theta)
            counter[states[j],actions[j]] += 1
            nb_samples += 1
    for s in range(env.n_states):
        for a in range(env.n_actions):
            if counter[s,a]!=0:
                A[s,a] /= counter[s,a]
    for i in range(len(samples)):
        states = samples[i]['states']
        actions = samples[i]['actions']
        for j in range(len(states)):
            res += np.exp(eta*A[states[j],actions[j]])
    res /= nb_samples
    return (np.log(res)/eta)

def Dg(theta,eta,phi,samples):
    n_states,p = np.shape(phi)
    numerator = 0
    denominator = 0
    A = np.zeros((env.n_states,env.n_actions))
    D = np.zeros((env.n_states,env.n_actions,p))
    counter = np.zeros((env.n_states,env.n_actions))
    for i in range(len(samples)):
        states = samples[i]['states']
        actions = samples[i]['actions']
        rewards = samples[i]['rewards']
        next_states = samples[i]['next_states']

        for j in range(len(states)):
            A[states[j],actions[j]] += rewards[j] + np.dot(phi[next_states[j],:],theta) - np.dot(phi[states[j],:],theta)
            D[states[j],actions[j],:] += phi[next_states[j],:] - phi[states[j],:]
            counter[states[j],actions[j]] += 1
    for s in range(env.n_states):
        for a in range(env.n_actions):
            if counter[s,a]!=0:
                A[s,a] /= counter[s,a]
    for s in range(env.n_states):
        for a in range(env.n_actions):
            if counter[s,a]!=0:
                D[s,a,:] /= counter[s,a]
    for i in range(len(samples)):
        states = samples[i]['states']
        actions = samples[i]['actions']
        for j in range(len(states)):
            numerator += np.exp(eta*A[states[j],actions[j]]) * D[states[j],actions[j]]
            denominator += np.exp(eta*A[states[j],actions[j]])
    return ((1/eta) * numerator / denominator)


class Policy(object):
    def __init__(self,pi):
        n_states,n_actions = np.shape(pi)
        self.n_actions = n_actions
        self.n_states = n_states
        self.pi = pi
    def draw_action(self,state):
        u = np.random.rand()
        probas = np.cumsum(self.pi[state,:])
        a = 0
        while (a < self.n_actions-1 and (u > probas[a] or self.pi[state,a]==0)):
            a += 1
        return a

def compute_phi(env,p):
    phi = np.zeros((env.n_states,p))
    for k in range(env.n_states):
        phi[k,:] = [k,k**2,np.log(k+1)]
    return(phi)
    
def initialize_pi(env):
    pi = np.zeros((env.n_states,env.n_actions))
    for s in range(env.n_states):
        actions = env.state_actions[s]
        for a in actions:
            pi[s,a] = 1./len(actions)
    return(pi)
    
    
def REPS_mirror_descent(env):
    """Relative Entropy Policy Search using Mirror Descent"""
    p = 3    
    # initialization of the distribution
    pi = initialize_pi(env)
    policy = Policy(pi)
    #Tmax =  -100*np.log(10e-6)/(1-env.gamma)
    K = 50
    N = 100
    eta = 0.1
    theta = [0 for i in range(p)]
    phi = compute_phi(env,p)
    for k in range(K):
        print('Iteration n°',k)
        ##### SAMPLING
        samples = collect_episodes(env,policy=policy,horizon=100,n_episodes=N)
        
        #### OPTIMIZE
        theta = opt.fmin_bfgs(g,x0=theta,fprime=Dg,args=(eta,phi,samples))
        
        #### COMPUTE THE NEW POLICY
        policy = compute_new_policy(eta,policy,phi,theta,samples)   
    return(policy,theta,phi)

In [58]:
policy,theta,phi = REPS_mirror_descent(env)

[[0.5        0.5        0.         0.        ]
 [0.5        0.         0.5        0.        ]
 [0.33333333 0.33333333 0.33333333 0.        ]
 [1.         0.         0.         0.        ]
 [0.         0.5        0.         0.5       ]
 [0.33333333 0.33333333 0.         0.33333333]
 [1.         0.         0.         0.        ]
 [0.5        0.         0.         0.5       ]
 [0.5        0.         0.5        0.        ]
 [0.33333333 0.         0.33333333 0.33333333]
 [0.         0.         0.5        0.5       ]]
Optimization terminated successfully.
         Current function value: -0.029215
         Iterations: 56
         Function evaluations: 110
         Gradient evaluations: 110
[[0.52747138 0.47231955 0.         0.        ]
 [0.49411275 0.         0.50568233 0.        ]
 [0.35135484 0.30260066 0.34594472 0.        ]
 [0.99970012 0.         0.         0.        ]
 [0.         0.47593022 0.         0.52387847]
 [0.29109291 0.35125458 0.         0.35755594]
 [0.99970012 0.         0

         Current function value: -0.007284
         Iterations: 46
         Function evaluations: 127
         Gradient evaluations: 117
[[0.63820398 0.36158549 0.         0.        ]
 [0.40774988 0.         0.59205398 0.        ]
 [0.45838951 0.18520345 0.3563117  0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.3696096  0.         0.63020313]
 [0.11704902 0.51441778 0.         0.36843733]
 [0.99970003 0.         0.         0.        ]
 [0.59385172 0.         0.         0.40594832]
 [0.62590741 0.         0.37389473 0.        ]
 [0.5494676  0.         0.28088842 0.16954342]
 [0.         0.         0.80489696 0.19488742]]
ok
[[0.63820398 0.36158549 0.         0.        ]
 [0.40774988 0.         0.59205398 0.        ]
 [0.45838951 0.18520345 0.3563117  0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.3696096  0.         0.63020313]
 [0.11704902 0.51441778 0.         0.36843733]
 [0.99970003 0.         0.         0.        ]
 [0.59385172 

Optimization terminated successfully.
         Current function value: -0.004082
         Iterations: 53
         Function evaluations: 104
         Gradient evaluations: 104
[[0.69456743 0.30521681 0.         0.        ]
 [0.29317317 0.         0.70663208 0.        ]
 [0.5622302  0.12691021 0.3107642  0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.31128134 0.         0.68853494]
 [0.04769354 0.68365198 0.         0.26855928]
 [0.99970003 0.         0.         0.        ]
 [0.71301719 0.         0.         0.28678457]
 [0.71027    0.         0.28953223 0.        ]
 [0.69075804 0.         0.2268078  0.08233521]
 [0.         0.         0.92742887 0.07236267]]
ok
[[0.69456743 0.30521681 0.         0.        ]
 [0.29317317 0.         0.70663208 0.        ]
 [0.5622302  0.12691021 0.3107642  0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.31128134 0.         0.68853494]
 [0.04769354 0.68365198 0.         0.26855928]
 [0.99970003 0.       

         Current function value: 0.005892
         Iterations: 60
         Function evaluations: 181
         Gradient evaluations: 171
[[0.72481545 0.27497533 0.         0.        ]
 [0.20930215 0.         0.79050579 0.        ]
 [0.6682258  0.08576428 0.245917   0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.26556149 0.         0.73424698]
 [0.02225346 0.7763123  0.         0.20133931]
 [0.99970003 0.         0.         0.        ]
 [0.78200375 0.         0.         0.21779666]
 [0.75394724 0.         0.24585428 0.        ]
 [0.75543664 0.         0.19348589 0.05097784]
 [0.         0.         0.97224667 0.02754908]]
ok
[[0.72481545 0.27497533 0.         0.        ]
 [0.20930215 0.         0.79050579 0.        ]
 [0.6682258  0.08576428 0.245917   0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.26556149 0.         0.73424698]
 [0.02225346 0.7763123  0.         0.20133931]
 [0.99970003 0.         0.         0.        ]
 [0.78200375 0

Optimization terminated successfully.
         Current function value: 0.010720
         Iterations: 58
         Function evaluations: 116
         Gradient evaluations: 116
[[0.71111765 0.2886753  0.         0.        ]
 [0.15422962 0.         0.84557509 0.        ]
 [0.7742887  0.06214781 0.16347032 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.27212512 0.         0.72768228]
 [0.01245255 0.85351567 0.         0.13393253]
 [0.99970003 0.         0.         0.        ]
 [0.84007795 0.         0.         0.15972224]
 [0.77998465 0.         0.21981551 0.        ]
 [0.78491062 0.         0.17964372 0.0353457 ]
 [0.         0.         0.98738906 0.01241056]]
ok
[[0.71111765 0.2886753  0.         0.        ]
 [0.15422962 0.         0.84557509 0.        ]
 [0.7742887  0.06214781 0.16347032 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.27212512 0.         0.72768228]
 [0.01245255 0.85351567 0.         0.13393253]
 [0.99970003 0.        

         Current function value: 0.010136
         Iterations: 53
         Function evaluations: 172
         Gradient evaluations: 161
[[0.75137593 0.24841686 0.         0.        ]
 [0.11694889 0.         0.88285855 0.        ]
 [0.84524901 0.03523819 0.11942119 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.19993788 0.         0.79986722]
 [0.01017231 0.84171554 0.         0.14801537]
 [0.99970003 0.         0.         0.        ]
 [0.83962144 0.         0.         0.1601788 ]
 [0.78429072 0.         0.21550989 0.        ]
 [0.79025711 0.         0.17521174 0.03443108]
 [0.         0.         0.99303172 0.00676735]]
ok
[[0.75137593 0.24841686 0.         0.        ]
 [0.11694889 0.         0.88285855 0.        ]
 [0.84524901 0.03523819 0.11942119 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.19993788 0.         0.79986722]
 [0.01017231 0.84171554 0.         0.14801537]
 [0.99970003 0.         0.         0.        ]
 [0.83962144 0

Optimization terminated successfully.
         Current function value: 0.008704
         Iterations: 49
         Function evaluations: 99
         Gradient evaluations: 99
[[0.74095073 0.25884039 0.         0.        ]
 [0.09040093 0.         0.90940714 0.        ]
 [0.90017978 0.02476859 0.07496049 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.1970253  0.         0.80278204]
 [0.00761161 0.88244271 0.         0.10984804]
 [0.99970003 0.         0.         0.        ]
 [0.86728132 0.         0.         0.1325191 ]
 [0.79268451 0.         0.20711579 0.        ]
 [0.79647471 0.         0.17397353 0.02945159]
 [0.         0.         0.99606714 0.00373272]]
ok
[[0.74095073 0.25884039 0.         0.        ]
 [0.09040093 0.         0.90940714 0.        ]
 [0.90017978 0.02476859 0.07496049 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.1970253  0.         0.80278204]
 [0.00761161 0.88244271 0.         0.10984804]
 [0.99970003 0.         0

         Current function value: 0.009232
         Iterations: 56
         Function evaluations: 234
         Gradient evaluations: 222
[[0.77729434 0.22249875 0.         0.        ]
 [0.06904496 0.         0.93076277 0.        ]
 [0.93396843 0.013576   0.05236518 0.        ]
 [0.99970003 0.         0.         0.        ]
 [0.         0.14107135 0.         0.85873209]
 [0.00745581 0.87590905 0.         0.11653814]
 [0.99970003 0.         0.         0.        ]
 [0.86228162 0.         0.         0.1375187 ]
 [0.79175214 0.         0.20804842 0.        ]
 [0.79515412 0.         0.17388319 0.03086255]
 [0.         0.         0.99730671 0.00249289]]
ok


In [59]:
render_policy(env,policy.pi)

In [61]:
policy.pi

array([[0.77729434, 0.22249875, 0.        , 0.        ],
       [0.06904496, 0.        , 0.93076277, 0.        ],
       [0.93396843, 0.013576  , 0.05236518, 0.        ],
       [0.99970003, 0.        , 0.        , 0.        ],
       [0.        , 0.14107135, 0.        , 0.85873209],
       [0.00745581, 0.87590905, 0.        , 0.11653814],
       [0.99970003, 0.        , 0.        , 0.        ],
       [0.86228162, 0.        , 0.        , 0.1375187 ],
       [0.79175214, 0.        , 0.20804842, 0.        ],
       [0.79515412, 0.        , 0.17388319, 0.03086255],
       [0.        , 0.        , 0.99730671, 0.00249289]])