In [None]:
!pip install --upgrade grpcio grpcio-tools

In [1]:
%matplotlib inline
%load_ext tensorboard

In [2]:
import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces.utils import flatten
from gymnasium.envs.registration import register, registry
import time
import numpy as np
import pygame

import matplotlib
import matplotlib.pyplot as plt

from typing import Any, Dict
import torch
import torch.nn as nn
import tensorboard

from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv

import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from optuna.visualization import plot_optimization_history, plot_param_importances

2025-01-26 22:48:06.223091: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-26 22:48:06.239018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1737931686.263640     455 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1737931686.270009     455 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-26 22:48:06.291460: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [3]:
if 'MarineEnv-v0' not in registry:
    register(
        id='MarineEnv-v0',
        entry_point='environments:MarineEnv',  # String reference to the class
    )

In [None]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
N_TRIALS = 100  # Maximum number of trials
N_JOBS = 1 # Number of jobs to run in parallel
N_STARTUP_TRIALS = 5  # Stop random sampling after N_STARTUP_TRIALS
N_EVALUATIONS = 2  # Number of evaluations during the training
N_TIMESTEPS = int(2e4)  # Training budget
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_ENVS = 10
N_EVAL_EPISODES = 10
TIMEOUT = int(60 * 15)  # 15 minutes

ENV_ID = 'MarineEnv-v0'

DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "env": ENV_ID,
}

In [None]:
def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
    
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1, log=True)  # Learning rate (log scale)
    
    n_steps = 2 ** trial.suggest_int('n_steps', 7, 12)  # Number of steps per update (512-4096)
    
    batch_size = 2 ** trial.suggest_int('batch_size', 5, 10)  # Minibatch size (32-1024)
    
    gamma = trial.suggest_float('gamma', 0.9, 0.9999)  # Discount factor (close to 1 for long-term rewards)
    
    gae_lambda = trial.suggest_float('gae_lambda', 0.8, 1.0)  # GAE lambda (trade-off bias/variance)
    
    clip_range = trial.suggest_float('clip_range', 0.1, 0.3)  # PPO clipping range
    
    ent_coef = trial.suggest_float('ent_coef', 0.0001, 0.1, log=True)  # Entropy coefficient (for exploration)
    
    vf_coef = trial.suggest_float('vf_coef', 0.1, 1.0)  # Value function loss coefficient
    
    max_grad_norm = trial.suggest_float('max_grad_norm', 0.3, 5.0)  # Gradient clipping
    
    target_kl = trial.suggest_float('target_kl', 0.01, 0.2)  # KL divergence target
    
    n_epochs = trial.suggest_int('n_epochs', 3, 10)  # PPO update epochs per batch
    
    activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu'])
    
    net_arch = trial.suggest_categorical('net_arch', ['tiny', 'small'])
    
    # Convert architecture choices
    net_arch = [128, 128] if net_arch == 'tiny' else [256, 256, 256]
    
    activation_fn = {'tanh': nn.Tanh, 'relu': nn.ReLU}[activation_fn]
    
    # Store gamma value in Optuna logs
    trial.set_user_attr('gamma', gamma)

    return {
        'n_steps': n_steps,
        'batch_size': batch_size,
        'gamma': gamma,
        'gae_lambda': gae_lambda,
        'learning_rate': learning_rate,
        'clip_range': clip_range,
        'ent_coef': ent_coef,
        'vf_coef': vf_coef,
        'max_grad_norm': max_grad_norm,
        'target_kl': target_kl,
        'n_epochs': n_epochs,
        'policy_kwargs': {
            'net_arch': net_arch,
            'activation_fn': activation_fn
        }
    }

