In [5]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

# --- 環境とパラメータ設定 ---
env = gym.make("CartPole-v1", render_mode=None)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
gamma = 0.99
lr = 0.01
max_episodes = 300

# --- Actor-Critic モデル定義 ---
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc = nn.Linear(state_dim, 128)
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc(x))
        return torch.softmax(self.actor(x), dim=-1), self.critic(x)

model = ActorCritic(state_dim, action_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)

# --- 1エピソードの実行とログ収集 ---
def train_one_episode():
    state, _ = env.reset()
    done = False
    log_probs = []
    values = []
    rewards = []

    while not done:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        probs, value = model(state_tensor)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated

        log_probs.append(dist.log_prob(action))
        values.append(value)
        rewards.append(reward)

        state = next_state

    return log_probs, values, rewards

# --- 学習ループ ---
episode_rewards = []

for episode in range(max_episodes):
    log_probs, values, rewards = train_one_episode()

    # 割引累積報酬（G）を計算
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32)
    values = torch.cat(values).squeeze()
    log_probs = torch.stack(log_probs)

    # 損失関数（Advantage）
    advantage = returns - values.detach()
    actor_loss = -torch.sum(log_probs * advantage)
    critic_loss = nn.functional.mse_loss(values, returns)
    loss = actor_loss + critic_loss

    # 学習
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_reward = sum(rewards)
    episode_rewards.append(total_reward)
    print(f"Episode {episode+1}: Total Reward = {total_reward}")

# --- アニメーションGIF生成 ---
fig, ax = plt.subplots()
ax.set_xlim(0, max_episodes)
ax.set_ylim(0, max(episode_rewards) * 1.1)
line, = ax.plot([], [], lw=2)
ax.set_title("PyTorch Actor-Critic: CartPole Learning")
ax.set_xlabel("Episode")
ax.set_ylabel("Total Reward")
xdata, ydata = [], []

def update(frame):
    xdata.append(frame)
    ydata.append(episode_rewards[frame])
    line.set_data(xdata, ydata)
    return line,

ani = FuncAnimation(fig, update, frames=range(len(episode_rewards)), blit=True)

# --- GIF保存 ---
ani.save("pytorch_actor_critic_learning.gif", writer=PillowWriter(fps=10))
# MP4で保存する場合（ffmpeg必要）：
# ani.save("pytorch_actor_critic_learning.mp4", writer="ffmpeg", fps=10)

plt.close()


  if not isinstance(terminated, (bool, np.bool8)):


Episode 1: Total Reward = 11.0
Episode 2: Total Reward = 12.0
Episode 3: Total Reward = 10.0
Episode 4: Total Reward = 10.0
Episode 5: Total Reward = 10.0
Episode 6: Total Reward = 9.0
Episode 7: Total Reward = 8.0
Episode 8: Total Reward = 10.0
Episode 9: Total Reward = 9.0
Episode 10: Total Reward = 10.0
Episode 11: Total Reward = 8.0
Episode 12: Total Reward = 10.0
Episode 13: Total Reward = 9.0
Episode 14: Total Reward = 11.0
Episode 15: Total Reward = 10.0
Episode 16: Total Reward = 8.0
Episode 17: Total Reward = 10.0
Episode 18: Total Reward = 10.0
Episode 19: Total Reward = 8.0
Episode 20: Total Reward = 10.0
Episode 21: Total Reward = 10.0
Episode 22: Total Reward = 9.0
Episode 23: Total Reward = 9.0
Episode 24: Total Reward = 10.0
Episode 25: Total Reward = 10.0
Episode 26: Total Reward = 11.0
Episode 27: Total Reward = 10.0
Episode 28: Total Reward = 8.0
Episode 29: Total Reward = 8.0
Episode 30: Total Reward = 10.0
Episode 31: Total Reward = 10.0
Episode 32: Total Reward = 9