In [1]:
%matplotlib inline
%load_ext tensorboard

In [2]:
import gymnasium as gym
import numpy as np

from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

from gymnasium.envs.registration import register, registry
import time

import matplotlib
import matplotlib.pyplot as plt

import torch

In [3]:
if 'MarineEnv-v0' not in registry:
    register(
        id='MarineEnv-v0',
        entry_point='environments:MarineEnv',  # String reference to the class
    )

In [4]:
env_kwargs = dict(
    render_mode='rgb_array',
    continuous=True,
    max_episode_steps=1200,
    training_stage=2,
    timescale=1/3
)

In [5]:
env = make_vec_env(env_id="MarineEnv-v0", n_envs=1, env_kwargs=env_kwargs)

In [6]:
td3_kwargs = {
    "policy": "MlpPolicy",  # Multi-Layer Perceptron policy
    "learning_rate": 3e-4,  # Stable learning rate for TD3
    "buffer_size": int(1e6),  # Large replay buffer for off-policy learning
    "learning_starts": 10000,  # Start training after collecting enough samples
    "batch_size": 256,  # Larger batch size stabilizes updates
    "tau": 0.005,  # Polyak averaging coefficient for target networks
    "gamma": 0.99,  # Discount factor (high for long-term planning)
    "train_freq": (1, "step"),  # Train every step
    "gradient_steps": 1,  # One gradient update per environment step
    "action_noise": None,  # TD3 handles exploration differently
    "replay_buffer_class": None,  # Use default replay buffer
    "optimize_memory_usage": False,  # Avoid memory-efficient mode for stability
    "policy_delay": 2,  # Update actor less frequently than critics (TD3 trick)
    "target_policy_noise": 0.2,  # Target policy smoothing noise (TD3 trick)
    "target_noise_clip": 0.5,  # Clip noise to avoid instability
    "tensorboard_log": "./tensorboard_td3_asv/",  # TensorBoard logging path
    "policy_kwargs": {
        "net_arch": [256, 256],  # Deep enough for ASV navigation
        "activation_fn": torch.nn.ReLU,  # ReLU activation for stability
    },
    "verbose": 1,  # Print training updates
    "device": "auto",  # Use GPU if available
}


In [7]:
agent = TD3(env=env, **td3_kwargs)

In [9]:
agent.learn(total_timesteps=1e5, reset_num_timesteps=False, progress_bar=True, tb_log_name='td3_2')

In [11]:
eval_env = gym.make('MarineEnv-v0', **env_kwargs)
mean, std = evaluate_policy(model=agent, env=eval_env, n_eval_episodes=10, deterministic=True)
print(f'Mean: {mean:.2f}, Std: {std:.2f}')

In [None]:
agent.save('td3_asv')

In [8]:
%tensorboard --logdir ./tensorboard_td3_asv/ --host=0.0.0.0

In [12]:
timescale = 1 / 6
for _ in range(5):
    env = gym.make('MarineEnv-v0', render_mode='human', continuous=True, training_stage=2, timescale=timescale, training=False)
    state, _ = env.reset()
    print(state)
    episode_rewards = 0 
    # flatten_state = flatten(env.observation_space, state)
    # state = torch.tensor(flatten_state, dtype=torch.float32, device=device).unsqueeze(0)
    for _ in range(int(400 / timescale)):
        action = agent.predict(state, deterministic=True)
        # print(action)
        # observation, reward, terminated, truncated, info = env.step((0, 0))
        observation, reward, terminated, truncated, info = env.step(action[0])
        env.render()
        # time.sleep(0.001)
        episode_rewards += reward
        print('===========================')
        print(observation)
        print(reward)
        
        if terminated or truncated:
            print(episode_rewards)
            break
    
        state = observation
            
    print(episode_rewards)
    print(state)
    env.close()