In [1]:
%matplotlib inline
%load_ext tensorboard

In [2]:
import numpy as np

from collections import deque

import matplotlib
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim


from tqdm import tqdm

# Gym
import gymnasium as gym
# import gym_pygame

from gymnasium.envs.registration import register, registry

import time

In [3]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
if 'MarineEnv-v0' not in registry:
    register(
        id='MarineEnv-v0',
        entry_point='environments:MarineEnv',  # String reference to the class
    )

In [5]:
timescale = 1 / 3
env_kwargs = dict(
    render_mode='rgb_array',
    continuous=True,
    max_episode_steps=int(400 / timescale),
    training_stage=2,
    timescale=timescale,
    training=True,
    total_targets=1,
)

env_id = 'MarineEnv-v0'
# Create the env
env = gym.make(env_id, **env_kwargs)

# Create the evaluation env
eval_env = gym.make(env_id, **env_kwargs)

# Get the state space and action space
s_size = env.observation_space.shape[0]
a_size = env.action_space.shape[0]

2025-02-05 19:53:08.740892: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738785188.754892   23943 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738785188.758858   23943 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-05 19:53:08.773320: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [6]:
print("_____OBSERVATION SPACE_____ \n")
print("The State Space is: ", s_size)
print("Sample observation", env.observation_space.sample()) # Get a random observation

_____OBSERVATION SPACE_____ 

The State Space is:  16
Sample observation [ 188.64545      7.8844295    7.5682063  200.92033    -44.695557
   42.987606    43.04377     31.751268    43.722347    40.838963
 -119.89502    306.49094     45.983223    -9.929177   -26.163208
   23.778082 ]


In [7]:
print("\n _____ACTION SPACE_____ \n")
print("The Action Space is: ", a_size)
print("Action Space Sample", env.action_space.sample()) # Take a random action


 _____ACTION SPACE_____ 

The Action Space is:  2
Action Space Sample [-0.5046333  -0.39191398]


In [8]:
marine_env_params = {
    "h_size": 128,
    "n_training_episodes": int(1e4),
    "n_evaluation_episodes": 100,
    "max_t": 1200,
    "lr": 1e-5,
    "env_id": env_id,
    "continuous": True,
    "state_space": s_size,
    "action_space": a_size,
    'gamma': 0.95,
    'clip_epsilon': 0.2,  # PPO Clipping
    'update_epochs': 4,  # Multiple updates per batch
    'print_every': 100
}

