# 自作のActor-Criticノートブック

In [32]:
import numpy as np
import copy
from dataclasses import dataclass, asdict, is_dataclass

import sys
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

import gymnasium as gym

from myActivator import tanhAndScale
from myFunction import make_squashed_gaussian

In [33]:
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s [%(levelname)s] %(message)s",
                    stream=sys.stdout, datefmt="%H:%M:%S")

In [34]:
env = gym.make("Pendulum-v1",render_mode="human")
for key in vars(env.spec):
    logging.info('%s: %s', key, vars(env.spec)[key])
for key in vars(env.unwrapped):
    logging.info('%s: %s', key, vars(env.unwrapped)[key])

22:12:11 [INFO] id: Pendulum-v1
22:12:11 [INFO] entry_point: gymnasium.envs.classic_control.pendulum:PendulumEnv
22:12:11 [INFO] reward_threshold: None
22:12:11 [INFO] nondeterministic: False
22:12:11 [INFO] max_episode_steps: 200
22:12:11 [INFO] order_enforce: True
22:12:11 [INFO] disable_env_checker: False
22:12:11 [INFO] kwargs: {'render_mode': 'human'}
22:12:11 [INFO] additional_wrappers: ()
22:12:11 [INFO] vector_entry_point: None
22:12:11 [INFO] namespace: None
22:12:11 [INFO] name: Pendulum
22:12:11 [INFO] version: 1
22:12:11 [INFO] max_speed: 8
22:12:11 [INFO] max_torque: 2.0
22:12:11 [INFO] dt: 0.05
22:12:11 [INFO] g: 10.0
22:12:11 [INFO] m: 1.0
22:12:11 [INFO] l: 1.0
22:12:11 [INFO] render_mode: human
22:12:11 [INFO] screen_dim: 500
22:12:11 [INFO] screen: None
22:12:11 [INFO] clock: None
22:12:11 [INFO] isopen: True
22:12:11 [INFO] action_space: Box(-2.0, 2.0, (1,), float32)
22:12:11 [INFO] observation_space: Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
22:12:11 [INFO] spec

In [35]:
@dataclass
class Config:
    V_net_sizes = [6,12,24,12,6]
    P_net_sizes = [6,12,24,24,12,6]
    V_net_in = 3
    P_net_in = 3
    V_net_out = 1
    P_net_out = 2

    V_lr = 1e-3
    P_lr = 1e-3

    u_high = 2.0
    u_low = -2.0

    log_std_min = -10.0
    log_std_max = 1.0

    gamma = 0.9

In [36]:
class ActorCriticAgent:
    def __init__(self,Config,device=None):
        if Config:
            self.Config = Config
        else:
            raise ValueError("No Config!!")
        
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        self.u_high = torch.as_tensor(Config.u_high, dtype=torch.float32, device=self.device)
        self.u_low = torch.as_tensor(Config.u_low, dtype=torch.float32, device=self.device)
        
        self.V_net = self.build_net(
            Config.V_net_in,
            Config.V_net_sizes,
            Config.V_net_out
        ).to(self.device)
        self.V_net.train()

        self.P_net = self.build_net(
            Config.P_net_in,
            Config.P_net_sizes,
            Config.P_net_out
        ).to(self.device)
        self.P_net.train()

        self.V_optim = optim.Adam(self.V_net.parameters(),Config.V_lr)
        self.P_optim = optim.Adam(self.P_net.parameters(),Config.P_lr)

        self.log_std_min = torch.as_tensor(Config.log_std_min,dtype=torch.float32,device=self.device)
        self.log_std_max = torch.as_tensor(Config.log_std_max,dtype=torch.float32,device=self.device)

    
    def to(self,device):
        self.device = torch.device(device)
        self.V_net.to(self.device)
        self.P_net.to(self.device)
        return self


    def build_net(self,input_size,hidden_sizes,output_size=1,output_activator=None):
        layers = []
        for input_size, output_size in zip([input_size]+hidden_sizes, hidden_sizes+[output_size]):
            layers.append(nn.Linear(input_size,output_size))
            layers.append(nn.ReLU())
        layers = layers[:-1]
        if output_activator:
            layers.append(output_activator)
        net = nn.Sequential(*layers)
        return net
    

    @torch.no_grad()
    def step(self,state):
        state = torch.as_tensor(state,dtype=torch.float32,device=self.device)
        if state.dim() == 1:
            state = state.unsqueeze(0)
        out = self.P_net(state)
        mu, log_std = torch.chunk(out, 2, dim=-1)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        dist = make_squashed_gaussian(mu=mu,std=std,low=self.u_low,high=self.u_high)
        action = dist.rsample()
        return action.squeeze(0).cpu().numpy()
    

    def save_all(self,path:str,extra:dict|None=None):
        cfg = asdict(self.Config) if is_dataclass(self.Config) else self.Config
        ckpt = {
            "Config":cfg,
            "V_net":self.V_net.state_dict(),
            "P_net":self.P_net.state_dict(),
        }
        if extra is not None:
            ckpt["extra"] = extra
        
        torch.save(ckpt,path)

    
    def load_all(self,path:str,map_location=None):
        ckpt = torch.load(path,map_location=map_location)
        self.V_net.load_state_dict(ckpt["V_net"])
        self.P_net.load_state_dict(ckpt["P_net"])

        return ckpt.get("extra",None)
    

    def mode2eval(self):
        self.V_net.eval()
        self.P_net.eval()


    def mode2train(self):
        self.V_net.train()
        self.P_net.train()
    

    def update_net_batch(self,states,actions,rewards,states_next,dones):
        states = torch.as_tensor(states,dtype=torch.float32,device=self.device)
        actions = torch.as_tensor(actions,dtype=torch.float32,device=self.device)
        rewards = torch.as_tensor(rewards,dtype=torch.float32,device=self.device)
        states_next = torch.as_tensor(states_next,dtype=torch.float32,device=self.device)

        if rewards.dim() == 1:
            rewards = rewards.unsqueeze(1)
        
        if dones is None:
            dones = torch.zeros((states.shape[0], 1), dtype=torch.float32, device=self.device)
        else:
            dones = torch.as_tensor(dones, dtype=torch.float32, device=self.device)
            if dones.dim() == 1:
                dones = dones.unsqueeze(1)

        with torch.no_grad():
            y_targets = rewards+self.Config.gamma*(1-dones)*self.V_net(states_next)

        V_values = self.V_net(states)
        V_loss = F.mse_loss(y_targets,V_values)
        self.V_optim.zero_grad()
        V_loss.backward()
        self.V_optim.step()

        for p in self.V_net.parameters():
            p.requires_grad_(False)

        outs = self.P_net(states)
        mus, log_stds = torch.chunk(outs, 2, dim=-1)
        log_stds = torch.clamp(log_stds, self.log_std_min, self.log_std_max)
        stds = torch.exp(log_stds)
        dists = make_squashed_gaussian(mu=mus,std=stds,low=self.u_low,high=self.u_high)
        advantages = (y_targets - V_values).detach()
        P_loss = -(advantages*dists.log_prob(actions).unsqueeze(-1)).mean()
        self.P_optim.zero_grad()
        P_loss.backward()
        self.P_optim.step()

        # print("logp:", dists.log_prob(actions).unsqueeze(-1).shape, "adv:", advantages.shape,
        #       "std mean:", stds.mean().item(),
        #       "logp mean:", dists.log_prob(actions).unsqueeze(-1).mean().item())

        for p in self.V_net.parameters():
            p.requires_grad_(True)

        return float(V_loss.item()), float(P_loss.item())

