In [None]:
!pip install gymnasium
!pip install minigrid
!pip install matplotlib

import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from minigrid.wrappers import ImgObsWrapper

import random
import math
import numpy as np
from collections import deque, namedtuple
import matplotlib.pyplot as plt

# --- ハイパーパラメータ設定 ---
# カリキュラム学習の設定
MAZE_SIZES = [5, 8, 16]
EPISODES_PER_STAGE = 1000

# DQNエージェントの設定
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 30000
REPLAY_BUFFER_SIZE = 50000
BATCH_SIZE = 128
LEARNING_RATE = 3e-5 
TARGET_UPDATE_FREQ = 100

# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 経験を保存するためのデータ構造 ---
Transition = namedtuple('Transition', 
                        ('state', 'action', 'next_state', 'reward', 'terminated'))

# --- リプレイバッファ ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# --- CNNベースのQネットワーク ---
class QNetwork(nn.Module):
    def __init__(self, obs_space_shape, action_space_n):
        super(QNetwork, self).__init__()
        h, w, c = obs_space_shape
        
        self.cnn = nn.Sequential(
            nn.Conv2d(c, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        with torch.no_grad():
            dummy_input = torch.zeros(1, c, h, w)
            cnn_out_size = self.cnn(dummy_input).shape[1]

        self.fc = nn.Sequential(
            nn.Linear(cnn_out_size, 256),
            nn.ReLU(),
            nn.Linear(256, action_space_n)
        )

    def forward(self, x):
        x = x.to(device).float() / 255.0
        x = x.permute(0, 3, 1, 2)
        cnn_out = self.cnn(x)
        return self.fc(cnn_out)

# --- DQNエージェント ---
class DQNAgent:
    def __init__(self, obs_space_shape, action_space_n):
        self.action_space_n = action_space_n
        self.policy_net = QNetwork(obs_space_shape, action_space_n).to(device)
        self.target_net = QNetwork(obs_space_shape, action_space_n).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE)
        self.steps_done = 0

    def select_action(self, state):
        eps_threshold = EPSILON_END + (EPSILON_START - EPSILON_END) * \
                        math.exp(-1. * self.steps_done / EPSILON_DECAY)
        self.steps_done += 1
        
        if random.random() > eps_threshold:
            with torch.no_grad():
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.action_space_n)]], device=device, dtype=torch.long)

    def update_model(self):
        if len(self.replay_buffer) < BATCH_SIZE:
            return

        transitions = self.replay_buffer.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        current_q_values = self.policy_net(state_batch).gather(1, action_batch)

        with torch.no_grad():
            next_q_values = torch.zeros(BATCH_SIZE, device=device)
            next_q_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
            target_q_values = reward_batch + (GAMMA * next_q_values)

        loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def sync_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

# --- メインの学習ループ ---
if __name__ == "__main__":
    agent = None
    all_stage_rewards = {}

    for stage, size in enumerate(MAZE_SIZES):
        print(f"--- Curriculum Stage {stage + 1}: {size}x{size} Maze ---")
        env = gym.make(f'MiniGrid-Empty-{size}x{size}-v0')
        env = ImgObsWrapper(env)
        
        if agent is None:
            obs_shape = env.observation_space.shape
            action_n = env.action_space.n
            agent = DQNAgent(obs_shape, action_n)

        stage_rewards = []
        for episode in range(EPISODES_PER_STAGE):
            obs, info = env.reset()
            state = torch.tensor(obs, device=device).unsqueeze(0).float()
            
            terminated, truncated = False, False
            episode_reward = 0
            while not terminated and not truncated:
                action = agent.select_action(state)
                obs, reward, terminated, truncated, info = env.step(action.item())
                episode_reward += reward

                next_state = torch.tensor(obs, device=device).unsqueeze(0).float() if not terminated else None
                
                agent.replay_buffer.push(state, action, next_state, 
                                         torch.tensor([reward], device=device), 
                                         torch.tensor(terminated, device=device))
                state = next_state
                agent.update_model()

            stage_rewards.append(episode_reward)

            if (episode + 1) % TARGET_UPDATE_FREQ == 0:
                agent.sync_target_network()
            
            if (episode + 1) % 100 == 0:
                avg_reward = np.mean(stage_rewards[-100:])
                print(f"Stage {stage+1}, Episode {episode+1}, Avg Reward (last 100): {avg_reward:.2f}")
        
        all_stage_rewards[f"{size}x{size}"] = stage_rewards
        print(f"--- Stage {stage + 1} Complete ---")

    env.close()

    # --- 結果のプロット ---
    plt.figure(figsize=(12, 7))
    total_episodes = 0
    for size, rewards in all_stage_rewards.items():
        moving_avg = np.convolve(rewards, np.ones(100)/100, mode='valid')
        episodes = np.arange(total_episodes, total_episodes + len(moving_avg))
        plt.plot(episodes, moving_avg, label=f'Stage: {size}')
        total_episodes += len(rewards)
    
    plt.title("DQN with Curriculum Learning Performance")
    plt.xlabel("Total Episodes")
    plt.ylabel("Average Reward (Moving Avg over 100 episodes)")
    plt.legend()
    plt.grid(True)
    plt.savefig("dqn_curriculum_performance.png")
    plt.show()

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting gymnasium
  Downloading gymnasium-1.2.0-py3-none-any.whl (944 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m944.3/944.3 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting farama-notifications>=0.0.1 (from gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium
Successfully installed farama-notifications-0.0.4 gymnasium-1.2.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting minigrid
  Downloading minigrid-3.0.0-py3-none-any.whl (136 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━