In [1]:
import gymnasium as gym
import numpy as np

In [11]:
# 各状態の分割数
NUM_DIZITIZED = 6
GAMMA = 0.99  # 時間割引率
ETA = 0.5  # 学習係数


def bins(clip_min, clip_max, num):
    # 観測した状態デジタル変換する閾値を求める
    return np.linspace(clip_min, clip_max, num + 1)[1:-1]

def digitize(observation):
    #状態の離散化
    cart_pos, cart_v, pole_angle, pole_v = observation
    digitized = [
        np.digitize(cart_pos, bins=bins(-2.4, 2.4, NUM_DIZITIZED)),
        np.digitize(cart_v, bins=bins(-3.0, 3.0, NUM_DIZITIZED)),
        np.digitize(pole_angle, bins=bins(-0.5, 0.5, NUM_DIZITIZED)),
        np.digitize(pole_v, bins=bins(-2.0, 2.0, NUM_DIZITIZED))
    ]
    return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])

class Agent:
    def __init__(self, num_states, num_actions):
        num_states = NUM_DIZITIZED ** num_states
        self.num_actions = num_actions
        self.q_table = np.random.uniform(low=-1, high=1, size=(num_states, num_actions))

    def update_Q_function(self, observation, action, reward, observation_next):
        d_observation = digitize(observation)
        d_observation_next = digitize(observation_next)
        self.update_Q_table(d_observation, action, reward, d_observation_next)

    def get_action(self, observation, step):
        d_observation = digitize(observation)
        action = self.decide_action(d_observation, step)
        return action

    def update_Q_table(self, state, action, reward, state_next):
        Max_Q_next = max(self.q_table[state_next][:])
        self.q_table[state, action] = self.q_table[state, action] + \
            ETA * (reward + GAMMA * Max_Q_next - self.q_table[state, action])

    def decide_action(self, state, episode):
        epsilon = 0.5 * (1 / (episode + 1))
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.q_table[state][:])
        else:
            action = np.random.choice(self.num_actions)
        return action

In [14]:
from matplotlib import animation

# 最大のステップ数
MAX_STEPS = 500
# 最大の試行回数
NUM_EPISODES = 10000

class Environment():
    def __init__(self, toy_env):
        self.env = gym.make(toy_env)
        num_states = self.env.observation_space.shape[0]
        num_actions = self.env.action_space.n
        self.agent = Agent(num_states, num_actions)

    def run(self):
        complete_episodes = 0 # 成功数
        step_list = []

        # 試行数分繰り返す
        for episode in range(NUM_EPISODES):
            observation, _ = self.env.reset()  # 環境の初期化
            for step in range(MAX_STEPS):
                action = self.agent.get_action(observation, episode)
                observation_next, _, term, trunc, _ = self.env.step(action)
                if term:
                    reward = -1  
                    complete_episodes = 0  # 成功数をリセット
                elif trunc:
                    print(f"{episode}: truncated {step}, {complete_episodes}")
                    reward = 1 
                    complete_episodes += 1  # 連続成功記録を更新
                else:
                    reward = 0
                # Qテーブルを更新する
                self.agent.update_Q_function(observation, action, reward, observation_next)
                # 観測の更新
                observation = observation_next

                # 終了時の処理
                if term or trunc:
                    step_list.append(step+1)
                    break

        return step_list, frames

In [15]:
TOY = "CartPole-v1"

cartpole = Environment(TOY)
step_list = cartpole.run()

136: truncated 499, 0
206: truncated 499, 0
211: truncated 499, 0
241: truncated 499, 0
280: truncated 499, 0
285: truncated 499, 0
290: truncated 499, 0
291: truncated 499, 1
292: truncated 499, 2
293: truncated 499, 3
294: truncated 499, 4
295: truncated 499, 5
449: truncated 499, 0
450: truncated 499, 1
453: truncated 499, 0
455: truncated 499, 0
457: truncated 499, 0
460: truncated 499, 0
462: truncated 499, 0
463: truncated 499, 1
465: truncated 499, 0
470: truncated 499, 0
471: truncated 499, 1
475: truncated 499, 0
477: truncated 499, 0
478: truncated 499, 1
479: truncated 499, 2
480: truncated 499, 3
483: truncated 499, 0
485: truncated 499, 0
486: truncated 499, 1
489: truncated 499, 0
490: truncated 499, 1
491: truncated 499, 2
492: truncated 499, 3
493: truncated 499, 4
494: truncated 499, 5
496: truncated 499, 0
499: truncated 499, 0
500: truncated 499, 1
501: truncated 499, 2
503: truncated 499, 0
504: truncated 499, 1
505: truncated 499, 2
507: truncated 499, 0
508: trunc

KeyboardInterrupt: 

In [9]:
import matplotlib.pyplot as plt

plt.plot(step_list[0])

NameError: name 'step_list' is not defined