## Parallel REINFORCE
---

Let's recall the REINFORCE algorithm:

$$
\nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta} \left[ \sum_{t=0}^T \gamma^t R_t \nabla_\theta \log \pi_\theta(A_t|S_t) \right],
$$

that can be approximated with Monte-Carlo methods:

$$
\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^N \sum_{t=0}^T \gamma^t R_t^{(i)} \nabla_\theta \log \pi_\theta(A_t^{(i)}|S_t^{(i)}).
$$

It is obvious that the sum over $t$ can not be computed in parallel, as the current state of the environment depends on the previous one. However, the sum over $i$ can be computed in parallel, as the trajectories are independent. Thus, we need to launch $N$ environments in parallel, collect the trajectories, and then compute the gradients. 

Unfortunately the parallelization can not be perfromed straightforwardly, as independent environments reach the terminal states at different times. Thus, we need to synchronize the environments at the end of each episode. There are two possible ways to do that: truncation and padding.

In case of truncation, we terminates the current episode when the first environment reaches the terminal state. This approach is not preferable as it leads to the loss of information, produced by the environments that have not reached the terminal state yet. Padding, on the other hand, is more preferable, as it allows to collect the full trajectories from all the environments. The only thing we need to do is to mask the gradients of the environments that have reached the terminal state.

### Mathematical rationale

Recalling the nabla distributive function property:

$$ \nabla_\theta \left( \alpha f(\theta) \right) = \alpha \nabla_\theta f(\theta), $$

therefore 

$$ 
\nabla_\theta \left(0 \cdot f(\theta) \right) \equiv 0 \cdot \nabla_\theta f(\theta) \equiv 0, \\
\nabla_\theta \left(1 \cdot f(\theta) \right) \equiv 1 \cdot \nabla_\theta f(\theta) \equiv \nabla_\theta f(\theta).
$$

Hence, by introducing the padding function $H\left(S_t^{(i)}\right)$ we can nullify the gradients of the environments that have reached the terminal state without stopping other environments. The padding function is defined as follows:

$$
H\left(S_t^{(i)}\right) = \begin{cases}
1, & \text{if } S_k^{(i)} \text{ is not terminal } \forall k \in \{0,...,t-1\} \\
0, & \text{otherwise}.
\end{cases}
$$

Thus, the gradient for the $\theta$ for the given environment becomes zero as soon as the environment reaches the terminal state. The final gradient is computed as follows:

$$
\nabla_\theta J(\theta) \approx \frac{1}{N} \sum_{i=1}^N \sum_{t=0}^T \gamma^t R_t^{(i)} H\left(S_t^{(i)}\right) \nabla_\theta \log \pi_\theta\left(A_t^{(i)}|S_t^{(i)}\right).
$$ 

### Implementation

It is know that the gymnasium environments yield tuple of five elements for each step:

```python
obs, reward, terminated, truncated, info = env.step(action.item())
```

For the vectorized enviornment, the first dimension of each element is equal to the number of the environments, which is the $N$ in the formula above. Asssuming that, the definition of padding function becomes evident

```python
class Padding:

    def __init__(self, size: int):
        """Initialize appding function"""
        self._size = size
        self.reset()

    def reset(self):
        """Reset the padding function"""
        self._is_terminated = np.full(size=self._size, False)

    def __call__(self, terminated: np.ndarray, truncated: np.ndarray) -> torch.Tensor, bool:
        """Yields the padding values
        
        Parameters
        ----------
        terminated : np.ndarray
            The array of the terminal states
        truncated : np.ndarray
            The array of the truncated states
        
        Returns
        -------
        np.ndarray
            The padding values (1 if value should be used, 0 otherwise)
        bool
            The total termination flag. If True, then all the states have reached the
            terminal states and the training loop should be stopped.
        """
        # As the function has a lag, we need firstly yield the previous values
        output = torch.Tensor(self._is_terminated.astype(int))
        # Then we update the values
        done = np.logical_or(terminated, truncated)
        self._is_terminated = np.logical_or(self._is_terminated, done)
        # As termination flag checked after the step, we can yield it, using the updated
        # values
        total_termination = np.all(self._is_terminated)
        return output, total_termination 

```

In training loop the function can be used as following:

```python

padding = Padding(size=N)
policy: torch.nn.Module = Policy()

...

gamma = 1
for i in range(iterations):
    optimizer.zero_grad()

    padding.reset()
    memory.reset()
    obs, _ = envs.reset()
    losses = torch.zeros(N)
    
    for t in range(T):

        action, log_probs = policy(obs)
        obs, rewards, terminated, truncated, _ = env.step(action.detached().numpy())
        padding_values, total_termination = padding(terminated, truncated)

        memory.append(log_probs, rewards, padding_values)

        ...

        if total_termination:
            break
    
    loss = memory.loss()
    loss.backward()
    optimizer.step()
```

For simplicity, the padding module can be integrated into `memory` function to reset it without a separated call.


In [None]:
import os
import random

import gymnasium as gym
import torch

from src.agent import PolicyNetworkContinuous, PolicyNetworkDiscrete, train, validate
from src.envs import make_env, make_envs
from src.utils import mp4_to_gif

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)

