In [None]:
from pistonball_CleanRL import *
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from supersuit import color_reduction_v0, frame_stack_v1, resize_v1
from pettingzoo.butterfly import pistonball_v6
import pandas as pd
import cv2
import os
import re

In [None]:
# 创建保存模型的目录
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)

# 创建保存日志的目录
log_dir = "training_logs"
os.makedirs(log_dir, exist_ok=True)

In [None]:
"""ALGO PARAMS"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ent_coef = 0.1
vf_coef = 0.1
clip_coef = 0.1
gamma = 0.99
batch_size = 32
stack_size = 4
frame_size = (64, 64)
max_cycles = 125
total_episodes = 2
n_pistons = 10

logs = []
writer = SummaryWriter(log_dir)
# model_name = "model_episode_300.pt"
model_name = None

In [None]:
""" ENV SETUP """
env = pistonball_v6.parallel_env(
    render_mode="rgb_array", continuous=False, max_cycles=max_cycles, n_pistons=n_pistons
)
env = color_reduction_v0(env)
env = resize_v1(env, frame_size[0], frame_size[1])
env = frame_stack_v1(env, stack_size=stack_size)
num_agents = len(env.possible_agents)
num_actions = env.action_space(env.possible_agents[0]).n
observation_size = env.observation_space(env.possible_agents[0]).shape

""" LEARNER SETUP """
# agent = Agent(num_actions=num_actions).to(device)
agent = Agent_ADG(num_actions=num_actions).to(device)

#读取之前的模型
previous_episodes = 0
if model_name:
    model_path = os.path.join(save_dir, model_name)
    agent.load_state_dict(torch.load(model_path, map_location=device))
    print("Model loaded successfully.")
    pattern = re.compile(r"model_episode_(\d+)\.pt")
    match = pattern.match(model_name)
    previous_episodes = int(match.group(1)) if match else 0

optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5)

""" ALGO LOGIC: EPISODE STORAGE"""
end_step = 0
total_episodic_return = 0
rb_obs = torch.zeros((max_cycles, num_agents, stack_size, *frame_size)).to(device)
rb_actions = torch.zeros((max_cycles, num_agents)).to(device)
rb_logprobs = torch.zeros((max_cycles, num_agents)).to(device)
rb_rewards = torch.zeros((max_cycles, num_agents)).to(device)
rb_terms = torch.zeros((max_cycles, num_agents)).to(device)
rb_values = torch.zeros((max_cycles, num_agents)).to(device)

In [None]:
""" TRAINING LOGIC """
# train for n number of episodes
for episode in range(previous_episodes+1, total_episodes+previous_episodes+1):
    # collect an episode
    with torch.no_grad():
        # collect observations and convert to batch of torch tensors
        next_obs, info = env.reset(seed=None)
        # reset the episodic return
        total_episodic_return = 0

        # each episode has num_steps
        for step in range(0, max_cycles):
            # rollover the observation
            obs = batchify_obs(next_obs, device)

            # get action from the agent
            actions, logprobs, _, values = agent(obs)

            # execute the environment and log data
            next_obs, rewards, terms, truncs, infos = env.step(
                unbatchify(actions, env)
            )

            # add to episode storage
            rb_obs[step] = obs
            rb_rewards[step] = batchify(rewards, device)
            rb_terms[step] = batchify(terms, device)
            rb_actions[step] = actions
            rb_logprobs[step] = logprobs
            rb_values[step] = values.flatten()

            # compute episodic return
            total_episodic_return += rb_rewards[step].cpu().numpy()

            # if we reach termination or truncation, end
            if any([terms[a] for a in terms]) or any([truncs[a] for a in truncs]):
                end_step = step
                break

    # bootstrap value if not done
    with torch.no_grad():
        rb_advantages = torch.zeros_like(rb_rewards).to(device)
        for t in reversed(range(end_step)):
            delta = (
                rb_rewards[t]
                + gamma * rb_values[t + 1] * rb_terms[t + 1]
                - rb_values[t]
            )
            rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1]
        rb_returns = rb_advantages + rb_values

    # convert our episodes to batch of individual transitions
    b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1)
    b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1)
    b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1)
    b_returns = torch.flatten(rb_returns[:end_step], start_dim=0, end_dim=1)
    b_values = torch.flatten(rb_values[:end_step], start_dim=0, end_dim=1)
    b_advantages = torch.flatten(rb_advantages[:end_step], start_dim=0, end_dim=1)

    # Optimizing the policy and value network
    b_index = np.arange(len(b_obs))
    clip_fracs = []
    for repeat in range(3):
        # shuffle the indices we use to access the data
        np.random.shuffle(b_index)
        for start in range(0, len(b_obs), batch_size):
            # select the indices we want to train on
            end = start + batch_size
            batch_index = b_index[start:end]

            _, newlogprob, entropy, value = agent(
                b_obs[batch_index], b_actions.long()[batch_index]
            )
            logratio = newlogprob - b_logprobs[batch_index]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clip_fracs += [
                    ((ratio - 1.0).abs() > clip_coef).float().mean().item()
                ]

            # normalize advantaegs
            advantages = b_advantages[batch_index]
            advantages = (advantages - advantages.mean()) / (
                advantages.std() + 1e-8
            )

            # Policy loss
            pg_loss1 = -b_advantages[batch_index] * ratio
            pg_loss2 = -b_advantages[batch_index] * torch.clamp(
                ratio, 1 - clip_coef, 1 + clip_coef
            )
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            value = value.flatten()
            v_loss_unclipped = (value - b_returns[batch_index]) ** 2
            v_clipped = b_values[batch_index] + torch.clamp(
                value - b_values[batch_index],
                -clip_coef,
                clip_coef,
            )
            v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            v_loss = 0.5 * v_loss_max.mean()

            entropy_loss = entropy.mean()
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

    # 记录日志
    logs.append({
        "episode": episode,
        "episodic_return": np.mean(total_episodic_return),
        "episode_length": end_step,
        "value_loss": v_loss.item(),
        "policy_loss": pg_loss.item(),
        "old_approx_kl": old_approx_kl.item(),
        "approx_kl": approx_kl.item(),
        "clip_fraction": np.mean(clip_fracs),
        "explained_variance": explained_var.item()
    })

    writer.add_scalar('Loss/Value Loss', v_loss.item(), episode)
    writer.add_scalar('Loss/Policy Loss', pg_loss.item(), episode)
    writer.add_scalar('Metrics/Episodic Return', np.mean(total_episodic_return), episode)
    writer.add_scalar('Metrics/Episode Length', end_step, episode)
    writer.add_scalar('Metrics/Old Approx KL', old_approx_kl.item(), episode)
    writer.add_scalar('Metrics/Approx KL', approx_kl.item(), episode)
    writer.add_scalar('Metrics/Clip Fraction', np.mean(clip_fracs), episode)
    writer.add_scalar('Metrics/Explained Variance', explained_var.item(), episode)

    print(f"Training episode {episode}")
    print(f"Episodic Return: {np.mean(total_episodic_return)}")
    print(f"Episode Length: {end_step}")
    print("")
    print(f"Value Loss: {v_loss.item()}")
    print(f"Policy Loss: {pg_loss.item()}")
    print(f"Old Approx KL: {old_approx_kl.item()}")
    print(f"Approx KL: {approx_kl.item()}")
    print(f"Clip Fraction: {np.mean(clip_fracs)}")
    print(f"Explained Variance: {explained_var.item()}")
    print("\n-------------------------------------------\n")

# 保存模型参数
model_path = os.path.join(save_dir, f"model_episode_{total_episodes+previous_episodes}.pt")
torch.save(agent.state_dict(), model_path)
print(f"Model saved to {model_path}")

# 读取之前的日志文件（如果存在），并保存日志
log_path = os.path.join(log_dir, "training_log.csv")
if os.path.exists(log_path):
    logs = pd.read_csv(log_path).to_dict(orient='records') + logs
pd.DataFrame(logs).to_csv(log_path, index=False)
print(f"Log saved to {log_path}")