In [1]:
import os

from typing import Dict, Any, Optional, Callable

import gymnasium as gym
from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import RecordVideo
import time
import numpy as np

# from typing import Any, Dict
import torch
import torch.nn as nn

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor

In [2]:
# Create folders for logs and videos
log_folder = './marl_logs'
video_folder = os.path.join(log_folder, 'video')
os.makedirs(video_folder, exist_ok=True)
# PPO.load(os.path.join(log_folder, 'optuna/best_model/best_model.zip'), device='cpu').policy

In [3]:
%matplotlib inline
%load_ext tensorboard

%tensorboard --logdir {log_folder} --host=0.0.0.0

Reusing TensorBoard on port 6006 (pid 1213), started 1 day, 2:00:28 ago. (Use '!kill 1213' to kill it.)

In [4]:
if 'MarineEnv-v0' not in registry:
    register(
        id='MarineEnv-v0',
        entry_point='environments:MarineEnv',  # String reference to the class
    )

def yield_random_seed():
    while True:
        yield np.random.randint(low=1, high=201)
seed_generator = yield_random_seed()

In [5]:
env_kwargs = dict(
        render_mode='rgb_array',
        training_stage=2,
        timescale=1/3,
        training=False,
        total_targets=2,
    )
    
env = gym.make(
        'MarineEnv-v0',
        **env_kwargs,
    )

In [6]:
env.observation_space.shape

(80,)

In [7]:
def run_video_rendering(
    agent: Any, 
    episodes: int = 3, 
    timescale: float = 1/6, 
    seed: Optional[int] = None, 
    record_video: bool = False, 
    episode_trigger: Optional[Callable[[int], bool]] = None, 
    name_prefix: Optional[str] = None, 
    video_folder: Optional[str] = None,
    marl: bool = False,
) -> None:
    """
    Runs a simulation of the MarineEnv-v0 environment using the given agent, with optional video recording.

    Parameters:
    - agent: The trained agent used for inference.
    - episodes (int): Number of episodes to run (default is 3).
    - timescale (float): The simulation timescale factor (default is 1/6).
    - seed (int, optional): Random seed for environment initialization. If None, a generated seed is used.
    - record_video (bool): If True, records the simulation as a video (default is False).
    - episode_trigger (function, optional): A function that determines which episodes get recorded.
    - name_prefix (str, optional): Prefix for recorded video file names.
    - video_folder (str, optional): Path to save recorded videos.
    - marl (bool): If True, running the env in MARL mode.

    Behavior:
    - If video recording is enabled, the environment is wrapped with RecordVideo.
    - Logs episode statistics, including total rewards, episode length, and termination status.
    - If recording, logs are saved to a text file in the specified video folder.
    - Displays simulation results in the console if video recording is disabled.

    Returns:
    - None. The function either logs the results to a file or prints them to the console.
    """
    
    if seed is None:
        seed = next(seed_generator)
        
    kwargs = dict(
        render_mode='rgb_array' if record_video else 'human',
        continuous=True,
        training_stage=2,
        timescale=timescale,
        training=False,
        total_targets=2,
        seed=seed,
    )
    
    env = gym.make(
        'MarineEnv-v0',
        **kwargs,
    )
    
    if record_video:
        from IPython.display import HTML
        from base64 import b64encode
        
        if episode_trigger is None:
            episode_trigger = lambda episode_id: True

        if video_folder is None:
            video_folder = video_folder   
        else:
            if name_prefix:
                video_folder = os.path.join(video_folder, name_prefix)
        
        # wrap environment for video recording
        env = RecordVideo(
            env=env, 
            video_folder=video_folder, 
            episode_trigger=lambda episode_id: True, 
            name_prefix=name_prefix)
        
    logged_episodes = []
    logged_rewards = []
    for episode in range(episodes):
        state, _ = env.reset()
        episode_rewards = 0

        if marl:
            eta = state[0][5]
        else:
            eta = state[5]
        
        for step in range(int(400 / timescale)):
            if marl:
                with torch.no_grad():
                    actions = [agent.predict(i, deterministic=True)[0] for i in state]
                next_state, reward, terminated, truncated, info = env.step(actions)
            else:
                with torch.no_grad():
                    action = agent.predict(state, deterministic=True)
            
                state, reward, terminated, truncated, info = env.step(action[0])
            
            if not record_video:
                env.render()
                
            episode_rewards += reward
            
            if terminated or truncated:
                break
                
            time.sleep(0.005)
         
        result_string = f'Episode: {episode}\nEpisode length: {step}, Elapsed real time: {round(step * timescale)} minutes, Initial WP ETA: {round(eta)} minutes\nEpisode total rewards: {episode_rewards :.2f}\nIs terminated: {info["terminated"]}, Is truncated: {info["truncated"]}\n============================\n'

        logged_episodes.append(result_string)
        logged_rewards.append(episode_rewards)

    evaluation = f'Mean: {np.array(logged_rewards).mean():.2f}, Std: {np.array(logged_rewards).std():.2f}, Initial seed: {seed}'
    logged_episodes.append(evaluation)
    # Open log file
    if record_video:
        log_file_path = os.path.join(video_folder, name_prefix + '.txt')
        with open(log_file_path, 'w') as log_file:
            log_file.write('\n'.join(logged_episodes))
            print(f"Training log saved at: {log_file_path}")
    else:
        print('\n'.join(logged_episodes))
    
    env.close()

