In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Core imports for REINFORCE with Baseline implementation
import jax
import jax.numpy as jnp
from typing import List, Tuple

# Environment and utilities
import gymnasium as gym
import numpy as np
import time

In [3]:
class PolicyNetwork:
    def __init__(self, input_dim: int, hidden_dims: List[int], action_dim: int):
        """
        Multi-layer policy network for continuous action spaces using Gaussian policy.
        
        Args:
            input_dim: State space dimension
            hidden_dims: Hidden layer dimensions
            action_dim: Number of continuous action dimensions (6 for HalfCheetah)
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.action_dim = action_dim
        
        # Output size: 2 * action_dim (mean and log_std for each action)
        self.output_dim = 2 * action_dim
        self.layer_dims = [input_dim] + hidden_dims + [self.output_dim]
        
        self.params = self._initialize_params()
    
    def _initialize_params(self) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Initialize weights and biases using Xavier initialization."""
        key = jax.random.PRNGKey(42)
        weights, biases = [], []
        
        for i in range(len(self.layer_dims) - 1):
            key, subkey = jax.random.split(key)
            fan_in, fan_out = self.layer_dims[i], self.layer_dims[i + 1]
            
            # Xavier initialization
            limit = jnp.sqrt(6 / (fan_in + fan_out))
            w = jnp.zeros((fan_in, fan_out))
            b = jnp.zeros(fan_out)
            
            weights.append(w)
            biases.append(b)
        
        return weights, biases
    
    @staticmethod
    def forward(state: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]], 
                action_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Forward pass returning action mean and log_std."""
        weights, biases = params
        x = state
        
        # Hidden layers with ReLU activation
        for w, b in zip(weights[:-1], biases[:-1]):
            x = jax.nn.relu(x @ w + b)
        
        # Output layer
        output = x @ weights[-1] + biases[-1]
        
        # Split output into mean and log_std
        mean = output[:action_dim]
        log_std = output[action_dim:] - 0.5  # Center log_std around 0.5 for better exploration
        
        # Clamp log_std to prevent extreme values
        log_std = jnp.clip(log_std, -2.0, 0.5)  # std between ~0.135 and ~1.65
        
        return mean, log_std
    
    @staticmethod
    def log_prob(state: jnp.ndarray, actions: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]],
                action_dim: int) -> jnp.ndarray:
        """Compute log probability of actions under Gaussian policy."""
        mean, log_std = PolicyNetwork.forward(state, params, action_dim)
        std = jnp.exp(log_std)
        
        # Gaussian log probability: log p(a) = -0.5 * ((a - μ) / σ)² - log(σ) - 0.5 * log(2π)
        log_prob = -0.5 * jnp.sum(((actions - mean) / std) ** 2)
        log_prob -= jnp.sum(log_std)  # -log(σ) terms
        log_prob -= 0.5 * action_dim * jnp.log(2 * jnp.pi)  # constant term
        
        return log_prob


class ValueNetwork:
    def __init__(self, input_dim: int, hidden_dims: List[int]):
        """
        State-value function network for baseline estimation.
        
        Args:
            input_dim: State space dimension
            hidden_dims: Hidden layer dimensions
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.layer_dims = [input_dim] + hidden_dims + [1]  # Output single value
        
        self.params = self._initialize_params()
    
    def _initialize_params(self) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Initialize weights and biases using Xavier initialization."""
        key = jax.random.PRNGKey(123)  # Different seed from policy network
        weights, biases = [], []
        
        for i in range(len(self.layer_dims) - 1):
            key, subkey = jax.random.split(key)
            fan_in, fan_out = self.layer_dims[i], self.layer_dims[i + 1]
            
            # Xavier initialization for value network
            limit = jnp.sqrt(6 / (fan_in + fan_out))
            w = jax.random.uniform(subkey, (fan_in, fan_out), minval=-limit, maxval=limit)
            b = jnp.zeros(fan_out)
            
            # Make the last layer even smaller to prevent extreme values
            if i == len(self.layer_dims) - 2:  # Last layer
                w *= 0.1  # Scale down output layer
            
            weights.append(w)
            biases.append(b)
        
        return weights, biases
    
    @staticmethod
    def forward(state: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]]) -> jnp.ndarray:
        """Forward pass returning state value estimate."""
        weights, biases = params
        x = state
        
        # Hidden layers with ReLU activation
        for w, b in zip(weights[:-1], biases[:-1]):
            x = jax.nn.relu(x @ w + b)
        
        # Output layer (linear activation for value function)
        value = x @ weights[-1] + biases[-1]
        return value.squeeze()  # Remove extra dimensions

In [4]:
class REINFORCE_Agent:
    def __init__(self, 
                 input_dim: int, 
                 policy_hidden_dims: List[int],
                 value_hidden_dims: List[int], 
                 action_dim: int,
                 learning_rate_theta: float = 1e-4,
                 learning_rate_w: float = 5e-4,
                 gamma: float = 0.99,
                 gradient_clip: float = 1.0):
        """
        REINFORCE agent with baseline for continuous control using Gaussian policy.
        
        Args:
            input_dim: State space dimension
            policy_hidden_dims: Hidden layer dimensions for policy network
            value_hidden_dims: Hidden layer dimensions for value network
            action_dim: Number of continuous action dimensions
            learning_rate_theta: Learning rate for policy parameters (α_θ)
            learning_rate_w: Learning rate for value function parameters (α_w)
            gamma: Discount factor
            gradient_clip: Gradient clipping threshold
        """
        self.gamma = gamma
        self.alpha_theta = learning_rate_theta
        self.alpha_w = learning_rate_w
        self.action_dim = action_dim
        self.gradient_clip = gradient_clip
        
        # Initialize policy and value networks
        self.policy_network = PolicyNetwork(input_dim, policy_hidden_dims, action_dim)
        self.value_network = ValueNetwork(input_dim, value_hidden_dims)
        
        self.key = jax.random.PRNGKey(0)
    
    def select_action(self, state: jnp.ndarray) -> jnp.ndarray:
        """Sample action from Gaussian policy."""
        self.key, subkey = jax.random.split(self.key)
        
        # Get mean and log_std from policy network
        mean, log_std = PolicyNetwork.forward(state, self.policy_network.params, self.action_dim)
        std = jnp.exp(log_std)
        
        # Sample from Gaussian distribution
        action = mean + std * jax.random.normal(subkey, shape=(self.action_dim,))
        
        # Light clipping to environment bounds (less aggressive)
        action = jnp.clip(action, -3.0, 3.0)  # Reduced clipping range
        
        return action
    
    @staticmethod
    def compute_returns(rewards: jnp.ndarray, gamma: float) -> jnp.ndarray:
        """Compute discounted returns G_t for each time step."""
        returns = jnp.zeros_like(rewards)
        G = 0.0
        
        # Compute returns backwards: G_t = R_{t+1} + γ*R_{t+2} + γ^2*R_{t+3} + ...
        for t in reversed(range(len(rewards))):
            G = rewards[t] + gamma * G
            returns = returns.at[t].set(G)
        
        return returns
    
    def policy_gradient(self, state: jnp.ndarray, action: jnp.ndarray) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Compute gradient of log π(a|s) with respect to policy parameters."""
        def log_prob_fn(params):
            return PolicyNetwork.log_prob(state, action, params, self.action_dim)
        
        grad_fn = jax.grad(log_prob_fn)
        return grad_fn(self.policy_network.params)
    
    def value_gradient(self, state: jnp.ndarray, target: float) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Compute gradient of squared error loss for value function."""
        def value_loss_fn(params):
            predicted_value = ValueNetwork.forward(state, params)
            return 0.5 * (predicted_value - target) ** 2
        
        grad_fn = jax.grad(value_loss_fn)
        return grad_fn(self.value_network.params)
    
    def update(self, states: jnp.ndarray, actions: jnp.ndarray, rewards: jnp.ndarray):
        """
        REINFORCE with baseline update using Gaussian policy.
        Returns training metrics and action statistics.
        """
        # Compute returns G_t for each time step
        returns = self.compute_returns(rewards, self.gamma)
        
        # Compute all advantages using current value network
        advantages = []
        baselines = []
        for t in range(len(states)):
            state_t = states[t]
            return_t = returns[t]
            baseline = ValueNetwork.forward(state_t, self.value_network.params)
            advantage = return_t - baseline
            advantages.append(float(advantage))
            baselines.append(float(baseline))
        
        advantages = jnp.array(advantages)
        
        # Normalize advantages for stability (ENABLED)
        advantage_mean = jnp.mean(advantages)
        advantage_std = jnp.std(advantages) + 1e-8
        advantages = (advantages - advantage_mean) / advantage_std
        
        # Initialize gradient accumulators
        policy_weights, policy_biases = self.policy_network.params
        value_weights, value_biases = self.value_network.params
        
        grad_policy_weights = [jnp.zeros_like(w) for w in policy_weights]
        grad_policy_biases = [jnp.zeros_like(b) for b in policy_biases]
        grad_value_weights = [jnp.zeros_like(w) for w in value_weights]
        grad_value_biases = [jnp.zeros_like(b) for b in value_biases]
        
        # Accumulate gradients for all steps
        for t in range(len(states)):
            state_t = states[t]
            action_t = actions[t]
            return_t = returns[t]
            advantage = advantages[t]
            
            # Value function gradients
            grad_v_w, grad_v_b = self.value_gradient(state_t, return_t)
            for i in range(len(grad_value_weights)):
                grad_value_weights[i] += grad_v_w[i]
                grad_value_biases[i] += grad_v_b[i]
            
            # Policy gradients (scaled by advantage)
            grad_p_w, grad_p_b = self.policy_gradient(state_t, action_t)
            for i in range(len(grad_policy_weights)):
                grad_policy_weights[i] += advantage * grad_p_w[i]
                grad_policy_biases[i] += advantage * grad_p_b[i]
        
        # Apply gradient clipping and normalization
        n_steps = len(states)
        for i in range(len(grad_policy_weights)):
            # Average gradients over episode
            grad_policy_weights[i] /= n_steps
            grad_policy_biases[i] /= n_steps
            grad_value_weights[i] /= n_steps
            grad_value_biases[i] /= n_steps
            
            # Clip gradients
            grad_policy_weights[i] = jnp.clip(grad_policy_weights[i], -self.gradient_clip, self.gradient_clip)
            grad_policy_biases[i] = jnp.clip(grad_policy_biases[i], -self.gradient_clip, self.gradient_clip)
            grad_value_weights[i] = jnp.clip(grad_value_weights[i], -self.gradient_clip, self.gradient_clip)
            grad_value_biases[i] = jnp.clip(grad_value_biases[i], -self.gradient_clip, self.gradient_clip)
        
        # Apply batch updates
        new_policy_weights = [w + self.alpha_theta * gw for w, gw in zip(policy_weights, grad_policy_weights)]
        new_policy_biases = [b + self.alpha_theta * gb for b, gb in zip(policy_biases, grad_policy_biases)]
        new_value_weights = [w + self.alpha_w * gw for w, gw in zip(value_weights, grad_value_weights)]
        new_value_biases = [b + self.alpha_w * gb for b, gb in zip(value_biases, grad_value_biases)]
        
        # Update parameters
        self.policy_network.params = (new_policy_weights, new_policy_biases)
        self.value_network.params = (new_value_weights, new_value_biases)
        
        # Compute training metrics (using original advantages for logging)
        orig_advantages = jnp.array([float(returns[t] - baselines[t]) for t in range(len(states))])
        avg_return = float(jnp.mean(returns))
        avg_advantage = float(jnp.mean(orig_advantages))
        avg_baseline = float(jnp.mean(jnp.array(baselines)))
        advantage_std = float(jnp.std(orig_advantages) + 1e-8)
        baseline_std = float(jnp.std(jnp.array(baselines)) + 1e-8)
        
        # Compute action statistics for logging
        action_mean = float(jnp.mean(actions))
        action_std = float(jnp.std(actions) + 1e-8)
        action_min = float(jnp.min(actions))
        action_max = float(jnp.max(actions))
        
        # Check for NaNs
        if jnp.isnan(avg_advantage) or jnp.isnan(avg_baseline):
            print(f"WARNING: NaN detected!")
            return 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        
        return avg_return, avg_advantage, avg_baseline, advantage_std, baseline_std, action_mean, action_std, action_min, action_max

# Test the updated REINFORCE with Gaussian policy for HalfCheetah
agent = REINFORCE_Agent(
    input_dim=17,                # HalfCheetah state space
    policy_hidden_dims=[64, 64], 
    value_hidden_dims=[64, 32],
    action_dim=6,                # HalfCheetah action space (6 joints)
    learning_rate_theta=1e-4,    
    learning_rate_w=5e-4,        
    gamma=0.99,
    gradient_clip=1.0            
)

# Test Gaussian action selection
state = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7])  # 17-dim state
print("Testing REINFORCE with Gaussian Policy for HalfCheetah:")
for i in range(3):
    action = agent.select_action(state)
    baseline_value = ValueNetwork.forward(state, agent.value_network.params)
    mean, log_std = PolicyNetwork.forward(state, agent.policy_network.params, agent.action_dim)
    print(f"Action: {action}")
    print(f"Policy mean: {mean}")
    print(f"Policy log_std: {log_std}")
    print(f"Estimated state value: {baseline_value:.4f}")
    print()

# Test update with dummy data
dummy_states = jnp.array([
    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8],
    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6]
])
dummy_actions = jnp.array([[0.1, -0.2, 0.3, -0.1, 0.0, 0.2], [0.0, 0.1, -0.3, 0.2, -0.1, 0.0], [-0.1, 0.0, 0.1, 0.0, 0.3, -0.2]])  # Continuous actions
dummy_rewards = jnp.array([1.0, -0.5, 0.8])

result = agent.update(dummy_states, dummy_actions, dummy_rewards)
avg_return, avg_advantage, avg_baseline, advantage_std, baseline_std, action_mean, action_std, action_min, action_max = result

print(f"Average return: {avg_return:.4f}")
print(f"Average advantage: {avg_advantage:.4f}")
print(f"Average baseline: {avg_baseline:.4f}")
print(f"Advantage std: {advantage_std:.4f}")
print(f"Baseline std: {baseline_std:.4f}")
print(f"Action mean: {action_mean:.4f}, std: {action_std:.4f}, min: {action_min:.4f}, max: {action_max:.4f}")

Testing REINFORCE with Gaussian Policy for HalfCheetah:
Action: [-0.8844393  -1.241595   -0.86387324  0.70867616 -0.5918747  -0.77141565]
Policy mean: [0. 0. 0. 0. 0. 0.]
Policy log_std: [-0.5 -0.5 -0.5 -0.5 -0.5 -0.5]
Estimated state value: -0.1375

Action: [ 0.56516916 -0.00244742 -0.44902968 -0.8656496   0.6427402   0.5636934 ]
Policy mean: [0. 0. 0. 0. 0. 0.]
Policy log_std: [-0.5 -0.5 -0.5 -0.5 -0.5 -0.5]
Estimated state value: -0.1375

Action: [ 0.70878166  0.18946531 -0.3463896   0.08290668  0.44997555  0.02333077]
Policy mean: [0. 0. 0. 0. 0. 0.]
Policy log_std: [-0.5 -0.5 -0.5 -0.5 -0.5 -0.5]
Estimated state value: -0.1375

Action: [-0.8844393  -1.241595   -0.86387324  0.70867616 -0.5918747  -0.77141565]
Policy mean: [0. 0. 0. 0. 0. 0.]
Policy log_std: [-0.5 -0.5 -0.5 -0.5 -0.5 -0.5]
Estimated state value: -0.1375

Action: [ 0.56516916 -0.00244742 -0.44902968 -0.8656496   0.6427402   0.5636934 ]
Policy mean: [0. 0. 0. 0. 0. 0.]
Policy log_std: [-0.5 -0.5 -0.5 -0.5 -0.5 -0.5]
E

In [None]:
import gymnasium as gym
import numpy as np
import time

# Quick test configuration for HalfCheetah with Gaussian policy
CONFIG = {
    'gamma': 0.99,
    'episode_length': 300,       # Increased episode length for better learning
    'episodes': 50,              # Quick test with 50 episodes
    'learning_rate_theta': 1e-4, # Reduced policy learning rate
    'learning_rate_w': 2e-4,     # Reduced value function learning rate
    'policy_hidden_dims': [16, 16],    # Larger policy network
    'value_hidden_dims': [16, 16],     # Larger value network
    'log_interval': 1,          # Log every 10 episodes
    'gradient_clip': 0.5         # Reduced gradient clipping
}

def run_demo_episode(agent, env_render, episode_num):
    """Run a single demonstration episode with rendering."""
    print(f"\n--- Demonstration Episode {episode_num} ---")
    observation, _ = env_render.reset()
    total_reward = 0
    step_count = 0
    episode_actions = []
    
    for step in range(CONFIG['episode_length']):
        # Convert observation to JAX array
        state = jnp.array(observation, dtype=jnp.float32)
        
        # Select action using Gaussian policy
        action = agent.select_action(state)
        episode_actions.append(action)
        
        # Take step in environment
        observation, reward, terminated, truncated, _ = env_render.step(np.array(action))
        total_reward += reward
        step_count += 1
        
        if terminated or truncated:
            break
    
    # Action statistics for demo episode
    episode_actions = jnp.stack(episode_actions)
    action_mean = float(jnp.mean(episode_actions))
    action_std = float(jnp.std(episode_actions))
    
    print(f"Demo episode reward: {total_reward:.2f} (steps: {step_count})")
    print(f"Demo action mean: {action_mean:.3f}, std: {action_std:.3f}")
    return total_reward

def train_agent():
    """Train REINFORCE agent with Gaussian policy on HalfCheetah."""
    # Training environment
    env = gym.make("HalfCheetah-v5", render_mode='human')
    
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    print(f"Environment: HalfCheetah-v5")
    print(f"Observation space: {obs_dim}")
    print(f"Action space: {action_dim} (continuous, Gaussian policy)")
    print(f"Action range: {env.action_space.low[0]:.1f} to {env.action_space.high[0]:.1f}")
    print(f"Training for {CONFIG['episodes']} episodes with {CONFIG['episode_length']} steps each")
    
    # Create agent with Gaussian policy
    agent = REINFORCE_Agent(
        input_dim=obs_dim,
        policy_hidden_dims=CONFIG['policy_hidden_dims'],
        value_hidden_dims=CONFIG['value_hidden_dims'],
        action_dim=action_dim,
        learning_rate_theta=CONFIG['learning_rate_theta'],
        learning_rate_w=CONFIG['learning_rate_w'],
        gamma=CONFIG['gamma'],
        gradient_clip=CONFIG['gradient_clip']
    )
    
    episode_rewards = []
    
    for episode in range(CONFIG['episodes']):
        observation, _ = env.reset()
        
        # Collect trajectory
        states, actions_list, rewards = [], [], []
        total_reward = 0
        
        for step in range(CONFIG['episode_length']):
            # Convert observation to JAX array
            state = jnp.array(observation, dtype=jnp.float32)
            
            # Select action using Gaussian policy
            action = agent.select_action(state)
            
            # Store transition data
            states.append(state)
            actions_list.append(action)
            
            # Take step in environment
            observation, reward, terminated, truncated, _ = env.step(np.array(action))
            rewards.append(reward)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        # Convert to JAX arrays
        states = jnp.stack(states)
        actions = jnp.stack(actions_list)
        rewards = jnp.array(rewards)
        
        # Update policy using REINFORCE with baseline
        result = agent.update(states, actions, rewards)
        avg_return, avg_advantage, avg_baseline, advantage_std, baseline_std, action_mean, action_std, action_min, action_max = result
        
        episode_rewards.append(total_reward)
        
        # Log progress with action statistics (FIXED: single print statement)
        if (episode + 1) % CONFIG['log_interval'] == 0:
            recent_reward = np.mean(episode_rewards[-CONFIG['log_interval']])
            
            print(f"Episode {episode + 1:4d} | Reward: {recent_reward:8.2f} | Return: {avg_return:8.2f} | Baseline: {avg_baseline:6.2f} | Advantage: {avg_advantage:8.2f} | Adv Std: {advantage_std:6.3f} | Base Std: {baseline_std:6.3f} | Act Mean: {action_mean:6.3f}, Std: {action_std:6.3f}, Min: {action_min:6.3f}, Max: {action_max:6.3f}")
    
    env.close()
    return episode_rewards

# Start training with Gaussian policy
try:
    print("Starting REINFORCE with Gaussian policy training...\n")
    
    training_rewards = train_agent()
    
    print(f"\nTraining completed!")
    print(f"Final average training reward: {np.mean(training_rewards[-10:]):.2f}")
        
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
except Exception as e:
    print(f"Training error: {e}")
    print("This might be due to environment setup issues.")

Starting REINFORCE with Gaussian policy training...

Environment: HalfCheetah-v5
Observation space: 17
Action space: 6 (continuous, Gaussian policy)
Action range: -1.0 to 1.0
Training for 50 episodes with 300 steps each
Episode    1 | Reward:   -80.06 | Return:   -16.74 | Baseline:  -0.06 | Advantage:   -16.68 | Adv Std: 13.196 | Base Std:  0.116 | Act Mean:  0.004, Std:  0.616, Min: -2.288, Max:  2.222
Episode    1 | Reward:   -80.06 | Return:   -16.74 | Baseline:  -0.06 | Advantage:   -16.68 | Adv Std: 13.196 | Base Std:  0.116 | Act Mean:  0.004, Std:  0.616, Min: -2.288, Max:  2.222
Episode    2 | Reward:  -138.34 | Return:   -29.61 | Baseline:  -0.06 | Advantage:   -29.55 | Adv Std: 10.720 | Base Std:  0.109 | Act Mean: -0.019, Std:  0.611, Min: -2.111, Max:  2.033
Episode    2 | Reward:  -138.34 | Return:   -29.61 | Baseline:  -0.06 | Advantage:   -29.55 | Adv Std: 10.720 | Base Std:  0.109 | Act Mean: -0.019, Std:  0.611, Min: -2.111, Max:  2.033
Episode    3 | Reward:    -9.29 