# 自作のActor-Criticノートブック

In [None]:
import numpy as np
import copy
from dataclasses import dataclass, asdict, is_dataclass

import sys
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

import gymnasium as gym

from myActivator import tanhAndScale
from myFunction import make_squashed_gaussian

In [None]:
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s [%(levelname)s] %(message)s",
                    stream=sys.stdout, datefmt="%H:%M:%S")

In [None]:
env = gym.make("Pendulum-v1",render_mode="human")
for key in vars(env.spec):
    logging.info('%s: %s', key, vars(env.spec)[key])
for key in vars(env.unwrapped):
    logging.info('%s: %s', key, vars(env.unwrapped)[key])

In [None]:
@dataclass
class Config:
    V_net_sizes = [6,12,12,6]
    P_net_sizes = [6,12,12,6]
    V_net_in = 3
    P_net_in = 3
    V_net_out = 1
    P_net_out = 2

    V_lr = 1e-3
    P_lr = 1e-3

    u_high = 2.0
    u_low = -2.0

    gamma = 0.99

In [None]:
class ActorCriticAgent:
    def __init__(self,Config,device=None):
        if Config:
            self.Config = Config
        else:
            raise ValueError("No Config!!")
        
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        self.u_high = torch.as_tensor(Config.u_high, dtype=torch.float32, device=self.device)
        self.u_low = torch.as_tensor(Config.u_low, dtype=torch.float32, device=self.device)
        
        self.V_net = self.build_net(
            Config.V_net_in,
            Config.V_net_sizes,
            Config.V_net_out
        ).to(self.device)
        self.V_net.train()

        self.P_net = self.build_net(
            Config.P_net_in,
            Config.P_net_sizes,
            Config.P_net_out
        ).to(self.device)
        self.P_net.train()

        self.V_optim = optim.Adam(self.V_net.parameters(),Config.V_lr)
        self.P_optim = optim.Adam(self.P_net.parameters(),Config.P_lr)

    
    def to(self,device):
        self.device = torch.device(device)
        self.V_net.to(self.device)
        self.P_net.to(self.device)
        return self


    def build_net(self,input_size,hidden_sizes,output_size=1,output_activator=None):
        layers = []
        for input_size, output_size in zip([input_size]+hidden_sizes, hidden_sizes+[output_size]):
            layers.append(nn.Linear(input_size,output_size))
            layers.append(nn.ReLU())
        layers = layers[:-1]
        if output_activator:
            layers.append(output_activator)
        net = nn.Sequential(*layers)
        return net
    

    @torch.no_grad()
    def step(self,state):
        state = torch.as_tensor(state,dtype=torch.float32,device=self.device)
        out = self.P_net(state)
        mu, log_std = torch.chunk(out, 2, dim=-1)
        std = torch.exp(log_std)
        dist = make_squashed_gaussian(mu=mu,std=std,low=self.u_low,high=self.u_high)
        action = dist.rsample()
        return action.squeeze(0).cpu().numpy()
    

    def save_all(self,path:str,extra:dict|None=None):
        cfg = asdict(self.Config) if is_dataclass(self.Config) else self.Config
        ckpt = {
            "Config":cfg,
            "V_net":self.V_net.state_dict(),
            "P_net":self.P_net.state_dict(),
        }
        if extra is not None:
            ckpt["extra"] = extra
        
        torch.save(ckpt,path)

    
    def load_all(self,path:str,map_location=None):
        ckpt = torch.load(path,map_location=map_location)
        self.V_net.load_state_dict(ckpt["V_net"])
        self.P_net.load_state_dict(ckpt["P_net"])

        return ckpt.get("extra",None)
    

    def mode2eval(self):
        self.V_net.eval()
        self.P_net.eval()


    def mode2train(self):
        self.V_net.train()
        self.P_net.train()
    

    def update_net_batch(self,states,actions,rewards,states_next,dones):
        states = torch.as_tensor(states,dtype=torch.float32,device=self.device)
        actions = torch.as_tensor(actions,dtype=torch.float32,device=self.device)
        rewards = torch.as_tensor(rewards,dtype=torch.float32,device=self.device)
        states_next = torch.as_tensor(states_next,dtype=torch.float32,device=self.device)

        if rewards.dim() == 1:
            rewards = rewards.unsqueeze(1)
        
        if dones is None:
            dones = torch.zeros((states.shape[0], 1), dtype=torch.float32, device=self.device)
        else:
            dones = torch.as_tensor(dones, dtype=torch.float32, device=self.device)
            if dones.dim() == 1:
                dones = dones.unsqueeze(1)

        with torch.no_grad():
            y_targets = rewards+self.Config.gamma*(1-dones)*self.V_net(states_next)

        V_values = self.V_net(states)
        V_loss = F.mse_loss(y_targets,V_values)
        self.V_optim.zero_grad()
        V_loss.backward()
        self.V_optim.step()

        for p in self.V_net.parameters():
            p.requires_grad_(False)

        outs = self.P_net(states)
        mus, log_stds = torch.chunk(outs, 2, dim=-1)
        stds = torch.exp(log_stds)
        dists = make_squashed_gaussian(mu=mus,std=stds,low=self.u_low,high=self.u_high)
        advantages = y_targets - V_values
        P_loss = -(advantages*dists.log_prob(actions)).mean()
        self.P_optim.zero_grad()
        P_loss.backward()
        self.P_optim.step()

        for p in self.V_net.parameters():
            p.requires_grad_(True)

        return float(V_loss.item()), float(P_loss.item())