In [8]:
# setup of the training envs
train_env = make_vec_env(env_id='MarineEnv-v0', n_envs=8, env_kwargs=env_kwargs)
eval_env = gym.make('MarineEnv-v0', **env_kwargs)
eval_env = Monitor(eval_env)

In [9]:
# setup of the defaut kwargs for initial training of PPO agent
default_kwargs = {
    'learning_rate': 3e-4,
    'n_steps': 2048,
    'batch_size': 64,
    'n_epochs': 10,
    'gamma': 0.99, 
    'gae_lambda': 0.95,
    'clip_range': 0.2,
    'ent_coef': 0.0,
    'vf_coef': 0.5, 
    'max_grad_norm': 0.5,
    'target_kl': None,
    'tensorboard_log': log_folder,
}

In [10]:
# establish initial agent
agent = PPO(
        policy='MlpPolicy',
        env=train_env,
        verbose=0,
        device='cpu',
        **default_kwargs,
    )

In [11]:
# eval callback - used to track the statistics of the training
eval_callback = EvalCallback(
        eval_env,
        best_model_save_path=os.path.join(log_folder, 'default/best_model/'),
        log_path=os.path.join(log_folder, 'default/results/'),
        eval_freq=5000,
        deterministic=True,
        render=False
    )

In [12]:
# env.observation_space

In [13]:
# agent.observation_space

In [14]:
agent.learn(total_timesteps=int(1e5), reset_num_timesteps=True, tb_log_name='default/default_best', progress_bar=True, callback=eval_callback)

Output()

Seed set to 57847507
Seed set to 57847508
Seed set to 57847509
Seed set to 57847510
Seed set to 57847511
Seed set to 57847512
Seed set to 57847513
Seed set to 57847514


  logger.warn(f"{pre} is not within the observation space.")


<stable_baselines3.ppo.ppo.PPO at 0x7f6d8438b9d0>

In [None]:
env.reset()

In [17]:
run_video_rendering(agent, episodes=2)

Seed set to 129
Episode: 0
Episode length: 315, Elapsed real time: 52 minutes, Initial WP ETA: 0 minutes
Episode total rewards: -57.56
Is terminated: WP Reached!, Is truncated: False

Episode: 1
Episode length: 421, Elapsed real time: 70 minutes, Initial WP ETA: 0 minutes
Episode total rewards: -254.25
Is terminated: WP Reached!, Is truncated: False

Mean: -155.91, Std: 98.34, Initial seed: 129


In [None]:
run_video_rendering(agent, episodes=1, marl=False, seed=42)

In [65]:
env = agent.env

In [66]:
raw_env = env.envs[5]

In [67]:
env = raw_env
while hasattr(env, 'env'):
    env = env.env

In [70]:
env.reset()

(array([99.11236572, 13.30774117,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        , 12.69857311, 57.25347137, -2.44946313,
        57.25347137,  1.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0. 

In [73]:
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
print("Step obs:", obs)



In [78]:
info


{'total_steps': 4, 'terminated': False, 'truncated': False}