In [3]:
# パッケージのimport
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
from matplotlib import animation
from IPython.display import display, Video

# アニメーション作成関数（mp4保存 + Jupyter再生）
def display_frames_as_gif(frames, filename="movie_cartpole.mp4"):
    height, width, _ = frames[0].shape
    dpi = 72
    fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=50)
    anim.save(filename)
    plt.close(fig)
    display(Video(filename, embed=True))

# 定数設定
ENV = 'CartPole-v0'
NUM_DIZITIZED = 6
GAMMA = 0.99
ETA = 0.5
MAX_STEPS = 200
NUM_EPISODES = 1000

# エージェント
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)

    def update_Q_function(self, observation, action, reward, observation_next):
        self.brain.update_Q_table(observation, action, reward, observation_next)

    def get_action(self, observation, step):
        return self.brain.decide_action(observation, step)

# 脳：Q学習
class Brain:
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        self.q_table = np.random.uniform(low=0, high=1, size=(NUM_DIZITIZED**num_states, num_actions))

    def bins(self, clip_min, clip_max, num):
        return np.linspace(clip_min, clip_max, num + 1)[1:-1]

    def digitize_state(self, observation):
        cart_pos, cart_v, pole_angle, pole_v = observation
        digitized = [
            np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),
            np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)),
            np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)),
            np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED))
        ]
        return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])

    def update_Q_table(self, observation, action, reward, observation_next):
        state = self.digitize_state(observation)
        state_next = self.digitize_state(observation_next)
        Max_Q_next = max(self.q_table[state_next][:])
        self.q_table[state, action] += ETA * (reward + GAMMA * Max_Q_next - self.q_table[state, action])

    def decide_action(self, observation, episode):
        state = self.digitize_state(observation)
        epsilon = 0.5 * (1 / (episode + 1))
        if epsilon <= np.random.uniform(0, 1):
            return np.argmax(self.q_table[state][:])
        else:
            return np.random.choice(self.num_actions)

# 環境管理
class Environment:
    def __init__(self):
        self.env = gym.make(ENV, render_mode='rgb_array')
        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n
        self.agent = Agent(self.num_states, self.num_actions)

    def run(self):
        complete_episodes = 0
        is_episode_final = False
        frames = []

        for episode in range(NUM_EPISODES):
            observation, _ = self.env.reset()

            for step in range(MAX_STEPS):
                if is_episode_final:
                    frame = self.env.render()
                    frames.append(frame)

                action = self.agent.get_action(observation, episode)
                observation_next, _, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated

                reward = -1 if done and step < 195 else (1 if done else 0)
                complete_episodes = 0 if done and step < 195 else (complete_episodes + 1 if done else complete_episodes)

                self.agent.update_Q_function(observation, action, reward, observation_next)
                observation = observation_next

                if done:
                    print(f"{episode} Episode: Finished after {step + 1} time steps")
                    break

            if is_episode_final:
                display_frames_as_gif(frames)
                break

            if complete_episodes >= 10:
                print("10回連続成功")
                is_episode_final = True

# 実行
cartpole_env = Environment()
cartpole_env.run()


  logger.deprecation(


0 Episode: Finished after 14 time steps
1 Episode: Finished after 10 time steps
2 Episode: Finished after 10 time steps
3 Episode: Finished after 98 time steps
4 Episode: Finished after 10 time steps
5 Episode: Finished after 17 time steps
6 Episode: Finished after 27 time steps
7 Episode: Finished after 12 time steps
8 Episode: Finished after 9 time steps
9 Episode: Finished after 13 time steps
10 Episode: Finished after 20 time steps
11 Episode: Finished after 14 time steps
12 Episode: Finished after 16 time steps
13 Episode: Finished after 11 time steps
14 Episode: Finished after 26 time steps
15 Episode: Finished after 11 time steps
16 Episode: Finished after 11 time steps
17 Episode: Finished after 98 time steps
18 Episode: Finished after 13 time steps
19 Episode: Finished after 35 time steps
20 Episode: Finished after 62 time steps
21 Episode: Finished after 74 time steps
22 Episode: Finished after 28 time steps
23 Episode: Finished after 42 time steps
24 Episode: Finished after 