In [14]:
import gymnasium as gym
import numpy as np
from pathlib import Path
import os
import sys
import torch as th
import argparse

from stable_baselines3 import HerReplayBuffer, SAC
from stable_baselines3.common.buffers import DictReplayBuffer
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import configure, Logger
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
import flycraft
from flycraft.utils.load_config import load_config

import sys
from pathlib import Path
import os
current_dir = os.getcwd()
print(current_dir)
PROJECT_ROOT_DIR = Path(current_dir).parent.parent.parent
if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from utils_my.sb3.my_eval_callback import MyEvalCallback
from utils_my.sb3.my_evaluate_policy import evaluate_policy_with_success_rate
from train_scripts.D2D.utils.get_vec_env import get_vec_env
from train_scripts.D2D.utils.load_data_from_csv import load_random_trajectories_from_csv_files,load_random_transitions_from_csv_files
from train_scripts.D2D.utils.InfoDictReplayBuffer import InfoDictReplayBuffer
from utils_my.sb3.my_wrappers import ScaledObservationWrapper, ScaledActionWrapper
import pathlib
import warnings
warnings.filterwarnings("ignore")  # 过滤Gymnasium的UserWarning
gym.register_envs(flycraft)

from stable_baselines3.common.logger import configure




train_config = load_config(PROJECT_ROOT_DIR/ "configs/train/D2D/F2F/medium/b_1/two_stage_skip_3_skip_1/2e6/sac_config_10hz_128_128_1.json")



NET_ARCH = train_config["rl_common"]["net_arch"]
GAMMA = train_config["rl_common"].get("gamma", 0.995)
BUFFER_SIZE = train_config["rl_common"].get("buffer_size", 1e6)
BATCH_SIZE = train_config["rl_common"].get("batch_size", 1024)
RL_TRAIN_PROCESS_NUM = train_config["rl_common"].get("rollout_process_num", 32)
RL_EVALUATE_PROCESS_NUM = train_config["rl_common"].get("evaluate_process_num", 32)
CALLBACK_PROCESS_NUM = train_config["rl_common"].get("callback_process_num", 32)
# GRADIENT_STEPS = train_config["rl_common"].get("gradient_steps", 2)
EVAL_FREQ = train_config["rl_common"].get("eval_freq", 1000)
N_EVAL_EPISODES = train_config["rl_common"].get("n_eval_episodes", CALLBACK_PROCESS_NUM*10)
USE_HER = train_config["rl_common"].get("use_her", False)
learning_rate = 3e-4

env_config_in_train = {
"num_process": 32,
"seed": 183,
"config_file": str(PROJECT_ROOT_DIR / "configs" / "env" / "D2D/env_config_for_sac_medium_b_1.json"),
# "custom_config": {"debug_mode": True, "flag_str": "Callback"}

}





vec_env = get_vec_env(
            **env_config_in_train
        )








sac_algo = SAC(
    "MultiInputPolicy",
    vec_env,
    seed=245,
    replay_buffer_class=HerReplayBuffer if USE_HER else InfoDictReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy="future",
    ) if USE_HER else None,
    verbose=1,
    buffer_size=int(BUFFER_SIZE),
    learning_starts=int(0),
    gradient_steps=int(20000),
    learning_rate=learning_rate,
    gamma=GAMMA,
    batch_size=int(BATCH_SIZE),
    policy_kwargs=dict(
        net_arch=NET_ARCH,
        activation_fn=th.nn.Tanh
    ),
)
# 设置 logger（保存路径 / 输出模式等）
sac_algo.set_logger(configure(folder="/home/sen/pythonprojects/fly-craft-examples/checkpoints/D2D/F2F/medium/b_1/skip_3_1_test_warmup", format_strings=["stdout", "csv"]))
def eval(policy_list):
    
    average_reward = 0.0
    average_std = 0.0
    average_successrate = 0.0 
    for index, policy in enumerate(policy_list):
        sac = SAC.load(policy)
        eval_reward, std_reward, eval_success_rate = evaluate_policy_with_success_rate(sac.policy, vec_env, 1000)
        print(f"seed{index+1}  ：",f"mean_reward={eval_reward:.2f} +/- {std_reward}")
        average_reward += eval_reward
        average_std += std_reward
        average_successrate +=eval_success_rate
    return average_reward/len(policy_list),average_std/len(policy_list), average_successrate/len(policy_list)


sac_algo.load_replay_buffer("/home/sen/pythonprojects/fly-craft-examples/checkpoints/D2D/F2F/medium/b_1/skip_3/sac_128_128_b_1_1e6steps_skip_3_seed_1_singleRL/replay_buffer")

sac_algo.train(batch_size=1024,gradient_steps=20000)

sac_algo.save(str(PROJECT_ROOT_DIR / "checkpoints/D2D/F2F/medium/b_1/skip_3_1_test_warmup"/ "final_model"))
sac_algo.save_replay_buffer(str(PROJECT_ROOT_DIR / "checkpoints/D2D/F2F/medium/b_1/skip_3_1_test_warmup"  / "final_buffer"))

policy_list = ["/home/sen/pythonprojects/fly-craft-examples/checkpoints/D2D/F2F/medium/b_1/skip_3_1_test_warmup/final_model.zip"]

reward,std_reward,success_rate  = eval(policy_list)
print(f"mean_reward={reward:.2f} +/- {std_reward}", f"success_rate = {success_rate}")



/home/sen/pythonprojects/fly-craft-examples/train_scripts/D2D/trys
load config from: /home/sen/pythonprojects/fly-craft-examples/configs/env/D2D/env_config_for_sac_medium_b_1.json
190 Generator(PCG64) Generator(PCG64)
load config from: /home/sen/pythonprojects/fly-craft-examples/configs/env/D2D/env_config_for_sac_medium_b_1.json
186 Generator(PCG64) Generator(PCG64)
load config from: /home/sen/pythonprojects/fly-craft-examples/configs/env/D2D/env_config_for_sac_medium_b_1.json
196 Generator(PCG64) Generator(PCG64)
load config from: /home/sen/pythonprojects/fly-craft-examples/configs/env/D2D/env_config_for_sac_medium_b_1.json
200 Generator(PCG64) Generator(PCG64)
load config from: /home/sen/pythonprojects/fly-craft-examples/configs/env/D2D/env_config_for_sac_medium_b_1.json
188 Generator(PCG64) Generator(PCG64)
load config from: /home/sen/pythonprojects/fly-craft-examples/configs/env/D2D/env_config_for_sac_medium_b_1.jsonload config from: /home/sen/pythonprojects/fly-craft-examples/conf

In [8]:
def eval(policy_list):
    
    average_reward = 0.0
    average_std = 0.0
    average_successrate = 0.0 
    for index, policy in enumerate(policy_list):
        sac = SAC.load(policy)
        eval_reward, std_reward, eval_success_rate = evaluate_policy_with_success_rate(sac.policy, vec_env, 1000)
        print(f"seed{index+1}  ：",f"mean_reward={eval_reward:.2f} +/- {std_reward}")
        average_reward += eval_reward
        average_std += std_reward
        average_successrate +=eval_success_rate
    return average_reward/len(policy_list),average_std/len(policy_list), average_successrate/len(policy_list)