In [None]:
def train(
        env,
        agent,
        rollout_num=200,
        rollout_len=256,
):
    print("cuda available:", torch.cuda.is_available())
    print("agent device:", agent.device)
    print("P_net device:", next(agent.V_net.parameters()).device)
    print("Q_net device:", next(agent.P_net.parameters()).device)

    reward_history = []

    state, info = env.reset()

    for r in range(rollout_num):
        # ---- rollout バッファ（長さ rollout_len）を毎回作り直す ----
        states = []
        actions = []
        rewards = []
        states_next = []
        dones = []
        
        reward_log = 0

        for t in range(rollout_len):
            with torch.no_grad():
                action = agent.step(state)
            state_next, reward, terminated, truncated, info = env.step(action)
            done = terminated

            states.append(np.atleast_1d(state).tolist())
            actions.append(np.atleast_1d(action).tolist())
            rewards.append(np.atleast_1d(reward).tolist())
            states_next.append(np.atleast_1d(state_next).tolist())
            dones.append(np.atleast_1d(done).tolist())

            if terminated or truncated:
                state, info = env.reset()
            else:
                state = state_next

            reward_log += reward
        
        V_loss, P_loss = agent.update_net_batch(states,actions,rewards,states_next,dones)

        reward_history.append(reward_log)

        logging.info('rollout %d: reward = %.2f', r, reward_log)

    return reward_history

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = ActorCriticAgent(Config=Config(),device=device)

rollout_num=200
rollout_len=256

rh = train(
    env=env,
    agent=agent,
    rollout_num=rollout_num,
    rollout_len=rollout_len,
)

cuda available: True
agent device: cuda
P_net device: cuda:0
Q_net device: cuda:0


ALSA lib confmisc.c:855:(parse_card) cannot find card '0'
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_card_inum returned error: No such file or directory
ALSA lib confmisc.c:422:(snd_func_concat) error evaluating strings
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_concat returned error: No such file or directory
ALSA lib confmisc.c:1334:(snd_func_refer) error evaluating name
ALSA lib conf.c:5178:(_snd_config_evaluate) function snd_func_refer returned error: No such file or directory
ALSA lib conf.c:5701:(snd_config_expand) Evaluate error: No such file or directory
ALSA lib pcm.c:2664:(snd_pcm_open_noupdate) Unknown PCM default


KeyboardInterrupt: 

In [None]:
env.close()

In [None]:
# 推論用に eval モードにしておく（保存自体は train のままでも可）
agent.mode2eval()

agent.save_all(
    "./models/actor_critic_final.pth",
    extra={
        "rollout_num": rollout_num,
        "rollout_len": rollout_len,
        "reward_history": rh,  # 要らなければ外してOK
    }
)
print("saved to actor_critic_final.pth")