# Example of how to integrate PPO with gSDE 📚

## 🔖🦄 Modify the Policy Network

The policy network should output not only the parameters of the action distribution (e.g., mean and possibly variance for Gaussian actions) but also parameters for the state-dependent noise.

In [None]:
import torch
import torch.nn as nn
import torch.distributions as distributions

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Linear(state_dim, 64)
        self.mu = nn.Linear(64, action_dim)  # Mean of the action distribution
        self.log_std = nn.Parameter(torch.zeros(action_dim))  # Log standard deviation for Gaussian
        self.state_dependent_noise = nn.Linear(state_dim, action_dim)  # Additional layer for state-dependent noise

    def forward(self, state):
        x = torch.relu(self.fc(state))
        mu = self.mu(x)
        std = torch.exp(self.log_std)
        noise = self.state_dependent_noise(state)
        return mu, std, noise


## 🔖 🏮 Implement State-Dependent Noise:

Introduce a mechanism to generate noise based on the current state. This could involve an additional network or a module within the policy network that computes noise parameters (e.g., a noise covariance matrix).


In [None]:
def sample_action(mu, std, noise, epsilon):
    # Sample noise from standard normal distribution
    noise_sample = epsilon * noise
    # Apply state-dependent noise to the action
    action = mu + std * noise_sample
    return action


## 🔖 ✨ Action Sampling with gSDE:

Modify the action sampling process to incorporate the state-dependent noise. This means that for each state, you generate noise based on the state and add it to the action mean before sampling the final action.

In [None]:
def get_action(policy_network, state):
    mu, std, noise = policy_network(state)
    epsilon = torch.randn_like(mu)  # Standard normal noise
    action = sample_action(mu, std, noise, epsilon)
    return action


## 🔖🍥  Adapt the PPO Update:

Ensure that the PPO update rules (the clipped surrogate objective) are applied to the new action sampling method. The PPO algorithm's core remains the same, but it now works with actions sampled using gSDE.

In [None]:
def ppo_update(policy_network, optimizer, states, actions, log_probs, returns, advantages, epsilon=0.2, beta=0.01):
    mu, std, noise = policy_network(states)
    new_log_probs = distributions.Normal(mu + noise, std).log_prob(actions).sum(axis=-1)
    ratio = torch.exp(new_log_probs - log_probs)

    # PPO objective with clipping
    surrogate1 = ratio * advantages
    surrogate2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
    policy_loss = -torch.min(surrogate1, surrogate2).mean()

    # Entropy bonus to encourage exploration
    entropy = distributions.Normal(mu, std).entropy().mean()
    loss = policy_loss - beta * entropy

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
