In [1]:
%matplotlib inline
%load_ext tensorboard

In [2]:
import gymnasium as gym

from stable_baselines3 import SAC, TD3

from gymnasium.envs.registration import register, registry
from stable_baselines3.common.evaluation import evaluate_policy
import time

import matplotlib
import matplotlib.pyplot as plt

import torch

2025-01-30 06:53:57.619329: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-30 06:53:57.645031: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738220037.667497   17372 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738220037.674500   17372 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-30 06:53:57.698029: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [3]:
if 'MarineEnv-v0' not in registry:
    register(
        id='MarineEnv-v0',
        entry_point='environments:MarineEnv',  # String reference to the class
    )

In [4]:
env_kwargs = dict(
    render_mode='rgb_array',
    continuous=True,
    training_stage=2,
    timescale=1/6
)

In [5]:
env = gym.make('MarineEnv-v0', **env_kwargs)

In [6]:
sac_kwargs = {
    "policy": "MlpPolicy",  # Using MLP since ASV states are structured data
    "learning_rate": 3e-4,  # Stable baseline learning rate for SAC
    "buffer_size": int(1e6),  # Large buffer for off-policy learning
    "learning_starts": 10000,  # Start training after collecting enough data
    "batch_size": 256,  # Larger batch size stabilizes training
    "tau": 0.005,  # Polyak averaging coefficient for soft target updates
    "gamma": 0.99,  # Discount factor for long-term planning
    "train_freq": (1, "step"),  # Update every step
    "gradient_steps": 1,  # Number of gradient steps per update
    "action_noise": None,  # No action noise (SAC handles exploration via entropy)
    "replay_buffer_class": None,  # Use default SAC replay buffer
    "optimize_memory_usage": False,  # Use default memory mode
    "ent_coef": "auto_0.1",  # Automatic entropy tuning with initial value 0.1
    "target_update_interval": 1,  # Update target network every step
    "target_entropy": "auto",  # Automatic entropy tuning
    "use_sde": True,  # Use State-Dependent Exploration (better for continuous control)
    "sde_sample_freq": 64,  # Sample a new noise matrix every 64 steps
    "use_sde_at_warmup": True,  # Use SDE from the start
    "tensorboard_log": "./tensorboard_sac_asv/",  # Log for debugging
    "policy_kwargs": {
        "net_arch": [256, 256],  # Sufficient for complex ASV decisions
        "activation_fn": torch.nn.ReLU,  # ReLU activation for stability
        "log_std_init": -2,  # Reduce initial log standard deviation
    },
    "verbose": 1,  # Show training progress
    "device": "auto",  # Use GPU if available
}


In [7]:
agent = SAC(
    env=env,
    **sac_kwargs,
)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [17]:
agent.learn(total_timesteps=1e5, reset_num_timesteps=False, progress_bar=True, tb_log_name='sac_3')

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 341      |
|    ep_rew_mean     | 3.42e+03 |
| time/              |          |
|    episodes        | 548      |
|    fps             | 3        |
|    time_elapsed    | 29993    |
|    total_timesteps | 190391   |
| train/             |          |
|    actor_loss      | -588     |
|    critic_loss     | 58.8     |
|    ent_coef        | 0.202    |
|    ent_coef_loss   | 0.244    |
|    learning_rate   | 0.0003   |
|    n_updates       | 180390   |
|    std             | 0.131    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 337      |
|    ep_rew_mean     | 3.39e+03 |
| time/              |          |
|    episodes        | 552      |
|    fps             | 3        |
|    time_elapsed    | 30081    |
|    total_timesteps | 191410   |
| train/             |          |
|    actor_loss      | -589     |
|    critic_lo

<stable_baselines3.sac.sac.SAC at 0x7f3d951a6910>

In [18]:
eval_env = gym.make('MarineEnv-v0', **env_kwargs)
mean, std = evaluate_policy(model=agent, env=eval_env, n_eval_episodes=10, deterministic=True)
print(f'Mean: {mean:.2f}, Std: {std:.2f}')

Mean: 3878.65, Std: 2617.48


In [8]:
%tensorboard --logdir ./tensorboard_sac_asv/ --host=0.0.0.0

In [16]:
agent.save('sac_3')
# agent = agent.load('sac_stage_2', device='cuda')

In [None]:
timescale = 1 / 10
for _ in range(1):
    env = gym.make('MarineEnv-v0', render_mode='human', continuous=True, training_stage=2, timescale=timescale, training=False)
    state, _ = env.reset()
    print(state)
    episode_rewards = 0 
    # flatten_state = flatten(env.observation_space, state)
    # state = torch.tensor(flatten_state, dtype=torch.float32, device=device).unsqueeze(0)
    for _ in range(int(400 / timescale)):
        action = agent.predict(state, deterministic=True)
        # print(action)
        # observation, reward, terminated, truncated, info = env.step((0, 0))
        observation, reward, terminated, truncated, info = env.step(action[0])
        env.render()
        # time.sleep(0.001)
        episode_rewards += reward
        print('===========================')
        print(observation)
        print(reward)
        
        if terminated or truncated:
            print(episode_rewards)
            break
    
        state = observation
            
    print(episode_rewards)
    print(state)
    env.close()

[303.61487    18.30973    13.876236   45.47168     2.7667515  45.47168
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.       ]
[305.14658    18.343014   13.845651   45.289127    1.2377588  45.37168
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.       ]
4.305852890014648
[305.79556    18.379648   13.814999   45.09879     0.5900769  45.271683
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.       ]
10.306520462036133
[3.0608978e+02 1.8413929e+01 1.3784288e+01 4.4914764e+01 2.9653478e-01
 4.5171684e+01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00]
10.30710220336914
[3.0621805e+02 1.8444532e+01 1.3753528e+01 4.4740173e+01 1.6864583e-01
 4.5071686e+01 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+