In [None]:
class TrialEvalCallback(EvalCallback):
    """
    Callback used for evaluating and reporting a trial.
    
    :param eval_env: Evaluation environement
    :param trial: Optuna trial object
    :param n_eval_episodes: Number of evaluation episodes
    :param eval_freq:   Evaluate the agent every ``eval_freq`` call of the callback.
    :param deterministic: Whether the evaluation should
        use a stochastic or deterministic policy.
    :param verbose:
    """

    def __init__(
        self,
        eval_env: gym.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
    ):

        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            # Evaluate policy (done in the parent class)
            super()._on_step()
            self.eval_idx += 1
            # Send report to Optuna
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function using by Optuna to evaluate
    one configuration (i.e., one set of hyperparameters).

    Given a trial object, it will sample hyperparameters,
    evaluate it and report the result (mean episodic reward after training)

    :param trial: Optuna trial object
    :return: Mean episodic reward after training
    """

    kwargs = DEFAULT_HYPERPARAMS.copy()

    # 1. Sample hyperparameters and update the keyword arguments
    kwargs.update(**sample_ppo_params(trial))
    print(kwargs)
    # Create the RL model
    model = PPO(device='cpu', verbose=1, **kwargs)
    # Create eval envs
    eval_envs = make_vec_env(ENV_ID, n_envs=N_EVAL_ENVS)

    eval_callback = TrialEvalCallback(eval_envs, trial, N_EVAL_EPISODES, EVAL_FREQ, deterministic=True, verbose=0)

    nan_encountered = False
    try:
        # Train the model
        model.learn(N_TIMESTEPS, callback=eval_callback, progress_bar=True)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN
        print(e)
        nan_encountered = True
    finally:
        # Free memory
        model.env.close()
        eval_envs.close()

    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward

In [None]:
# Set pytorch num threads to 1 for faster training
torch.set_num_threads(1)
# Select the sampler, can be random, TPESampler, CMAES, ...
sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)
# Do not prune before 1/3 of the max budget is used
pruner = MedianPruner(
    n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3
)
# Create the study and start the hyperparameter optimization
study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")

try:
    study.optimize(objective, n_trials=N_TRIALS, n_jobs=N_JOBS, timeout=TIMEOUT)
except KeyboardInterrupt:
    pass

print("Number of finished trials: ", len(study.trials))

print("Best trial:")
trial = study.best_trial

print(f"  Value: {trial.value}")

print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

print("  User attrs:")
for key, value in trial.user_attrs.items():
    print(f"    {key}: {value}")

# Write report
study.trials_dataframe().to_csv("study_results_ppo_marineenv.csv")

fig1 = plot_optimization_history(study)
fig2 = plot_param_importances(study)

fig1.show()
fig2.show()

In [None]:
# Create the environment
def make_env():
    env = gym.make('MarineEnv-v0', render_mode='rgb_array', continuous=True, max_episode_steps=400)
    env = Monitor(env)  # ✅ Apply Monitor FIRST before vectorization
    return env

# Wrap it in `DummyVecEnv` FIRST
env = DummyVecEnv([make_env])  

# Now apply normalization
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)

In [4]:
env = gym.make('MarineEnv-v0', render_mode='rgb_array', continuous=True, max_episode_steps=400, training_stage=2)
# vec_env = make_vec_env('MarineEnv-v0', n_envs=4)

In [5]:
kwargs = { 
    'clip_range': 0.3,
    'ent_coef': 0.005,
    'gamma': 0.99, 
    'learning_rate': 1e-4, 
    'max_grad_norm': 0.99, 
    'policy_kwargs': {'net_arch': [128, 128], 'activation_fn': torch.nn.Tanh},
}

In [6]:
model = PPO(
    policy='MlpPolicy',
    env=env,
    verbose=1,
    device='cpu', 
    tensorboard_log='./stage_1_tensorboard_logs/',
    **kwargs
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [7]:
model.learn(total_timesteps=(1e5), progress_bar=True)

Scene:  overtaking
Logging to ./stage_1_tensorboard_logs/PPO_16


Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 285      |
|    ep_rew_mean     | 176      |
| time/              |          |
|    fps             | 688      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 290        |
|    ep_rew_mean          | 303        |
| time/                   |            |
|    fps                  | 390        |
|    iterations           | 2          |
|    time_elapsed         | 10         |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.00765695 |
|    clip_fraction        | 0.0168     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.83      |
|    explained_variance   | -0.00095   |
|    learning_rate        | 0.0001     |
|    loss                 | 494        |
|    n_updates            | 10         |
|    policy_gradient_loss | -0.0105    |
|    std                  | 0.995      |
|    value_loss           | 1.15e+03   |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 304         |
|    ep_rew_mean          | 339         |
| time/                   |             |
|    fps                  | 315         |
|    iterations           | 3           |
|    time_elapsed         | 19          |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.011565043 |
|    clip_fraction        | 0.0311      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.83       |
|    explained_variance   | 0.0358      |
|    learning_rate        | 0.0001      |
|    loss                 | 428         |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0153     |
|    std                  | 0.998       |
|    value_loss           | 998         |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 305         |
|    ep_rew_mean          | 355         |
| time/                   |             |
|    fps                  | 245         |
|    iterations           | 4           |
|    time_elapsed         | 33          |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.005724605 |
|    clip_fraction        | 0.00762     |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.84       |
|    explained_variance   | 0.0495      |
|    learning_rate        | 0.0001      |
|    loss                 | 527         |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.00968    |
|    std                  | 0.999       |
|    value_loss           | 987         |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 310        |
|    ep_rew_mean          | 371        |
| time/                   |            |
|    fps                  | 226        |
|    iterations           | 5          |
|    time_elapsed         | 45         |
|    total_timesteps      | 10240      |
| train/                  |            |
|    approx_kl            | 0.01037117 |
|    clip_fraction        | 0.0233     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.83      |
|    explained_variance   | 0.101      |
|    learning_rate        | 0.0001     |
|    loss                 | 393        |
|    n_updates            | 40         |
|    policy_gradient_loss | -0.0112    |
|    std                  | 0.996      |
|    value_loss           | 940        |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 322        |
|    ep_rew_mean          | 398        |
| time/                   |            |
|    fps                  | 203        |
|    iterations           | 6          |
|    time_elapsed         | 60         |
|    total_timesteps      | 12288      |
| train/                  |            |
|    approx_kl            | 0.00898819 |
|    clip_fraction        | 0.0143     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.82      |
|    explained_variance   | 0.103      |
|    learning_rate        | 0.0001     |
|    loss                 | 345        |
|    n_updates            | 50         |
|    policy_gradient_loss | -0.0153    |
|    std                  | 0.988      |
|    value_loss           | 738        |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 333         |
|    ep_rew_mean          | 453         |
| time/                   |             |
|    fps                  | 188         |
|    iterations           | 7           |
|    time_elapsed         | 76          |
|    total_timesteps      | 14336       |
| train/                  |             |
|    approx_kl            | 0.008701171 |
|    clip_fraction        | 0.0207      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.81       |
|    explained_variance   | 0.151       |
|    learning_rate        | 0.0001      |
|    loss                 | 476         |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.0141     |
|    std                  | 0.98        |
|    value_loss           | 968         |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 339         |
|    ep_rew_mean          | 494         |
| time/                   |             |
|    fps                  | 172         |
|    iterations           | 8           |
|    time_elapsed         | 94          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.012127824 |
|    clip_fraction        | 0.0269      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.79       |
|    explained_variance   | -0.0452     |
|    learning_rate        | 0.0001      |
|    loss                 | 391         |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0191     |
|    std                  | 0.976       |
|    value_loss           | 939         |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 342         |
|    ep_rew_mean          | 529         |
| time/                   |             |
|    fps                  | 161         |
|    iterations           | 9           |
|    time_elapsed         | 114         |
|    total_timesteps      | 18432       |
| train/                  |             |
|    approx_kl            | 0.010972433 |
|    clip_fraction        | 0.0216      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.79       |
|    explained_variance   | 0.0244      |
|    learning_rate        | 0.0001      |
|    loss                 | 446         |
|    n_updates            | 80          |
|    policy_gradient_loss | -0.0166     |
|    std                  | 0.978       |
|    value_loss           | 1.1e+03     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 346         |
|    ep_rew_mean          | 560         |
| time/                   |             |
|    fps                  | 148         |
|    iterations           | 10          |
|    time_elapsed         | 137         |
|    total_timesteps      | 20480       |
| train/                  |             |
|    approx_kl            | 0.011174509 |
|    clip_fraction        | 0.0339      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.79       |
|    explained_variance   | 0.00748     |
|    learning_rate        | 0.0001      |
|    loss                 | 597         |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.016      |
|    std                  | 0.975       |
|    value_loss           | 1.34e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 343         |
|    ep_rew_mean          | 563         |
| time/                   |             |
|    fps                  | 138         |
|    iterations           | 11          |
|    time_elapsed         | 162         |
|    total_timesteps      | 22528       |
| train/                  |             |
|    approx_kl            | 0.008499541 |
|    clip_fraction        | 0.032       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.79       |
|    explained_variance   | 0.105       |
|    learning_rate        | 0.0001      |
|    loss                 | 560         |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.0164     |
|    std                  | 0.974       |
|    value_loss           | 1.35e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 347         |
|    ep_rew_mean          | 585         |
| time/                   |             |
|    fps                  | 132         |
|    iterations           | 12          |
|    time_elapsed         | 186         |
|    total_timesteps      | 24576       |
| train/                  |             |
|    approx_kl            | 0.011443188 |
|    clip_fraction        | 0.0329      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.78       |
|    explained_variance   | 0.0471      |
|    learning_rate        | 0.0001      |
|    loss                 | 769         |
|    n_updates            | 110         |
|    policy_gradient_loss | -0.0154     |
|    std                  | 0.972       |
|    value_loss           | 1.53e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 348         |
|    ep_rew_mean          | 610         |
| time/                   |             |
|    fps                  | 123         |
|    iterations           | 13          |
|    time_elapsed         | 215         |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.013720594 |
|    clip_fraction        | 0.0511      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.78       |
|    explained_variance   | 0.158       |
|    learning_rate        | 0.0001      |
|    loss                 | 462         |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.026      |
|    std                  | 0.965       |
|    value_loss           | 967         |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 350         |
|    ep_rew_mean          | 626         |
| time/                   |             |
|    fps                  | 116         |
|    iterations           | 14          |
|    time_elapsed         | 246         |
|    total_timesteps      | 28672       |
| train/                  |             |
|    approx_kl            | 0.016406883 |
|    clip_fraction        | 0.0516      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.76       |
|    explained_variance   | 0.0356      |
|    learning_rate        | 0.0001      |
|    loss                 | 639         |
|    n_updates            | 130         |
|    policy_gradient_loss | -0.025      |
|    std                  | 0.956       |
|    value_loss           | 1.44e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 352         |
|    ep_rew_mean          | 663         |
| time/                   |             |
|    fps                  | 109         |
|    iterations           | 15          |
|    time_elapsed         | 279         |
|    total_timesteps      | 30720       |
| train/                  |             |
|    approx_kl            | 0.028834533 |
|    clip_fraction        | 0.12        |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.75       |
|    explained_variance   | 0.054       |
|    learning_rate        | 0.0001      |
|    loss                 | 461         |
|    n_updates            | 140         |
|    policy_gradient_loss | -0.0346     |
|    std                  | 0.958       |
|    value_loss           | 972         |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 354         |
|    ep_rew_mean          | 687         |
| time/                   |             |
|    fps                  | 104         |
|    iterations           | 16          |
|    time_elapsed         | 313         |
|    total_timesteps      | 32768       |
| train/                  |             |
|    approx_kl            | 0.016048599 |
|    clip_fraction        | 0.0551      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.75       |
|    explained_variance   | 0.0166      |
|    learning_rate        | 0.0001      |
|    loss                 | 812         |
|    n_updates            | 150         |
|    policy_gradient_loss | -0.0264     |
|    std                  | 0.958       |
|    value_loss           | 1.7e+03     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 356         |
|    ep_rew_mean          | 710         |
| time/                   |             |
|    fps                  | 100         |
|    iterations           | 17          |
|    time_elapsed         | 348         |
|    total_timesteps      | 34816       |
| train/                  |             |
|    approx_kl            | 0.021465305 |
|    clip_fraction        | 0.1         |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.75       |
|    explained_variance   | 0.0437      |
|    learning_rate        | 0.0001      |
|    loss                 | 781         |
|    n_updates            | 160         |
|    policy_gradient_loss | -0.0308     |
|    std                  | 0.955       |
|    value_loss           | 1.76e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 361         |
|    ep_rew_mean          | 752         |
| time/                   |             |
|    fps                  | 95          |
|    iterations           | 18          |
|    time_elapsed         | 387         |
|    total_timesteps      | 36864       |
| train/                  |             |
|    approx_kl            | 0.012898657 |
|    clip_fraction        | 0.036       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.74       |
|    explained_variance   | 0.226       |
|    learning_rate        | 0.0001      |
|    loss                 | 641         |
|    n_updates            | 170         |
|    policy_gradient_loss | -0.0239     |
|    std                  | 0.949       |
|    value_loss           | 1.47e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 363         |
|    ep_rew_mean          | 790         |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 19          |
|    time_elapsed         | 13833       |
|    total_timesteps      | 38912       |
| train/                  |             |
|    approx_kl            | 0.014423275 |
|    clip_fraction        | 0.0418      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.73       |
|    explained_variance   | 0.232       |
|    learning_rate        | 0.0001      |
|    loss                 | 1.1e+03     |
|    n_updates            | 180         |
|    policy_gradient_loss | -0.0219     |
|    std                  | 0.945       |
|    value_loss           | 2.18e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 365         |
|    ep_rew_mean          | 837         |
| time/                   |             |
|    fps                  | 1           |
|    iterations           | 20          |
|    time_elapsed         | 23089       |
|    total_timesteps      | 40960       |
| train/                  |             |
|    approx_kl            | 0.013646053 |
|    clip_fraction        | 0.0547      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.72       |
|    explained_variance   | 0.193       |
|    learning_rate        | 0.0001      |
|    loss                 | 609         |
|    n_updates            | 190         |
|    policy_gradient_loss | -0.0287     |
|    std                  | 0.945       |
|    value_loss           | 1.39e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 367         |
|    ep_rew_mean          | 886         |
| time/                   |             |
|    fps                  | 1           |
|    iterations           | 21          |
|    time_elapsed         | 23125       |
|    total_timesteps      | 43008       |
| train/                  |             |
|    approx_kl            | 0.020387396 |
|    clip_fraction        | 0.0674      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.72       |
|    explained_variance   | 0.123       |
|    learning_rate        | 0.0001      |
|    loss                 | 888         |
|    n_updates            | 200         |
|    policy_gradient_loss | -0.0284     |
|    std                  | 0.94        |
|    value_loss           | 1.77e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 370         |
|    ep_rew_mean          | 926         |
| time/                   |             |
|    fps                  | 1           |
|    iterations           | 22          |
|    time_elapsed         | 23160       |
|    total_timesteps      | 45056       |
| train/                  |             |
|    approx_kl            | 0.014259534 |
|    clip_fraction        | 0.0513      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.71       |
|    explained_variance   | 0.124       |
|    learning_rate        | 0.0001      |
|    loss                 | 900         |
|    n_updates            | 210         |
|    policy_gradient_loss | -0.0239     |
|    std                  | 0.939       |
|    value_loss           | 2.17e+03    |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 374        |
|    ep_rew_mean          | 963        |
| time/                   |            |
|    fps                  | 2          |
|    iterations           | 23         |
|    time_elapsed         | 23226      |
|    total_timesteps      | 47104      |
| train/                  |            |
|    approx_kl            | 0.01373112 |
|    clip_fraction        | 0.0455     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.71      |
|    explained_variance   | 0.178      |
|    learning_rate        | 0.0001     |
|    loss                 | 666        |
|    n_updates            | 220        |
|    policy_gradient_loss | -0.0226    |
|    std                  | 0.936      |
|    value_loss           | 1.76e+03   |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 372        |
|    ep_rew_mean          | 990        |
| time/                   |            |
|    fps                  | 2          |
|    iterations           | 24         |
|    time_elapsed         | 23284      |
|    total_timesteps      | 49152      |
| train/                  |            |
|    approx_kl            | 0.01830166 |
|    clip_fraction        | 0.0779     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.7       |
|    explained_variance   | 0.178      |
|    learning_rate        | 0.0001     |
|    loss                 | 592        |
|    n_updates            | 230        |
|    policy_gradient_loss | -0.0242    |
|    std                  | 0.931      |
|    value_loss           | 1.48e+03   |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 371         |
|    ep_rew_mean          | 1.02e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 25          |
|    time_elapsed         | 23347       |
|    total_timesteps      | 51200       |
| train/                  |             |
|    approx_kl            | 0.027207045 |
|    clip_fraction        | 0.104       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.69       |
|    explained_variance   | 0.29        |
|    learning_rate        | 0.0001      |
|    loss                 | 545         |
|    n_updates            | 240         |
|    policy_gradient_loss | -0.0214     |
|    std                  | 0.931       |
|    value_loss           | 1.08e+03    |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 367        |
|    ep_rew_mean          | 1.03e+03   |
| time/                   |            |
|    fps                  | 2          |
|    iterations           | 26         |
|    time_elapsed         | 23419      |
|    total_timesteps      | 53248      |
| train/                  |            |
|    approx_kl            | 0.01926349 |
|    clip_fraction        | 0.0708     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.69      |
|    explained_variance   | 0.166      |
|    learning_rate        | 0.0001     |
|    loss                 | 922        |
|    n_updates            | 250        |
|    policy_gradient_loss | -0.0294    |
|    std                  | 0.923      |
|    value_loss           | 2.14e+03   |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 364        |
|    ep_rew_mean          | 1.04e+03   |
| time/                   |            |
|    fps                  | 2          |
|    iterations           | 27         |
|    time_elapsed         | 23492      |
|    total_timesteps      | 55296      |
| train/                  |            |
|    approx_kl            | 0.02049969 |
|    clip_fraction        | 0.0865     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.67      |
|    explained_variance   | 0.0686     |
|    learning_rate        | 0.0001     |
|    loss                 | 705        |
|    n_updates            | 260        |
|    policy_gradient_loss | -0.0252    |
|    std                  | 0.921      |
|    value_loss           | 1.5e+03    |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 363         |
|    ep_rew_mean          | 1.05e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 28          |
|    time_elapsed         | 23556       |
|    total_timesteps      | 57344       |
| train/                  |             |
|    approx_kl            | 0.015524499 |
|    clip_fraction        | 0.0457      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.67       |
|    explained_variance   | 0.442       |
|    learning_rate        | 0.0001      |
|    loss                 | 691         |
|    n_updates            | 270         |
|    policy_gradient_loss | -0.0173     |
|    std                  | 0.919       |
|    value_loss           | 1.68e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 365         |
|    ep_rew_mean          | 1.08e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 29          |
|    time_elapsed         | 23627       |
|    total_timesteps      | 59392       |
| train/                  |             |
|    approx_kl            | 0.016506482 |
|    clip_fraction        | 0.0597      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.67       |
|    explained_variance   | 0.447       |
|    learning_rate        | 0.0001      |
|    loss                 | 531         |
|    n_updates            | 280         |
|    policy_gradient_loss | -0.0253     |
|    std                  | 0.921       |
|    value_loss           | 1.13e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 361         |
|    ep_rew_mean          | 1.1e+03     |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 30          |
|    time_elapsed         | 23700       |
|    total_timesteps      | 61440       |
| train/                  |             |
|    approx_kl            | 0.037335202 |
|    clip_fraction        | 0.138       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.67       |
|    explained_variance   | 0.302       |
|    learning_rate        | 0.0001      |
|    loss                 | 531         |
|    n_updates            | 290         |
|    policy_gradient_loss | -0.0306     |
|    std                  | 0.919       |
|    value_loss           | 1.18e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 361         |
|    ep_rew_mean          | 1.12e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 31          |
|    time_elapsed         | 23768       |
|    total_timesteps      | 63488       |
| train/                  |             |
|    approx_kl            | 0.024498407 |
|    clip_fraction        | 0.0903      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.67       |
|    explained_variance   | 0.346       |
|    learning_rate        | 0.0001      |
|    loss                 | 627         |
|    n_updates            | 300         |
|    policy_gradient_loss | -0.0265     |
|    std                  | 0.921       |
|    value_loss           | 1.61e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 356         |
|    ep_rew_mean          | 1.14e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 32          |
|    time_elapsed         | 23843       |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.011854064 |
|    clip_fraction        | 0.0375      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.67       |
|    explained_variance   | 0.362       |
|    learning_rate        | 0.0001      |
|    loss                 | 947         |
|    n_updates            | 310         |
|    policy_gradient_loss | -0.0214     |
|    std                  | 0.92        |
|    value_loss           | 2.08e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 352         |
|    ep_rew_mean          | 1.14e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 33          |
|    time_elapsed         | 23919       |
|    total_timesteps      | 67584       |
| train/                  |             |
|    approx_kl            | 0.026554512 |
|    clip_fraction        | 0.111       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.66       |
|    explained_variance   | 0.0542      |
|    learning_rate        | 0.0001      |
|    loss                 | 736         |
|    n_updates            | 320         |
|    policy_gradient_loss | -0.0302     |
|    std                  | 0.915       |
|    value_loss           | 1.74e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 350         |
|    ep_rew_mean          | 1.15e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 34          |
|    time_elapsed         | 23994       |
|    total_timesteps      | 69632       |
| train/                  |             |
|    approx_kl            | 0.022667024 |
|    clip_fraction        | 0.102       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.65       |
|    explained_variance   | 0.264       |
|    learning_rate        | 0.0001      |
|    loss                 | 659         |
|    n_updates            | 330         |
|    policy_gradient_loss | -0.0338     |
|    std                  | 0.912       |
|    value_loss           | 1.35e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 347         |
|    ep_rew_mean          | 1.16e+03    |
| time/                   |             |
|    fps                  | 2           |
|    iterations           | 35          |
|    time_elapsed         | 24313       |
|    total_timesteps      | 71680       |
| train/                  |             |
|    approx_kl            | 0.025890555 |
|    clip_fraction        | 0.113       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.64       |
|    explained_variance   | 0.068       |
|    learning_rate        | 0.0001      |
|    loss                 | 644         |
|    n_updates            | 340         |
|    policy_gradient_loss | -0.0343     |
|    std                  | 0.905       |
|    value_loss           | 1.39e+03    |
-----------------------------------------


---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 345       |
|    ep_rew_mean          | 1.18e+03  |
| time/                   |           |
|    fps                  | 3         |
|    iterations           | 36        |
|    time_elapsed         | 24400     |
|    total_timesteps      | 73728     |
| train/                  |           |
|    approx_kl            | 0.0259847 |
|    clip_fraction        | 0.106     |
|    clip_range           | 0.3       |
|    entropy_loss         | -2.63     |
|    explained_variance   | 0.0732    |
|    learning_rate        | 0.0001    |
|    loss                 | 997       |
|    n_updates            | 350       |
|    policy_gradient_loss | -0.0314   |
|    std                  | 0.903     |
|    value_loss           | 2.02e+03  |
---------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 345         |
|    ep_rew_mean          | 1.19e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 37          |
|    time_elapsed         | 24481       |
|    total_timesteps      | 75776       |
| train/                  |             |
|    approx_kl            | 0.021847155 |
|    clip_fraction        | 0.0922      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.63       |
|    explained_variance   | 0.253       |
|    learning_rate        | 0.0001      |
|    loss                 | 678         |
|    n_updates            | 360         |
|    policy_gradient_loss | -0.0287     |
|    std                  | 0.899       |
|    value_loss           | 1.57e+03    |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 343        |
|    ep_rew_mean          | 1.19e+03   |
| time/                   |            |
|    fps                  | 3          |
|    iterations           | 38         |
|    time_elapsed         | 24569      |
|    total_timesteps      | 77824      |
| train/                  |            |
|    approx_kl            | 0.02285341 |
|    clip_fraction        | 0.0881     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.62      |
|    explained_variance   | 0.617      |
|    learning_rate        | 0.0001     |
|    loss                 | 688        |
|    n_updates            | 370        |
|    policy_gradient_loss | -0.0253    |
|    std                  | 0.898      |
|    value_loss           | 1.75e+03   |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 341         |
|    ep_rew_mean          | 1.21e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 39          |
|    time_elapsed         | 24660       |
|    total_timesteps      | 79872       |
| train/                  |             |
|    approx_kl            | 0.029467385 |
|    clip_fraction        | 0.123       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.0298      |
|    learning_rate        | 0.0001      |
|    loss                 | 570         |
|    n_updates            | 380         |
|    policy_gradient_loss | -0.0255     |
|    std                  | 0.894       |
|    value_loss           | 1.5e+03     |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 339         |
|    ep_rew_mean          | 1.21e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 40          |
|    time_elapsed         | 26293       |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.025982514 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.172       |
|    learning_rate        | 0.0001      |
|    loss                 | 653         |
|    n_updates            | 390         |
|    policy_gradient_loss | -0.0318     |
|    std                  | 0.894       |
|    value_loss           | 1.45e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 336         |
|    ep_rew_mean          | 1.22e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 41          |
|    time_elapsed         | 26388       |
|    total_timesteps      | 83968       |
| train/                  |             |
|    approx_kl            | 0.024635764 |
|    clip_fraction        | 0.081       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.341       |
|    learning_rate        | 0.0001      |
|    loss                 | 426         |
|    n_updates            | 400         |
|    policy_gradient_loss | -0.0242     |
|    std                  | 0.893       |
|    value_loss           | 1.04e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 332         |
|    ep_rew_mean          | 1.22e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 42          |
|    time_elapsed         | 26478       |
|    total_timesteps      | 86016       |
| train/                  |             |
|    approx_kl            | 0.023231039 |
|    clip_fraction        | 0.0912      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.117       |
|    learning_rate        | 0.0001      |
|    loss                 | 757         |
|    n_updates            | 410         |
|    policy_gradient_loss | -0.0265     |
|    std                  | 0.894       |
|    value_loss           | 1.53e+03    |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 334        |
|    ep_rew_mean          | 1.24e+03   |
| time/                   |            |
|    fps                  | 3          |
|    iterations           | 43         |
|    time_elapsed         | 26571      |
|    total_timesteps      | 88064      |
| train/                  |            |
|    approx_kl            | 0.03878351 |
|    clip_fraction        | 0.178      |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.62      |
|    explained_variance   | 0.1        |
|    learning_rate        | 0.0001     |
|    loss                 | 382        |
|    n_updates            | 420        |
|    policy_gradient_loss | -0.0306    |
|    std                  | 0.901      |
|    value_loss           | 1.01e+03   |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 331        |
|    ep_rew_mean          | 1.25e+03   |
| time/                   |            |
|    fps                  | 3          |
|    iterations           | 44         |
|    time_elapsed         | 26663      |
|    total_timesteps      | 90112      |
| train/                  |            |
|    approx_kl            | 0.01805212 |
|    clip_fraction        | 0.0611     |
|    clip_range           | 0.3        |
|    entropy_loss         | -2.62      |
|    explained_variance   | 0.588      |
|    learning_rate        | 0.0001     |
|    loss                 | 636        |
|    n_updates            | 430        |
|    policy_gradient_loss | -0.0254    |
|    std                  | 0.901      |
|    value_loss           | 1.51e+03   |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 329         |
|    ep_rew_mean          | 1.26e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 45          |
|    time_elapsed         | 26756       |
|    total_timesteps      | 92160       |
| train/                  |             |
|    approx_kl            | 0.031742916 |
|    clip_fraction        | 0.113       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.62       |
|    explained_variance   | 0.091       |
|    learning_rate        | 0.0001      |
|    loss                 | 584         |
|    n_updates            | 440         |
|    policy_gradient_loss | -0.0293     |
|    std                  | 0.895       |
|    value_loss           | 1.08e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 326         |
|    ep_rew_mean          | 1.26e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 46          |
|    time_elapsed         | 26852       |
|    total_timesteps      | 94208       |
| train/                  |             |
|    approx_kl            | 0.024482548 |
|    clip_fraction        | 0.0852      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.454       |
|    learning_rate        | 0.0001      |
|    loss                 | 538         |
|    n_updates            | 450         |
|    policy_gradient_loss | -0.0294     |
|    std                  | 0.894       |
|    value_loss           | 1.07e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 326         |
|    ep_rew_mean          | 1.27e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 47          |
|    time_elapsed         | 26949       |
|    total_timesteps      | 96256       |
| train/                  |             |
|    approx_kl            | 0.022802372 |
|    clip_fraction        | 0.084       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.419       |
|    learning_rate        | 0.0001      |
|    loss                 | 605         |
|    n_updates            | 460         |
|    policy_gradient_loss | -0.024      |
|    std                  | 0.893       |
|    value_loss           | 1.15e+03    |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 327         |
|    ep_rew_mean          | 1.27e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 48          |
|    time_elapsed         | 27050       |
|    total_timesteps      | 98304       |
| train/                  |             |
|    approx_kl            | 0.029262481 |
|    clip_fraction        | 0.116       |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.6        |
|    explained_variance   | 0.37        |
|    learning_rate        | 0.0001      |
|    loss                 | 405         |
|    n_updates            | 470         |
|    policy_gradient_loss | -0.028      |
|    std                  | 0.892       |
|    value_loss           | 925         |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 329         |
|    ep_rew_mean          | 1.28e+03    |
| time/                   |             |
|    fps                  | 3           |
|    iterations           | 49          |
|    time_elapsed         | 27148       |
|    total_timesteps      | 100352      |
| train/                  |             |
|    approx_kl            | 0.018253624 |
|    clip_fraction        | 0.0595      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.6        |
|    explained_variance   | 0.631       |
|    learning_rate        | 0.0001      |
|    loss                 | 555         |
|    n_updates            | 480         |
|    policy_gradient_loss | -0.0225     |
|    std                  | 0.889       |
|    value_loss           | 1.18e+03    |
-----------------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x7f88aceaad90>

In [8]:
mean, std = evaluate_policy(model=model, env=env, n_eval_episodes=10, deterministic=True)
print(f'Mean: {mean:.2f}, Std: {std:.2f}')



Scene:  head-on
Scene:  overtaking
Scene:  head-on
Scene:  head-on
Scene:  static
Scene:  head-on
Scene:  crossing
Scene:  overtaking
Scene:  static
Scene:  crossing
Scene:  static
Mean: 1160.51, Std: 238.91


In [None]:
%tensorboard --logdir ./stage_1_tensorboard_logs/ --host=0.0.0.0

In [9]:
# Save environment normalization stats
# env.save("ppo_normalized_env.pkl")
model.save("ppo_marine_stage_2")
# model = model.load('ppo_marine_stage_1.zip')

In [None]:
env = VecNormalize.load("ppo_normalized_env.pkl", env)

# Disable reward normalization for evaluation
env.training = False
env.norm_reward = False

import cv2
import numpy as np

obs = env.reset()
for _ in range(100):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, dones, _ = env.step(action)

    # ✅ Ensure env.get_images() is not empty
    images = env.get_images()
    if images and images[0] is not None:
        frame = images[0]
        
        # ✅ Ensure the frame has valid dimensions before displaying
        if frame.shape[0] > 0 and frame.shape[1] > 0:
            cv2.imshow("PPO MarineEnv Evaluation", frame)
            cv2.waitKey(1)  # Display for 1ms
        else:
            print("Warning: Received an empty frame from env.get_images()")

    if dones:
        break

env.close()
cv2.destroyAllWindows()  # Close display window


In [15]:
env = gym.make('MarineEnv-v0', render_mode='human', continuous=True, training_stage=2)
state, _ = env.reset()
print(state)
episode_rewards = 0 
# flatten_state = flatten(env.observation_space, state)
# state = torch.tensor(flatten_state, dtype=torch.float32, device=device).unsqueeze(0)
for _ in range(400):
    action = model.predict(state, deterministic=True)
    # print(action)
    # observation, reward, terminated, truncated, info = env.step((0, 0))
    observation, reward, terminated, truncated, info = env.step(action[0])
    env.render()
    time.sleep(0.09)
    episode_rewards += reward
    print('===========================')
    print(observation)
    
    if terminated or truncated:
        print(episode_rewards)
        break

    state = observation
        
print(episode_rewards)
print(state)
env.close()

Scene:  overtaking
[352.84598     8.270623   23.717087  172.0578     -4.5568643 172.0578
   1.1522969   0.9461159 330.75168     4.819188   44.419704  205.94368
   3.788774    5.5005403  85.249214   74.8327      0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.       ]


  logger.warn(f"{pre} is not within the observation space.")


[347.47888      8.370624    23.670557   169.66878      0.81182843
 171.72447      2.6951344    1.8640248  330.75168      4.801336
  49.87619    194.51064      3.4833755    5.5005403  109.743675
  76.21453      0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.        ]
[341.47037     8.470624   23.6238    167.33456     6.8339014 171.39114
   5.987527    2.9183948 330.75168     4.78708    56.015423  179.92232
   3.2322266   5.5005403 167.48991    70.43965     0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.       ]
[337.37524     8.570623   23.577019  165.05464    10.950973  171.05782
  11.51214     3.5890958 330.75168     4.775433   60.269115  168.91737
   

AttributeError: 'NoneType' object has no attribute 'fill'

In [None]:
env.close()

In [None]:
state, _ = env.reset()

In [None]:
state

In [None]:
model.predict(state)[0]

In [None]:
model.batch_size