In [37]:
def train(
        env,
        agent,
        rollout_num=200,
        rollout_len=256,
):
    print("cuda available:", torch.cuda.is_available())
    print("agent device:", agent.device)
    print("P_net device:", next(agent.V_net.parameters()).device)
    print("Q_net device:", next(agent.P_net.parameters()).device)

    reward_history = []
    reward_log = 0
    episode = 0

    state, info = env.reset()

    for r in range(rollout_num):
        # ---- rollout バッファ（長さ rollout_len）を毎回作り直す ----
        states = []
        actions = []
        rewards = []
        states_next = []
        dones = []

        for t in range(rollout_len):
            with torch.no_grad():
                action = agent.step(state)
            state_next, reward, terminated, truncated, info = env.step(action)
            done = terminated

            states.append(np.atleast_1d(state).tolist())
            actions.append(np.atleast_1d(action).tolist())
            rewards.append(np.atleast_1d(reward).tolist())
            states_next.append(np.atleast_1d(state_next).tolist())
            dones.append(np.atleast_1d(done).tolist())

            reward_log += reward

            if terminated or truncated:
                state, info = env.reset()
                episode += 1
                logging.info('episode train %d: reward = %.2f', episode, reward_log)
                reward_history.append(reward_log)
                reward_log = 0
            else:
                state = state_next
        
        V_loss, P_loss = agent.update_net_batch(states,actions,rewards,states_next,dones)

    return reward_history

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = ActorCriticAgent(Config=Config(),device=device)

rollout_num=500
rollout_len=100

rh = train(
    env=env,
    agent=agent,
    rollout_num=rollout_num,
    rollout_len=rollout_len,
)

cuda available: True
agent device: cuda
P_net device: cuda:0
Q_net device: cuda:0
22:12:18 [INFO] episode train 1: reward = -1773.71
22:12:24 [INFO] episode train 2: reward = -1795.53
22:12:31 [INFO] episode train 3: reward = -1661.74
22:12:38 [INFO] episode train 4: reward = -1167.24
22:12:44 [INFO] episode train 5: reward = -1481.06
22:12:51 [INFO] episode train 6: reward = -1035.51
22:12:58 [INFO] episode train 7: reward = -1707.66
22:13:04 [INFO] episode train 8: reward = -861.98
22:13:11 [INFO] episode train 9: reward = -1075.98
22:13:18 [INFO] episode train 10: reward = -1012.65
22:13:24 [INFO] episode train 11: reward = -1061.22
22:13:31 [INFO] episode train 12: reward = -1401.98
22:13:38 [INFO] episode train 13: reward = -1045.42
22:13:44 [INFO] episode train 14: reward = -1126.07
22:13:51 [INFO] episode train 15: reward = -1230.00
22:13:58 [INFO] episode train 16: reward = -939.40
22:14:04 [INFO] episode train 17: reward = -986.38
22:14:11 [INFO] episode train 18: reward = -14

KeyboardInterrupt: 

In [39]:
env.close()

In [None]:
# 推論用に eval モードにしておく（保存自体は train のままでも可）
agent.mode2eval()

from datetime import datetime

stamp = datetime.now().strftime("%Y%m%d_%H%M%S")

agent.save_all(
    "./models/actor_critic_final_" + stamp + ".pth",
    extra={
        "rollout_num": rollout_num,
        "rollout_len": rollout_len,
        "reward_history": rh,  # 要らなければ外してOK
    }
)
print("saved to actor_critic_final.pth")