In [4]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML

# Q-Network
class DQN(nn.Module):
    def __init__(self, obs_size, n_actions):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        return self.net(x)

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

# アニメーション作成関数
def display_frames_as_gif(frames, filename="cartpole.mp4"):
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        return patch,

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save(filename)
    plt.close()

# 学習本体
def train_cartpole():
    env = gym.make("CartPole-v1", render_mode="rgb_array")
    obs_size = env.observation_space.shape[0]
    n_actions = env.action_space.n

    policy_net = DQN(obs_size, n_actions)
    target_net = DQN(obs_size, n_actions)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
    buffer = ReplayBuffer(10000)
    batch_size = 64
    gamma = 0.99
    epsilon = 1.0
    epsilon_decay = 0.995
    epsilon_min = 0.01

    episodes = 500
    final_frames = []  # 最後のエピソード用フレーム保存

    for episode in range(episodes):
        obs, _ = env.reset()
        total_reward = 0
        episode_frames = []

        for t in range(200):
            state = torch.FloatTensor(obs).unsqueeze(0)
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    q_values = policy_net(state)
                    action = q_values.argmax().item()

            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            buffer.push((obs, action, reward, next_obs, done))
            obs = next_obs
            total_reward += reward

            # アニメーション用フレーム追加（最終エピソードのみ）
            if episode == episodes - 1:
                frame = env.render()
                episode_frames.append(frame)

            if len(buffer) >= batch_size:
                transitions = buffer.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*transitions)

                states_tensor = torch.FloatTensor(states)
                actions_tensor = torch.LongTensor(actions).unsqueeze(1)
                rewards_tensor = torch.FloatTensor(rewards).unsqueeze(1)
                next_states_tensor = torch.FloatTensor(next_states)
                dones_tensor = torch.FloatTensor(dones).unsqueeze(1)

                q_values = policy_net(states_tensor).gather(1, actions_tensor)
                next_q_values = target_net(next_states_tensor).max(1)[0].detach().unsqueeze(1)
                expected_q_values = rewards_tensor + gamma * next_q_values * (1 - dones_tensor)

                loss = nn.MSELoss()(q_values, expected_q_values)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if done:
                break

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        if episode % 10 == 0:
            target_net.load_state_dict(policy_net.state_dict())

        print(f"Episode {episode}: total reward = {total_reward}")

        if episode == episodes - 1:
            final_frames = episode_frames

    env.close()
    display_frames_as_gif(final_frames, "cartpole_54.mp4")
    print("動画保存完了：cartpole.mp4")

if __name__ == "__main__":
    train_cartpole()


Episode 0: total reward = 32.0
Episode 1: total reward = 26.0
Episode 2: total reward = 40.0
Episode 3: total reward = 18.0
Episode 4: total reward = 48.0
Episode 5: total reward = 28.0
Episode 6: total reward = 22.0
Episode 7: total reward = 27.0
Episode 8: total reward = 15.0
Episode 9: total reward = 19.0
Episode 10: total reward = 12.0
Episode 11: total reward = 11.0
Episode 12: total reward = 12.0
Episode 13: total reward = 19.0
Episode 14: total reward = 20.0
Episode 15: total reward = 25.0
Episode 16: total reward = 66.0
Episode 17: total reward = 16.0
Episode 18: total reward = 26.0
Episode 19: total reward = 15.0
Episode 20: total reward = 21.0
Episode 21: total reward = 39.0
Episode 22: total reward = 23.0
Episode 23: total reward = 15.0
Episode 24: total reward = 23.0
Episode 25: total reward = 18.0
Episode 26: total reward = 29.0
Episode 27: total reward = 25.0
Episode 28: total reward = 21.0
Episode 29: total reward = 23.0
Episode 30: total reward = 29.0
Episode 31: total 