# REINFORCE Agent for gymnasium environments
In this notebook I present a very basic implementation of a REINFORCE agent for gymnasium environments. This is meant to work in continous observations and actions environments. The policy network is a multi-layer perceptron (MLP) that outputs a Gaussian distribution for the actions. The agent uses the REINFORCE algorithm to update the policy based on the rewards received from the environment.

For the baseline I use a simple MLP that outputs a single value for the state, which is used to compute the advantage function. The agent is designed to work with environments that have continuous action spaces, such as Pendulum-v1.

I used JAX as a framework for the implementation, which allows for efficient computation and automatic differentiation.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# Core imports for REINFORCE with Baseline implementation
import jax
import jax.numpy as jnp
from typing import List, Tuple

# Environment and utilities
import gymnasium as gym
import numpy as np
import time

In [4]:
class PolicyNetwork:
    def __init__(self, input_dim: int, hidden_dims: List[int], action_dim: int):
        """
        Multi-layer policy network for continuous action spaces using Gaussian policy.
        
        Args:
            input_dim: State space dimension
            hidden_dims: Hidden layer dimensions
            action_dim: Number of continuous action dimensions (6 for HalfCheetah)
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.action_dim = action_dim
        
        # Output size: 2 * action_dim (mean and log_std for each action)
        self.output_dim = 2 * action_dim
        self.layer_dims = [input_dim] + hidden_dims + [self.output_dim]
        
        self.params = self._initialize_params()
    
    def _initialize_params(self) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Initialize weights and biases using proper Xavier initialization."""
        key = jax.random.PRNGKey(42)
        weights, biases = [], []
        
        for i in range(len(self.layer_dims) - 1):
            key, subkey = jax.random.split(key)
            fan_in, fan_out = self.layer_dims[i], self.layer_dims[i + 1]
            
            # Proper Xavier initialization
            limit = jnp.sqrt(6 / (fan_in + fan_out))
            w = jax.random.uniform(subkey, (fan_in, fan_out), minval=-limit, maxval=limit)
            b = jnp.zeros(fan_out)
            
            # Scale down the output layer for better initialization
            if i == len(self.layer_dims) - 2:  # Output layer
                w *= 0.1
            
            weights.append(w)
            biases.append(b)
        
        return weights, biases
    
    @staticmethod
    def forward(state: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]], 
                action_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Forward pass returning action mean and log_std."""
        weights, biases = params
        x = state
        
        # Hidden layers with ReLU activation
        for w, b in zip(weights[:-1], biases[:-1]):
            x = jax.nn.relu(x @ w + b)
        
        # Output layer
        output = x @ weights[-1] + biases[-1]
        
        # Split output into mean and log_std
        mean = output[:action_dim]
        log_std = output[action_dim:] - 1.0  # Initialize log_std around -1 for smaller initial std
        
        # Clamp log_std to prevent extreme values
        log_std = jnp.clip(log_std, -2.0, 2.0)  # std between ~0.135 and ~1.0
        
        return mean, log_std
    
    @staticmethod
    def log_prob(state: jnp.ndarray, actions: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]],
                action_dim: int) -> jnp.ndarray:
        """Compute log probability of actions under Gaussian policy."""
        mean, log_std = PolicyNetwork.forward(state, params, action_dim)
        std = jnp.exp(log_std)
        
        # Gaussian log probability: log p(a) = -0.5 * ((a - μ) / σ)² - log(σ) - 0.5 * log(2π)
        log_prob = -0.5 * jnp.sum(((actions - mean) / std) ** 2)
        log_prob -= jnp.sum(log_std)  # -log(σ) terms
        log_prob -= 0.5 * action_dim * jnp.log(2 * jnp.pi)  # constant term
        
        return log_prob


class ValueNetwork:
    def __init__(self, input_dim: int, hidden_dims: List[int]):
        """
        State-value function network for baseline estimation.
        
        Args:
            input_dim: State space dimension
            hidden_dims: Hidden layer dimensions
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.layer_dims = [input_dim] + hidden_dims + [1]  # Output single value
        
        self.params = self._initialize_params()
    
    def _initialize_params(self) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Initialize weights and biases using Xavier initialization."""
        key = jax.random.PRNGKey(123)  # Different seed from policy network
        weights, biases = [], []
        
        for i in range(len(self.layer_dims) - 1):
            key, subkey = jax.random.split(key)
            fan_in, fan_out = self.layer_dims[i], self.layer_dims[i + 1]
            
            # Xavier initialization for value network
            limit = jnp.sqrt(6 / (fan_in + fan_out))
            w = jax.random.uniform(subkey, (fan_in, fan_out), minval=-limit, maxval=limit)
            b = jnp.zeros(fan_out)
            
            # Make the last layer even smaller to prevent extreme values
            if i == len(self.layer_dims) - 2:  # Last layer
                w *= 0.1  # Scale down output layer
            
            weights.append(w)
            biases.append(b)
        
        return weights, biases
    
    @staticmethod
    def forward(state: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]]) -> jnp.ndarray:
        """Forward pass returning state value estimate."""
        weights, biases = params
        x = state
        
        # Hidden layers with ReLU activation
        for w, b in zip(weights[:-1], biases[:-1]):
            x = jax.nn.relu(x @ w + b)
        
        # Output layer (linear activation for value function)
        value = x @ weights[-1] + biases[-1]
        return value.squeeze()  # Remove extra dimensions

In [None]:
class REINFORCE_Agent:
    def __init__(self, 
                 input_dim: int, 
                 policy_hidden_dims: List[int],
                 value_hidden_dims: List[int], 
                 action_dim: int,
                 learning_rate_theta: float = 1e-4,
                 learning_rate_w: float = 5e-4,
                 gradient_clip: float = 1.0):
        """
        Simplified REINFORCE agent that maximizes average reward per episode.
        
        Args:
            input_dim: State space dimension
            policy_hidden_dims: Hidden layer dimensions for policy network
            value_hidden_dims: Hidden layer dimensions for value network
            action_dim: Number of continuous action dimensions
            learning_rate_theta: Learning rate for policy parameters (α_θ)
            learning_rate_w: Learning rate for value function parameters (α_w)
            gradient_clip: Gradient clipping threshold
        """
        self.alpha_theta = learning_rate_theta
        self.alpha_w = learning_rate_w
        self.action_dim = action_dim
        self.gradient_clip = gradient_clip
        
        # Initialize policy and value networks
        self.policy_network = PolicyNetwork(input_dim, policy_hidden_dims, action_dim)
        self.value_network = ValueNetwork(input_dim, value_hidden_dims)
        
        self.key = jax.random.PRNGKey(0)
    
    def select_action(self, state: jnp.ndarray) -> jnp.ndarray:
        """Sample action from Gaussian policy."""
        self.key, subkey = jax.random.split(self.key)
        
        # Get mean and log_std from policy network
        mean, log_std = PolicyNetwork.forward(state, self.policy_network.params, self.action_dim)
        std = jnp.exp(log_std)
        
        # Sample from Gaussian distribution
        action = mean + std * jax.random.normal(subkey, shape=(self.action_dim,))
        
        # Light clipping to environment bounds
        action = jnp.clip(action, -3.0, 3.0)
        
        return action
    
    def policy_gradient(self, state: jnp.ndarray, action: jnp.ndarray) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Compute gradient of log π(a|s) with respect to policy parameters."""
        def log_prob_fn(params):
            return PolicyNetwork.log_prob(state, action, params, self.action_dim)
        
        grad_fn = jax.grad(log_prob_fn)
        return grad_fn(self.policy_network.params)
    
    def value_gradient(self, state: jnp.ndarray, target: float) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Compute gradient of squared error loss for value function."""
        def value_loss_fn(params):
            predicted_value = ValueNetwork.forward(state, params)
            return 0.5 * (predicted_value - target) ** 2
        
        grad_fn = jax.grad(value_loss_fn)
        return grad_fn(self.value_network.params)
    
    def update(self, states: jnp.ndarray, actions: jnp.ndarray, rewards: jnp.ndarray):
        """
        Simple REINFORCE update that maximizes average reward per episode.
        Uses the episode average reward as the target for all steps.
        """
        # Compute average reward for this episode (simple and direct)
        avg_reward = float(jnp.mean(rewards))
        
        # Compute baselines and advantages using current value network
        advantages = []
        baselines = []
        for t in range(len(states)):
            state_t = states[t]
            baseline = ValueNetwork.forward(state_t, self.value_network.params)
            advantage = avg_reward - baseline  # Use avg_reward instead of discounted return
            advantages.append(float(advantage))
            baselines.append(float(baseline))
        
        advantages = jnp.array(advantages)
        
        # Normalize advantages for stability
        advantage_mean = jnp.mean(advantages)
        advantage_std = jnp.std(advantages) + 1e-8
        advantages = (advantages - advantage_mean) / advantage_std
        
        # Initialize gradient accumulators
        policy_weights, policy_biases = self.policy_network.params
        value_weights, value_biases = self.value_network.params
        
        grad_policy_weights = [jnp.zeros_like(w) for w in policy_weights]
        grad_policy_biases = [jnp.zeros_like(b) for b in policy_biases]
        grad_value_weights = [jnp.zeros_like(w) for w in value_weights]
        grad_value_biases = [jnp.zeros_like(b) for b in value_biases]
        
        # Accumulate gradients for all steps
        for t in range(len(states)):
            state_t = states[t]
            action_t = actions[t]
            advantage = advantages[t]
            
            # Value function gradients (target is average reward)
            grad_v_w, grad_v_b = self.value_gradient(state_t, avg_reward)
            for i in range(len(grad_value_weights)):
                grad_value_weights[i] += grad_v_w[i]
                grad_value_biases[i] += grad_v_b[i]
            
            # Policy gradients (scaled by advantage)
            grad_p_w, grad_p_b = self.policy_gradient(state_t, action_t)
            for i in range(len(grad_policy_weights)):
                grad_policy_weights[i] += advantage * grad_p_w[i]
                grad_policy_biases[i] += advantage * grad_p_b[i]
        
        # Apply gradient clipping and normalization
        n_steps = len(states)
        for i in range(len(grad_policy_weights)):
            # Average gradients over episode
            grad_policy_weights[i] /= n_steps
            grad_policy_biases[i] /= n_steps
            grad_value_weights[i] /= n_steps
            grad_value_biases[i] /= n_steps
            
            # Clip gradients
            grad_policy_weights[i] = jnp.clip(grad_policy_weights[i], -self.gradient_clip, self.gradient_clip)
            grad_policy_biases[i] = jnp.clip(grad_policy_biases[i], -self.gradient_clip, self.gradient_clip)
            grad_value_weights[i] = jnp.clip(grad_value_weights[i], -self.gradient_clip, self.gradient_clip)
            grad_value_biases[i] = jnp.clip(grad_value_biases[i], -self.gradient_clip, self.gradient_clip)
        
        # Apply batch updates
        new_policy_weights = [w + self.alpha_theta * gw for w, gw in zip(policy_weights, grad_policy_weights)]
        new_policy_biases = [b + self.alpha_theta * gb for b, gb in zip(policy_biases, grad_policy_biases)]
        new_value_weights = [w + self.alpha_w * gw for w, gw in zip(value_weights, grad_value_weights)]
        new_value_biases = [b + self.alpha_w * gb for b, gb in zip(value_biases, grad_value_biases)]
        
        # Update parameters
        self.policy_network.params = (new_policy_weights, new_policy_biases)
        self.value_network.params = (new_value_weights, new_value_biases)
        
        # Compute training metrics (using original advantages for logging)
        orig_advantages = jnp.array([float(avg_reward - baselines[t]) for t in range(len(states))])
        avg_advantage = float(jnp.mean(orig_advantages))
        avg_baseline = float(jnp.mean(jnp.array(baselines)))
        advantage_std = float(jnp.std(orig_advantages) + 1e-8)
        baseline_std = float(jnp.std(jnp.array(baselines)) + 1e-8)
        
        # Compute action statistics for logging
        action_mean = float(jnp.mean(actions))
        action_std = float(jnp.std(actions) + 1e-8)
        action_min = float(jnp.min(actions))
        action_max = float(jnp.max(actions))
        
        # Check for NaNs
        if jnp.isnan(avg_advantage) or jnp.isnan(avg_baseline):
            print(f"WARNING: NaN detected!")
            return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        
        return avg_reward, avg_advantage, avg_baseline, advantage_std, baseline_std, action_mean, action_std, action_min, action_max

# Test the simplified REINFORCE for average reward maximization
agent = REINFORCE_Agent(
    input_dim=17,
    policy_hidden_dims=[64, 64], 
    value_hidden_dims=[64, 32],
    action_dim=6,
    learning_rate_theta=1e-4,    
    learning_rate_w=5e-4,        
    gradient_clip=1.0            
)

# Test Gaussian action selection
state = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7])  # 17-dim state
print("Testing Simplified REINFORCE (Average Reward) for HalfCheetah:")
for i in range(3):
    action = agent.select_action(state)
    baseline_value = ValueNetwork.forward(state, agent.value_network.params)
    mean, log_std = PolicyNetwork.forward(state, agent.policy_network.params, agent.action_dim)
    print(f"Action: {action}")
    print(f"Policy mean: {mean}")
    print(f"Policy log_std: {log_std}")
    print(f"Estimated state value: {baseline_value:.4f}")
    print()

# Test update with dummy data
dummy_states = jnp.array([
    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8],
    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6]
])
dummy_actions = jnp.array([[0.1, -0.2, 0.3, -0.1, 0.0, 0.2], [0.0, 0.1, -0.3, 0.2, -0.1, 0.0], [-0.1, 0.0, 0.1, 0.0, 0.3, -0.2]])  # Continuous actions
dummy_rewards = jnp.array([1.0, -0.5, 0.8])

result = agent.update(dummy_states, dummy_actions, dummy_rewards)
avg_reward, avg_advantage, avg_baseline, advantage_std, baseline_std, action_mean, action_std, action_min, action_max = result

print(f"Average reward: {avg_reward:.4f}")
print(f"Average advantage: {avg_advantage:.4f}")
print(f"Average baseline: {avg_baseline:.4f}")
print(f"Advantage std: {advantage_std:.4f}")
print(f"Baseline std: {baseline_std:.4f}")
print(f"Action mean: {action_mean:.4f}, std: {action_std:.4f}, min: {action_min:.4f}, max: {action_max:.4f}")

Testing Simplified REINFORCE (Average Reward) for HalfCheetah:
Action: [-0.55363524 -0.7422026  -0.50519276  0.44682175 -0.33735043 -0.52271694]
Policy mean: [-0.03758024 -0.01273906  0.02240782  0.02175305  0.0345523  -0.03523846]
Policy log_std: [-1.0387405  -1.0318427  -0.99308646 -1.0111479  -0.9646626  -0.9589811 ]
Estimated state value: -0.1375

Action: [ 0.29218614 -0.01417697 -0.25183186 -0.4974694   0.4384162   0.3209747 ]
Policy mean: [-0.03758024 -0.01273906  0.02240782  0.02175305  0.0345523  -0.03523846]
Policy log_std: [-1.0387405  -1.0318427  -0.99308646 -1.0111479  -0.9646626  -0.9589811 ]
Estimated state value: -0.1375

Action: [ 0.37598157  0.09857586 -0.18914562  0.07148103  0.3172931  -0.02049511]
Policy mean: [-0.03758024 -0.01273906  0.02240782  0.02175305  0.0345523  -0.03523846]
Policy log_std: [-1.0387405  -1.0318427  -0.99308646 -1.0111479  -0.9646626  -0.9589811 ]
Estimated state value: -0.1375

Average reward: 0.4333
Average advantage: 0.5707
Average baselin

In [None]:
# Test on a simpler environment: Pendulum-v1
import gymnasium as gym
import numpy as np

# Simple configuration for Pendulum
SIMPLE_CONFIG = {
    'episode_length': 100,  # The length of each episode
    'episodes': 1000,   # The number of episodes to train
    'learning_rate_theta': 3e-3,    # Learning rate for policy parameters (α_θ)
    'learning_rate_w': 3e-3,      # Learning rate for value function parameters (α_w)
    'policy_hidden_dims': [8, 8],   # Hidden layers for policy network 
    'value_hidden_dims': [8, 8],    # Hidden layers for value network
    'log_interval': 1,  # Every how many episodes to log
    'render_interval': 10,  # Every how many episodes to render the environment
    'gradient_clip': 1.0    # Gradient clipping threshold
}

def train_simple_env():
    """Train REINFORCE agent on Pendulum-v1"""
    # Create training environment (no rendering)
    env = gym.make("Pendulum-v1")
    # Create rendering environment for periodic demos
    env_render = gym.make("Pendulum-v1", render_mode='human')
    
    obs_dim = env.observation_space.shape[0]  # 3 for Pendulum
    action_dim = env.action_space.shape[0]    # 1 for Pendulum
    
    print(f"Environment: Pendulum-v1")
    print(f"Observation space: {obs_dim}")
    print(f"Action space: {action_dim} (continuous)")
    print(f"Action range: {env.action_space.low[0]:.1f} to {env.action_space.high[0]:.1f}")
    print(f"Training for {SIMPLE_CONFIG['episodes']} episodes with {SIMPLE_CONFIG['episode_length']} steps each")
    print(f"Rendering every {SIMPLE_CONFIG['render_interval']} episodes\n")
    
    # Create simple agent
    agent = REINFORCE_Agent(
        input_dim=obs_dim,
        policy_hidden_dims=SIMPLE_CONFIG['policy_hidden_dims'],
        value_hidden_dims=SIMPLE_CONFIG['value_hidden_dims'],
        action_dim=action_dim,
        learning_rate_theta=SIMPLE_CONFIG['learning_rate_theta'],
        learning_rate_w=SIMPLE_CONFIG['learning_rate_w'],
        gradient_clip=SIMPLE_CONFIG['gradient_clip']
    )
    
    episode_rewards = []
    
    for episode in range(SIMPLE_CONFIG['episodes']):
        # Choose environment (render or not)
        current_env = env_render if (episode + 1) % SIMPLE_CONFIG['render_interval'] == 0 else env
        
        observation, _ = current_env.reset()
        
        # Collect trajectory
        states, actions_list, rewards = [], [], []
        total_reward = 0
        
        for step in range(SIMPLE_CONFIG['episode_length']):
            # Convert observation to JAX array
            state = jnp.array(observation, dtype=jnp.float32)
            
            # Select action using Gaussian policy
            action = agent.select_action(state)
            
            # Store transition data
            states.append(state)
            actions_list.append(action)
            
            # Take step in environment
            observation, reward, terminated, truncated, _ = current_env.step(np.array(action))
            reward += 10
            # Don't modify the reward - use original Pendulum rewards
            rewards.append(reward)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        # Convert to JAX arrays
        states = jnp.stack(states)
        actions = jnp.stack(actions_list)
        rewards = jnp.array(rewards)
        
        # Update policy using REINFORCE (maximizing average reward)
        result = agent.update(states, actions, rewards)
        avg_reward, avg_advantage, avg_baseline, advantage_std, baseline_std, action_mean, action_std, action_min, action_max = result
        
        episode_rewards.append(total_reward)
        
        # Log progress (fix duplicate logging)
        if (episode + 1) % SIMPLE_CONFIG['log_interval'] == 0:
            recent_reward = np.mean(episode_rewards[-SIMPLE_CONFIG['log_interval']])
            
            print(f"Ep {episode + 1:3d} | Reward: {recent_reward:7.1f} | AvgRew: {avg_reward:6.1f} | Base: {avg_baseline:5.1f} | Adv: {avg_advantage:6.1f} | ActMean: {action_mean:5.2f}, Std: {action_std:4.2f}")
            
            # Special message for rendered episodes (only once)
            if (episode + 1) % SIMPLE_CONFIG['render_interval'] == 0:
                print(f"  --> Episode {episode + 1} was rendered (Reward: {total_reward:.1f})")
    
    env.close()
    env_render.close()
    return episode_rewards

# Run training on simple environment
print("=" * 60)
print("TESTING ON SIMPLE ENVIRONMENT: Pendulum-v1")
print("=" * 60)

try:
    simple_rewards = train_simple_env()
    
    print(f"\nSimple environment training completed!")
    print(f"Final 10-episode average reward: {np.mean(simple_rewards[-10:]):.1f}")
    print(f"Best 10-episode average reward: {max([np.mean(simple_rewards[i:i+10]) for i in range(len(simple_rewards)-9)]):.1f}")
    
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"Training error: {e}")
    import traceback
    traceback.print_exc()

TESTING ON SIMPLE ENVIRONMENT: Pendulum-v1
Environment: Pendulum-v1
Observation space: 3
Action space: 1 (continuous)
Action range: -2.0 to 2.0
Training for 1000000 episodes with 100 steps each
Rendering every 10 episodes

Ep   1 | Reward:   496.1 | AvgRew:    5.0 | Base:  -0.2 | Adv:    5.2 | ActMean:  0.07, Std: 0.46
Ep   2 | Reward:   217.4 | AvgRew:    2.2 | Base:  -0.1 | Adv:    2.3 | ActMean:  0.05, Std: 0.39
Ep   3 | Reward:   298.8 | AvgRew:    3.0 | Base:  -0.2 | Adv:    3.1 | ActMean:  0.02, Std: 0.38
Ep   4 | Reward:   249.0 | AvgRew:    2.5 | Base:  -0.1 | Adv:    2.6 | ActMean:  0.03, Std: 0.36
Ep   5 | Reward:   240.9 | AvgRew:    2.4 | Base:  -0.1 | Adv:    2.5 | ActMean: -0.00, Std: 0.38
Ep   6 | Reward:   250.6 | AvgRew:    2.5 | Base:  -0.2 | Adv:    2.7 | ActMean:  0.04, Std: 0.35
Ep   7 | Reward:   240.8 | AvgRew:    2.4 | Base:  -0.2 | Adv:    2.6 | ActMean:  0.08, Std: 0.35
Ep   8 | Reward:   514.1 | AvgRew:    5.1 | Base:  -0.3 | Adv:    5.5 | ActMean:  0.00, Std