In [9]:
class ContinuousPolicy(nn.Module):
    def __init__(self, s_size, a_size, h_size):
        super(ContinuousPolicy, self).__init__()
        self.fc1 = nn.Linear(s_size, h_size)
        self.fc2 = nn.Linear(h_size, h_size)

        # Actor (outputs mean and standard deviation for actions)
        self.mu_layer = nn.Linear(h_size, a_size)
        self.sigma_layer = nn.Linear(h_size, a_size)

        # Critic (outputs a single scalar value)
        self.value_layer = nn.Linear(h_size, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        mu = torch.tanh(self.mu_layer(x))  # Mean (bounded [-1,1]) in order to sample the action according to the observation space limits
        sigma = F.softplus(self.sigma_layer(x)) + 1e-5  # Ensure sigma > 0

        value = self.value_layer(x)  # Value function output
        return mu, sigma, value

    def act(self, state):
        """
        Given a state, take an action for a continuous action space.
        Returns: (action, log_prob, value)
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
    
        mu, sigma, value = self.forward(state)  # Forward pass now returns a tensor for value
        dist = Normal(mu, sigma)
    
        action = dist.sample()  # Sample from Gaussian distribution
        log_prob = dist.log_prob(action).sum(dim=-1)  # Sum log probs over multiple actions
    
        return action.detach().cpu().numpy()[0], log_prob, value  # Return `value` as a tensor

In [10]:
def ppo_train(policy, optimizer, n_training_episodes, max_t, gamma, clip_epsilon, update_epochs, print_every):
    """
    Implements PPO with Gaussian policy and proper detach() usage.
    """

    scores_deque = deque(maxlen=100)
    scores = []

    for i_episode in range(1, n_training_episodes + 1):
        saved_log_probs = []
        saved_values = []
        rewards = []
        states = []
        actions = []

        state, _ = env.reset()
        
        for t in range(max_t):
            action, log_prob, value = policy.act(state)  

            # Detach `log_prob` and `value` before storing
            saved_log_probs.append(log_prob.detach())  
            saved_values.append(value.detach())
            actions.append(action)
            states.append(state)

            state, reward, terminated, truncated, _ = env.step(action)
            rewards.append(reward)

            if terminated or truncated:
                break 

        scores_deque.append(sum(rewards))
        scores.append(sum(rewards))

        # Compute advantage function
        returns = deque(maxlen=max_t)
        advantages = deque(maxlen=max_t)
        n_steps = len(rewards)

        last_advantage = 0
        last_value = saved_values[-1]

        for t in range(n_steps - 1, -1, -1):
            delta = rewards[t] + (gamma * saved_values[t + 1] if t + 1 < n_steps else last_value) - saved_values[t]
            last_advantage = delta + gamma * 0.95 * last_advantage  

            returns.appendleft(last_advantage + saved_values[t])  
            advantages.appendleft(last_advantage)  

        # Convert to tensors
        returns = torch.tensor(returns, dtype=torch.float32).to(device)
        advantages = torch.tensor(advantages, dtype=torch.float32).to(device)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO Update
        for _ in range(update_epochs):
            optimizer.zero_grad()  
            
            new_log_probs = []
            new_values = []

            for state, action in zip(states, actions):
                _, new_log_prob, new_value = policy.act(state)
                new_log_probs.append(new_log_prob)
                new_values.append(new_value)

            new_log_probs = torch.stack(new_log_probs).to(device)
            new_values = torch.stack(new_values).squeeze().to(device)

            # Compute ratio correctly (log_prob must be detached!)
            ratio = torch.exp(new_log_probs - torch.tensor(saved_log_probs, dtype=torch.float32).to(device))

            # Clipped Surrogate Loss
            clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
            policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
            # policy_loss = torch.min(ratio * advantages, clipped_ratio * advantages).mean()

            # Separate Value Function Update
            value_loss = nn.MSELoss()(new_values, returns)

            # Combine losses for backpropagation
            total_loss = policy_loss + 0.5 * value_loss

            total_loss.backward()
            optimizer.step()

        if i_episode % print_every == 0:
            print(f"Episode {i_episode}\tAverage Score: {np.mean(scores_deque):.2f}")

    return scores


In [11]:
# Create policy and place it to the device
marine_env_policy = ContinuousPolicy(marine_env_params["state_space"], marine_env_params["action_space"], marine_env_params["h_size"]).to(device)
marine_env_optimizer = optim.Adam(marine_env_policy.parameters(), lr=marine_env_params["lr"])

In [12]:
scores = ppo_train(marine_env_policy,
                   marine_env_optimizer,
                   marine_env_params["n_training_episodes"], 
                   marine_env_params["max_t"],
                   marine_env_params["gamma"], 
                   marine_env_params["clip_epsilon"],
                   marine_env_params["update_epochs"], 
                   marine_env_params["print_every"]
                  )

Episode 100	Average Score: -1157622.40
Episode 200	Average Score: -2298453.03
Episode 300	Average Score: -839353.76
Episode 400	Average Score: -775543.48
Episode 500	Average Score: -1323305.15
Episode 600	Average Score: -433279.20
Episode 700	Average Score: -422513.33
Episode 800	Average Score: -371539.43
Episode 900	Average Score: -649480.67
Episode 1000	Average Score: -943380.89
Episode 1100	Average Score: -315596.05
Episode 1200	Average Score: -430708.71
Episode 1300	Average Score: -1957556.38
Episode 1400	Average Score: -415237.90
Episode 1500	Average Score: -429666.09


KeyboardInterrupt: 

In [None]:
timescale = 1/3
env_test = gym.make('MarineEnv-v0', render_mode='human', continuous=True, training_stage=2, timescale=timescale, training=False, total_targets=1)
for _ in range(5):
    
    state, _ = env_test.reset()
    print(f'Detected targets:', [target for target in env_test.unwrapped.own_ship.detected_targets])
    print(state)
    episode_rewards = 0 
    for _ in range(int(400 / timescale)):
        action = marine_env_policy.act(state)
        state, reward, terminated, truncated, info = env_test.step(action[0])
        env_test.render()
        time.sleep(0.005)
        episode_rewards += reward
        print('===========================')
        print(state)
        print(f'Step reward: {reward:.2f}')
        print(f'Current Total reward: {episode_rewards:.2f}')
        print(f'Dangerous targets: ', [target for target in env_test.unwrapped.own_ship.dangerous_targets])
        
        if terminated or truncated:
            print('Episode total reward: ', episode_rewards)
            print(info)
            break
            
    print('Episode total rewards: ', episode_rewards)
    print('Episode final state: ', state)
    print(f'============================\n' * 10)
    env_trn.close()