In [2]:
import gym
import numpy as np
import matplotlib.pyplot as plt

In [3]:
class MazeEnv(gym.Env):
    def __init__(self):
        self.state = 0

    def reset(self):
        self.state = 0
        return self.state
    
    def step(self, action):
        if action == 0:
            self.state -= 3
        if action == 1:
            self.state += 1
        if action == 2:
            self.state += 3
        if action == 3:
            self.state -= 1
        done = False
        if self.state == 8:
            done = True
        return self.state, 1, done, {}

In [5]:
class Agent:
    def __init__(self):
        self.actions = list(range(4))
        self.theta_0 = np.asarray(
            [[np.nan, 1, 1, np.nan],      # s0
            [np.nan, 1, np.nan, 1],      # s1
            [np.nan, np.nan, 1, 1],      # s2
            [1, np.nan, np.nan, np.nan], # s3 
            [np.nan, 1, 1, np.nan],      # s4
            [1, np.nan, np.nan, 1],      # s5
            [np.nan, 1, np.nan, np.nan], # s6 
            [1, 1, np.nan, 1]]           # s7
        )
        self.pi = self._cvt_theta_0_to_pi(self.theta_0)

    def _cvt_theta_0_to_pi(self, theta):
        m, n = theta.shape
        pi = np.zeros((m, n))
        for r in range(m):
            pi[r, :] = theta[r, :] / np.nansum(theta[r, :])
        return np.nan_to_num(pi)
    
    def choose_action(self, state):
        action = np.random.choice(self.actions, p=self.pi[state, :])
        return action

In [6]:
env = MazeEnv()
state = env.reset()
agent = Agent()

done = False
action_history = []
state_history = []
while not done:
    action = agent.choose_action(state)
    state, reward, done, _ = env.step(action)
    action_history.append(action)
    state_history.append(state)
print("action history:", action_history)
print("state_history:", state_history)

action history: [2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 0, 1, 1, 3, 3, 1, 1, 3, 1, 3, 3, 1, 3, 1, 3, 2, 0, 2, 0, 1, 1, 2, 3, 1, 3, 1, 3, 2, 0, 2, 3, 1, 1]
state_history: [3, 0, 3, 0, 3, 0, 3, 0, 3, 0, 3, 0, 3, 0, 1, 2, 1, 0, 1, 2, 1, 2, 1, 0, 1, 0, 1, 0, 3, 0, 3, 0, 1, 2, 5, 4, 5, 4, 5, 4, 7, 4, 7, 6, 7, 8]
