In [None]:
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from Approach_env import SRC_approach
import numpy as np
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import HerReplayBuffer, TD3, PPO, SAC
from RL_algo.DDPG_BC import DDPG_BC
from RL_algo.td3_BC import TD3_BC
from RL_algo.DemoHerReplayBuffer import DemoHerReplayBuffer
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.noise import NormalActionNoise
import time
import pickle
# Create environment

seed = 10
set_random_seed(seed)

with open('./Approach_noise_env_info', 'rb') as file:
    env_info = pickle.load(file)
    
step_size= np.array(env_info["step_size"], dtype=np.float32)
threshold = np.array(env_info["threshold"], dtype=np.float32)
episode_steps = int(env_info["max_timestep"])

# Register and make the env
gym.envs.register(id="TD3_HER_BC", entry_point=SRC_approach, max_episode_steps=episode_steps)
env = gym.make("TD3_HER_BC", render_mode="human",reward_type = "dense",max_episode_step=episode_steps,seed = seed, step_size=step_size,threshold=threshold)

import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn

In [None]:
# Check the environment
check_env(env)

In [None]:
with open('./Expert_traj/Approach/all_episodes_merged.pkl', 'rb') as file:
    episode_transitions = pickle.load(file)

In [None]:
episode_transitions = None

In [None]:
# Set up the hyperparameters
goal_selection_strategy = "future"
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=5e-2 * np.ones(n_actions))


model = TD3_BC(
    "MultiInputPolicy",
    env,
    learning_rate=3e-4,
    learning_starts=600,
    tau = 0.005,
    gamma = 0.995,
    batch_size=512,
    action_noise=action_noise,
    replay_buffer_class=DemoHerReplayBuffer,
    train_freq = (3, "episode"),
    policy_kwargs = dict(net_arch=dict(pi=[256, 256, 256], qf=[256, 256, 256])),
    # Parameters for HER
    replay_buffer_kwargs=dict(
        demo_transitions=episode_transitions, 
        demo_sample_ratio=0.2,
        n_sampled_goal=4,
        goal_selection_strategy=goal_selection_strategy,
    ),
    verbose=1,
    tensorboard_log="./Approach/TD3_BC_noise_dense",
    episode_transitions=episode_transitions,
    BC_coeff=0.3,
    demo_ratio=0.1,
)

model_path = "./Approach/TD3_BC_noise_dense/rl_model_final.zip"
model = TD3_BC.load(model_path,env=env)#

checkpoint_callback = CheckpointCallback(save_freq=5000, save_path='./Approach/TD3_BC_noise_dense', name_prefix='rl_model')

In [None]:
# Train the model
model.learn(total_timesteps=int(300000), progress_bar=True,callback=checkpoint_callback,reset_num_timesteps=False)

In [None]:
# Save the model
model.save("./Approach/TD3_BC_noise_dense/rl_model_final")

In [None]:
# Predict the action with model
obs,info = env.reset()
print(obs)
for i in range(90000):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    time.sleep(0.01)
    if terminated or truncated:
        obs, info = env.reset()