In [3]:
%matplotlib inline
%load_ext tensorboard

In [2]:
import gymnasium as gym
import numpy as np

from stable_baselines3 import A2C
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

from gymnasium.envs.registration import register, registry
import time

import matplotlib
import matplotlib.pyplot as plt

import torch

In [6]:
if 'MarineEnv-v0' not in registry:
    register(
        id='MarineEnv-v0',
        entry_point='environments:MarineEnv',  # String reference to the class
    )

In [7]:
env_kwargs = dict(
    render_mode='rgb_array',
    continuous=True,
    max_episode_steps=1200,
    training_stage=2,
    timescale=1/3
)

In [8]:
env = make_vec_env(env_id="MarineEnv-v0", n_envs=1, env_kwargs=env_kwargs)

In [12]:
a2c_kwargs = {
    "policy": "MlpPolicy",  # Multi-layer perceptron (MLP) for structured data
    "learning_rate": 7e-4,  # Standard learning rate for A2C (stable)
    "n_steps": 5,  # Number of steps per update (balancing bias-variance)
    "gamma": 0.99,  # Discount factor for long-term decision-making
    "gae_lambda": 0.95,  # Generalized advantage estimation (bias-variance tradeoff)
    "ent_coef": 0.01,  # Encourages exploration
    "vf_coef": 0.5,  # Value function loss coefficient
    "max_grad_norm": 0.5,  # Gradient clipping for stability
    "use_rms_prop": True,  # RMSprop is better for A2C (instead of Adam)
    "rms_prop_eps": 1e-5,  # Stabilizing factor for RMSprop
    "use_sde": False,  # A2C does not benefit much from state-dependent noise
    "normalize_advantage": True,  # Normalizing advantage speeds up learning
    "tensorboard_log": "./tensorboard_a2c_asv/",  # Logging path
    "policy_kwargs": {
        "net_arch": [256, 256],  # Good balance between complexity and efficiency
        "activation_fn": torch.nn.ReLU,  # Stable and efficient activation
    },
    "verbose": 1,  # Print training updates
    "device": "cpu", 
}


In [13]:
agent = A2C(env=env, **a2c_kwargs)

Using cpu device


In [21]:
agent.learn(total_timesteps=1e5, reset_num_timesteps=False, progress_bar=True, tb_log_name='a2c_1')

Logging to ./tensorboard_a2c_asv/a2c_1_0


Output()

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 1.2e+03   |
|    ep_rew_mean        | -2.21e+03 |
| time/                 |           |
|    fps                | 202       |
|    iterations         | 100       |
|    time_elapsed       | 2         |
|    total_timesteps    | 60500     |
| train/                |           |
|    entropy_loss       | -2.85     |
|    explained_variance | 0.438     |
|    learning_rate      | 0.0007    |
|    n_updates          | 12099     |
|    policy_loss        | 0.0111    |
|    std                | 1.01      |
|    value_loss         | 1.49e+15  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 1.2e+03   |
|    ep_rew_mean        | -2.21e+03 |
| time/                 |           |
|    fps                | 141       |
|    iterations         | 200       |
|    time_elapsed       | 7         |
|    total_t

<stable_baselines3.a2c.a2c.A2C at 0x7f8e34d41590>

In [18]:
eval_env = gym.make('MarineEnv-v0', **env_kwargs)
mean, std = evaluate_policy(model=agent, env=eval_env, n_eval_episodes=10, deterministic=True)
print(f'Mean: {mean:.2f}, Std: {std:.2f}')

Mean: -3043.42, Std: 2111.12


In [19]:
agent.save('a2c_asv')

In [20]:
timescale = 1 / 6
for _ in range(5):
    env = gym.make('MarineEnv-v0', render_mode='human', continuous=True, training_stage=2, timescale=timescale, training=False)
    state, _ = env.reset()
    print(state)
    episode_rewards = 0 
    # flatten_state = flatten(env.observation_space, state)
    # state = torch.tensor(flatten_state, dtype=torch.float32, device=device).unsqueeze(0)
    for _ in range(int(400 / timescale)):
        action = agent.predict(state, deterministic=True)
        # print(action)
        # observation, reward, terminated, truncated, info = env.step((0, 0))
        observation, reward, terminated, truncated, info = env.step(action[0])
        env.render()
        # time.sleep(0.001)
        episode_rewards += reward
        print('===========================')
        print(observation)
        print(reward)
        
        if terminated or truncated:
            print(episode_rewards)
            break
    
        state = observation
            
    print(episode_rewards)
    print(state)
    env.close()

[250.65114   10.028007  12.014323  71.88461   -4.830954  71.88461
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.      ]
[253.98447    10.1113405  11.986503   71.12708    -8.183369   71.71795
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.       ]
33.37820587158203
[257.3178    10.194674  11.958737  70.38226  -11.543806  71.551285
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.      ]
33.37765274047852
[260.65115   10.278007  11.931128  69.65044  -14.912376  71.38462
   0.         0.         0.         0.         0.         0.
   0.         0.         0.         0.      ]
33.376098251342775
[263.9845     10.3613405  11.903778   68.931885  -18.289125   71.21796
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.       ]
31.773494720458984
[267.31784   10.444674  11.87679

AttributeError: 'NoneType' object has no attribute 'fill'