## Implementation for the discrete environment
---

### Training

In [None]:
THREADS = 5
ITERATIONS = 10
GAMMA = 0.90


FOLDER = "./results/reinforce"

CONFIG = {
    "observation": {
        # Use an occupancy grid. The grid size and features can be adjusted.
        "type": "OccupancyGrid",  # or "Kinematics" / "TimeToCollision"
        "grid_size": [
            [-5, 5],
            [-5, 5],
        ],  # Two dimensions: x from -5 to 5 and y from -5 to 5
        "grid_step": [2.0, 2.0],  # Specify step for each dimension
        "features": ["presence", "vx"],  # presence and relative speed features
    },
    "simulation_frequency": 15,  # adjust as needed
    "policy_frequency": 5,
    "duration": 40,  # initial episode duration in seconds
    "action": {"type": "DiscreteMetaAction"},  # use the discrete meta-action space
    "offscreen_rendering": True,
}

MAX_LENGTH = CONFIG["duration"] * CONFIG["policy_frequency"]

In [None]:
envs = make_envs("highway-v0", THREADS, config=CONFIG)
input_dim = envs.observation_space.shape[1]
output_dim = envs.action_space[0].n

In [5]:
HIDDEN_DIM = 64
LEARNING_RATE = 1e-3

policy = PolicyNetworkDiscrete(input_dim, HIDDEN_DIM, output_dim)
optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)

In [6]:
model, results = train(policy, envs, optimizer, GAMMA, ITERATIONS, device, SEED)
torch.save(model.state_dict(), os.path.join(FOLDER, "discrete-policy.pt"))

Iteration  1/5:  59%|█████▉    | 117/199 [00:33<00:23,  3.50it/s]
Iteration  2/5:  47%|████▋     | 94/199 [00:29<00:32,  3.24it/s, length=55.7, reward=44.1]
Iteration  3/5:  93%|█████████▎| 185/199 [01:20<00:06,  2.30it/s, length=37.4, reward=29.8]
Iteration  4/5:  71%|███████   | 141/199 [00:55<00:23,  2.52it/s, length=70.9, reward=56.5]
Iteration  5/5:  53%|█████▎    | 105/199 [00:40<00:36,  2.58it/s, length=51.3, reward=40.3]


### Inference

In [None]:
env = make_env("highway-v0", config=CONFIG)
env = gym.wrappers.RecordVideo(
    env,
    video_folder=FOLDER,
    episode_trigger=lambda x: x == 0 or x == (ITERATIONS - 2),
    name_prefix="discrete-agent",
    video_length=MAX_LENGTH,
)
model = PolicyNetworkDiscrete(input_dim, HIDDEN_DIM, output_dim)
model.load_state_dict(
    torch.load(os.path.join(FOLDER, "discrete-policy.pt"), weights_only=False)
)
results = validate(model, env, n_episodes=ITERATIONS, device=device)
mp4_to_gif(FOLDER)

  logger.warn(


Validation: 100%|██████████| 5/5 [01:11<00:00, 14.30s/it]


<p align="center">
    <img width="600" src="results\reinforce\discrete-agent-episode-8.gif" alt="Discrete action space policy">
    <p align="center">Fig. 1 - Policy in 5Hz environment with discrete action space</p>
</p>

In [8]:
results

{'reward': deque([np.float64(142.5368151561697),
        np.float64(146.98125960061418),
        np.float64(133.64792626728143),
        np.float64(133.64792626728143),
        np.float64(133.64792626728143)],
       maxlen=5),
 'length': deque([200, 200, 200, 200, 200], maxlen=5),
 'norm_length': [1.0, 1.0, 1.0, 1.0, 1.0]}

In [None]:
envs = make_envs("highway-v0", THREADS, config=CONFIG)
results = validate(model, env, n_episodes=ITERATIONS, device=device)

Validation: 100%|██████████| 5/5 [01:17<00:00, 15.41s/it]


## Implementation for the continious environment
---

### Training

In [None]:
CONFIG["action"] = {"type": "ContinuousAction"}
envs = make_envs("highway-v0", THREADS, config=CONFIG)
output_dim = envs.action_space._shape[1]

In [None]:
policy = PolicyNetworkContinuous(input_dim, HIDDEN_DIM, output_dim)
optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)

model, results = train(policy, envs, optimizer, GAMMA, ITERATIONS, device, SEED)
torch.save(model.state_dict(), os.path.join(FOLDER, "continious-policy.pt"))

Iteration  1/10: 100%|██████████| 199/199 [01:09<00:00,  2.87it/s]
Iteration  2/10: 100%|██████████| 199/199 [01:19<00:00,  2.49it/s, length=199, reward=11.1]
Iteration  3/10: 100%|██████████| 199/199 [01:13<00:00,  2.69it/s, length=199, reward=10.1]
Iteration  4/10: 100%|██████████| 199/199 [01:16<00:00,  2.61it/s, length=199, reward=12.9]
Iteration  5/10:  11%|█         | 21/199 [00:08<01:09,  2.56it/s, length=199, reward=10.2]

### Inference