In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Core imports for REINFORCE with Baseline implementation
import jax
import jax.numpy as jnp
from typing import List, Tuple

# Environment and utilities
import gymnasium as gym
import numpy as np
import time

In [3]:
class PolicyNetwork:
    def __init__(self, input_dim: int, hidden_dims: List[int], action_dim: int, discrete_bins: int = 5):
        """
        Multi-layer policy network for continuous action spaces via discretization.
        
        Args:
            input_dim: State space dimension
            hidden_dims: Hidden layer dimensions
            action_dim: Number of continuous action dimensions (8 for Ant)
            discrete_bins: Number of discrete bins per action dimension
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.action_dim = action_dim
        self.discrete_bins = discrete_bins
        
        # Total output size: action_dim * discrete_bins
        self.output_dim = action_dim * discrete_bins
        self.layer_dims = [input_dim] + hidden_dims + [self.output_dim]
        
        self.params = self._initialize_params()
    
    def _initialize_params(self) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Initialize weights and biases using Xavier initialization."""
        key = jax.random.PRNGKey(42)
        weights, biases = [], []
        
        for i in range(len(self.layer_dims) - 1):
            key, subkey = jax.random.split(key)
            fan_in, fan_out = self.layer_dims[i], self.layer_dims[i + 1]
            
            # Initialize weights and biases to 0
            w = jnp.zeros((fan_in, fan_out))
            b = jnp.zeros(fan_out)
            
            weights.append(w)
            biases.append(b)
        
        return weights, biases
    
    @staticmethod
    def forward(state: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]], 
                action_dim: int, discrete_bins: int) -> jnp.ndarray:
        """Forward pass returning action probabilities for each dimension."""
        weights, biases = params
        x = state
        
        # Hidden layers with ReLU activation
        for w, b in zip(weights[:-1], biases[:-1]):
            x = jax.nn.relu(x @ w + b)
        
        # Output layer
        logits = x @ weights[-1] + biases[-1]
        
        # Reshape to (action_dim, discrete_bins) and apply softmax per action dimension
        logits = logits.reshape((action_dim, discrete_bins))
        action_probs = jax.nn.softmax(logits, axis=1)
        
        return action_probs
    
    @staticmethod
    def log_prob(state: jnp.ndarray, actions: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]],
                action_dim: int, discrete_bins: int) -> jnp.ndarray:
        """Compute log probability of actions."""
        action_probs = PolicyNetwork.forward(state, params, action_dim, discrete_bins)
        
        # Get log probabilities for selected actions
        log_probs = jnp.log(action_probs + 1e-8)
        selected_log_probs = log_probs[jnp.arange(action_dim), actions]
        
        return jnp.sum(selected_log_probs)  # Sum log probs (product of probs)


class ValueNetwork:
    def __init__(self, input_dim: int, hidden_dims: List[int]):
        """
        State-value function network for baseline estimation.
        
        Args:
            input_dim: State space dimension
            hidden_dims: Hidden layer dimensions
        """
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.layer_dims = [input_dim] + hidden_dims + [1]  # Output single value
        
        self.params = self._initialize_params()
    
    def _initialize_params(self) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Initialize weights and biases using Xavier initialization."""
        key = jax.random.PRNGKey(123)  # Different seed from policy network
        weights, biases = [], []
        
        for i in range(len(self.layer_dims) - 1):
            key, subkey = jax.random.split(key)
            fan_in, fan_out = self.layer_dims[i], self.layer_dims[i + 1]
            
            # Xavier initialization for value network
            limit = jnp.sqrt(6 / (fan_in + fan_out))
            w = jax.random.uniform(subkey, (fan_in, fan_out), minval=-limit, maxval=limit)
            b = jnp.zeros(fan_out)
            
            # Make the last layer even smaller to prevent extreme values
            if i == len(self.layer_dims) - 2:  # Last layer
                w *= 0.1  # Scale down output layer
            
            weights.append(w)
            biases.append(b)
        
        return weights, biases
    
    @staticmethod
    def forward(state: jnp.ndarray, params: Tuple[List[jnp.ndarray], List[jnp.ndarray]]) -> jnp.ndarray:
        """Forward pass returning state value estimate."""
        weights, biases = params
        x = state
        
        # Hidden layers with ReLU activation
        for w, b in zip(weights[:-1], biases[:-1]):
            x = jax.nn.relu(x @ w + b)
        
        # Output layer (linear activation for value function)
        value = x @ weights[-1] + biases[-1]
        return value.squeeze()  # Remove extra dimensions

In [4]:
class REINFORCE_Agent:
    def __init__(self, 
                 input_dim: int, 
                 policy_hidden_dims: List[int],
                 value_hidden_dims: List[int], 
                 action_dim: int,
                 discrete_bins: int = 5,
                 learning_rate_theta: float = 1e-4,
                 learning_rate_w: float = 5e-4,
                 gamma: float = 0.99,
                 gradient_clip: float = 1.0):
        """
        REINFORCE agent with baseline for variance reduction.
        Uses classical REINFORCE update rule with state-value function baseline.
        
        Args:
            input_dim: State space dimension
            policy_hidden_dims: Hidden layer dimensions for policy network
            value_hidden_dims: Hidden layer dimensions for value network
            action_dim: Number of continuous action dimensions
            discrete_bins: Number of discrete bins per action dimension
            learning_rate_theta: Learning rate for policy parameters (α_θ)
            learning_rate_w: Learning rate for value function parameters (α_w)
            gamma: Discount factor
            gradient_clip: Gradient clipping threshold
        """
        self.gamma = gamma
        self.alpha_theta = learning_rate_theta
        self.alpha_w = learning_rate_w
        self.action_dim = action_dim
        self.discrete_bins = discrete_bins
        self.gradient_clip = gradient_clip
        
        # Initialize policy and value networks with different architectures
        self.policy_network = PolicyNetwork(input_dim, policy_hidden_dims, action_dim, discrete_bins)
        self.value_network = ValueNetwork(input_dim, value_hidden_dims)
        
        self.key = jax.random.PRNGKey(0)
    
    def select_action(self, state: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Select discrete actions for each dimension, then convert to continuous."""
        self.key, subkey = jax.random.split(self.key)
        
        # Get action probabilities: shape (action_dim, discrete_bins)
        action_probs = PolicyNetwork.forward(
            state, self.policy_network.params, self.action_dim, self.discrete_bins
        )
        
        # Sample discrete action for each dimension
        keys = jax.random.split(subkey, self.action_dim)
        discrete_actions = jnp.array([
            jax.random.categorical(key, jnp.log(action_probs[i] + 1e-8))
            for i, key in enumerate(keys)
        ])
        
        # Convert discrete actions to continuous [-1, 1]
        continuous_actions = self._discrete_to_continuous(discrete_actions)
        
        return discrete_actions, continuous_actions
    
    def _discrete_to_continuous(self, discrete_actions: jnp.ndarray) -> jnp.ndarray:
        """Convert discrete actions to continuous [-1, 1] range for HalfCheetah."""
        # Map discrete bins [0, 1, 2, 3, 4] to continuous values [-1, -0.5, 0, 0.5, 1]
        continuous_actions = 2.0 * discrete_actions / (self.discrete_bins - 1) - 1.0
        return jnp.clip(continuous_actions, -0.8, 0.8)
    
    @staticmethod
    def compute_returns(rewards: jnp.ndarray, gamma: float) -> jnp.ndarray:
        """Compute discounted returns G_t for each time step."""
        returns = jnp.zeros_like(rewards)
        G = 0.0
        
        # Compute returns backwards: G_t = R_{t+1} + γ*R_{t+2} + γ^2*R_{t+3} + ...
        for t in reversed(range(len(rewards))):
            G = rewards[t] + gamma * G
            returns = returns.at[t].set(G)
        
        return returns
    
    def policy_gradient(self, state: jnp.ndarray, action: jnp.ndarray) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Compute gradient of log π(a|s) with respect to policy parameters."""
        def log_prob_fn(params):
            return PolicyNetwork.log_prob(state, action, params, self.action_dim, self.discrete_bins)
        
        grad_fn = jax.grad(log_prob_fn)
        return grad_fn(self.policy_network.params)
    
    def value_gradient(self, state: jnp.ndarray, target: float) -> Tuple[List[jnp.ndarray], List[jnp.ndarray]]:
        """Compute gradient of squared error loss for value function."""
        def value_loss_fn(params):
            predicted_value = ValueNetwork.forward(state, params)
            return 0.5 * (predicted_value - target) ** 2
        
        grad_fn = jax.grad(value_loss_fn)
        return grad_fn(self.value_network.params)
    
    def update(self, states: jnp.ndarray, actions: jnp.ndarray, rewards: jnp.ndarray):
        """
        REINFORCE with baseline update (proper episode-level batch updates):
        1. Compute returns G_t for all steps
        2. Compute advantages for all steps using current value network
        3. Compute gradients for all steps
        4. Apply batch updates to both networks
        """
        # Compute returns G_t for each time step
        returns = self.compute_returns(rewards, self.gamma)
        
        # Compute all advantages using current value network (before any updates)
        advantages = []
        baselines = []
        for t in range(len(states)):
            state_t = states[t]
            return_t = returns[t]
            baseline = ValueNetwork.forward(state_t, self.value_network.params)
            advantage = return_t - baseline
            advantages.append(float(advantage))
            baselines.append(float(baseline))
        
        advantages = jnp.array(advantages)
        
        # Optional: Normalize advantages for stability (commented out for pure REINFORCE)
        # advantage_mean = jnp.mean(advantages)
        # advantage_std = jnp.std(advantages) + 1e-8
        # advantages = (advantages - advantage_mean) / advantage_std
        
        # Initialize gradient accumulators
        policy_weights, policy_biases = self.policy_network.params
        value_weights, value_biases = self.value_network.params
        
        grad_policy_weights = [jnp.zeros_like(w) for w in policy_weights]
        grad_policy_biases = [jnp.zeros_like(b) for b in policy_biases]
        grad_value_weights = [jnp.zeros_like(w) for w in value_weights]
        grad_value_biases = [jnp.zeros_like(b) for b in value_biases]
        
        # Accumulate gradients for all steps
        for t in range(len(states)):
            state_t = states[t]
            action_t = actions[t]
            return_t = returns[t]
            advantage = advantages[t]
            
            # === Value Function Gradients (NOT scaled by advantage) ===
            grad_v_w, grad_v_b = self.value_gradient(state_t, return_t)
            for i in range(len(grad_value_weights)):
                grad_value_weights[i] += grad_v_w[i]  # No advantage scaling here!
                grad_value_biases[i] += grad_v_b[i]   # No advantage scaling here!
            
            # === Policy Gradients (scaled by advantage) ===
            grad_p_w, grad_p_b = self.policy_gradient(state_t, action_t)
            for i in range(len(grad_policy_weights)):
                grad_policy_weights[i] += advantage * grad_p_w[i]
                grad_policy_biases[i] += advantage * grad_p_b[i]
        
        # Apply gradient clipping and normalization
        n_steps = len(states)
        for i in range(len(grad_policy_weights)):
            # Average gradients over episode
            grad_policy_weights[i] /= n_steps
            grad_policy_biases[i] /= n_steps
            grad_value_weights[i] /= n_steps
            grad_value_biases[i] /= n_steps
            
            # Clip gradients
            grad_policy_weights[i] = jnp.clip(grad_policy_weights[i], -self.gradient_clip, self.gradient_clip)
            grad_policy_biases[i] = jnp.clip(grad_policy_biases[i], -self.gradient_clip, self.gradient_clip)
            grad_value_weights[i] = jnp.clip(grad_value_weights[i], -self.gradient_clip, self.gradient_clip)
            grad_value_biases[i] = jnp.clip(grad_value_biases[i], -self.gradient_clip, self.gradient_clip)
        
        # Apply batch updates (single update per episode)
        # Policy update: θ ← θ + α_θ * ∇_θ
        new_policy_weights = [w + self.alpha_theta * gw for w, gw in zip(policy_weights, grad_policy_weights)]
        new_policy_biases = [b + self.alpha_theta * gb for b, gb in zip(policy_biases, grad_policy_biases)]
        
        # Value function update: w ← w + α_w * ∇_w
        new_value_weights = [w + self.alpha_w * gw for w, gw in zip(value_weights, grad_value_weights)]
        new_value_biases = [b + self.alpha_w * gb for b, gb in zip(value_biases, grad_value_biases)]
        
        # Update parameters
        self.policy_network.params = (new_policy_weights, new_policy_biases)
        self.value_network.params = (new_value_weights, new_value_biases)
        
        # Return metrics for logging
        avg_return = float(jnp.mean(returns))
        avg_advantage = float(jnp.mean(advantages))
        avg_baseline = float(jnp.mean(jnp.array(baselines)))
        
        # Check for NaNs and provide debugging info
        if jnp.isnan(avg_advantage) or jnp.isnan(avg_baseline):
            print(f"WARNING: NaN detected!")
            print(f"Returns: {returns}")
            print(f"Baselines: {baselines}")
            print(f"Advantages: {advantages}")
            # Return safe values
            return 0.0, 0.0, 0.0
        
        advantage_std = jnp.std(advantages) + 1e-8
        baseline_std = jnp.std(jnp.array(baselines)) + 1e-8
        
        return avg_return, avg_advantage, avg_baseline, advantage_std, baseline_std

# Test the REINFORCE with baseline agent for HalfCheetah-v4
agent = REINFORCE_Agent(
    input_dim=17,                # HalfCheetah state space
    policy_hidden_dims=[64, 64], 
    value_hidden_dims=[64, 32],
    action_dim=6,                # HalfCheetah action space (6 joints)
    discrete_bins=5,
    learning_rate_theta=1e-4,    
    learning_rate_w=5e-4,        
    gamma=0.99,
    gradient_clip=1.0            
)

# Test action selection
state = jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7])  # 17-dim state
print("Testing REINFORCE with Baseline for HalfCheetah:")
for i in range(3):
    discrete_actions, continuous_actions = agent.select_action(state)
    baseline_value = ValueNetwork.forward(state, agent.value_network.params)
    print(f"Discrete actions: {discrete_actions}")
    print(f"Continuous actions: {continuous_actions}")
    print(f"Estimated state value: {baseline_value:.4f}")
    print()

# Test update with dummy data (17-dim states, 6-dim actions)
dummy_states = jnp.array([
    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8],
    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6]
])
dummy_actions = jnp.array([[0, 1, 2, 3, 4, 0], [1, 0, 3, 2, 1, 4], [2, 4, 1, 0, 3, 2]])  # 6-dim actions
dummy_rewards = jnp.array([1.0, -0.5, 0.8])

avg_return, avg_advantage, avg_baseline, advantage_std, baseline_std = agent.update(dummy_states, dummy_actions, dummy_rewards)
print(f"Average return: {avg_return:.4f}")
print(f"Average advantage: {avg_advantage:.4f}")
print(f"Average baseline: {avg_baseline:.4f}")
print(f"Advantage std: {advantage_std:.4f}")
print(f"Baseline std: {baseline_std:.4f}")

Testing REINFORCE with Baseline for HalfCheetah:
Discrete actions: [3 4 4 4 4 1]
Continuous actions: [ 0.5  0.8  0.8  0.8  0.8 -0.5]
Estimated state value: -0.1375

Discrete actions: [3 0 1 4 2 0]
Continuous actions: [ 0.5 -0.8 -0.5  0.8  0.  -0.8]
Estimated state value: -0.1375

Discrete actions: [1 0 0 1 2 1]
Continuous actions: [-0.5 -0.8 -0.8 -0.5  0.  -0.5]
Estimated state value: -0.1375

Discrete actions: [3 4 4 4 4 1]
Continuous actions: [ 0.5  0.8  0.8  0.8  0.8 -0.5]
Estimated state value: -0.1375

Discrete actions: [3 0 1 4 2 0]
Continuous actions: [ 0.5 -0.8 -0.5  0.8  0.  -0.8]
Estimated state value: -0.1375

Discrete actions: [1 0 0 1 2 1]
Continuous actions: [-0.5 -0.8 -0.8 -0.5  0.  -0.5]
Estimated state value: -0.1375

Average return: 0.7937
Average advantage: 0.9311
Average baseline: -0.1374
Advantage std: 0.4030
Baseline std: 0.0083
Average return: 0.7937
Average advantage: 0.9311
Average baseline: -0.1374
Advantage std: 0.4030
Baseline std: 0.0083


In [5]:
import gymnasium as gym
import numpy as np
import time

# Training configuration for HalfCheetah-v4
CONFIG = {
    'gamma': 0.99,
    'episode_length': 250,       # Shorter episodes for faster training
    'episodes': 100000000000,            # Fewer episodes but more focused
    'learning_rate_theta': 5e-4, # Higher learning rate for faster convergence
    'learning_rate_w': 5e-4,     # Higher learning rate for value function
    'policy_hidden_dims': [17, 32],    # Good network size for HalfCheetah
    'value_hidden_dims': [17, 8],     # Smaller value network
    'discrete_bins': 7,          # 5 bins per action dimension
    'log_interval': 1,          # Log every 10 episodes
    'render_interval': 1,       # Render every 50 episodes
    'render_episodes': 1,
    'render_episode_length': 200,
    'gradient_clip': 1.0         # Standard gradient clipping
}

def run_demo_episode(agent, env_render, episode_num):
    """Run a single demonstration episode with rendering."""
    print(f"\n--- Demonstration Episode {episode_num} ---")
    observation, _ = env_render.reset()
    total_reward = 0
    step_count = 0
    
    for step in range(CONFIG['render_episode_length']):
        # Convert observation to JAX array
        state = jnp.array(observation, dtype=jnp.float32)
        
        # Select action using current policy
        discrete_actions, continuous_actions = agent.select_action(state)
        
        # Take step in environment
        observation, reward, terminated, truncated, _ = env_render.step(np.array(continuous_actions))
        total_reward += reward
        step_count += 1
        
        
        if terminated or truncated:
            break
    
    print(f"Demo episode reward: {total_reward:.2f} (steps: {step_count})")
    return total_reward

def train_agent():
    """Train REINFORCE agent on HalfCheetah environment with periodic rendering."""
    # Training environment (no rendering for speed)
    env = gym.make("HalfCheetah-v5", render_mode='human')
    
    # Separate environment for rendering demonstrations
    env_render = gym.make("HalfCheetah-v5", render_mode="human")
    
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    print(f"Environment: HalfCheetah-v5")
    print(f"Observation space: {obs_dim}")
    print(f"Action space: {action_dim} (continuous)")
    print(f"Discrete bins per action: {CONFIG['discrete_bins']}")
    print(f"Action range: {env.action_space.low[0]:.1f} to {env.action_space.high[0]:.1f}")
    print(f"Rendering every {CONFIG['render_interval']} episodes")
    
    # Create agent with baseline
    agent = REINFORCE_Agent(
        input_dim=obs_dim,
        policy_hidden_dims=CONFIG['policy_hidden_dims'],
        value_hidden_dims=CONFIG['value_hidden_dims'],
        action_dim=action_dim,
        discrete_bins=CONFIG['discrete_bins'],
        learning_rate_theta=CONFIG['learning_rate_theta'],
        learning_rate_w=CONFIG['learning_rate_w'],
        gamma=CONFIG['gamma'],
        gradient_clip=CONFIG['gradient_clip']
    )
    
    episode_rewards = []
    demo_rewards = []
    
    for episode in range(CONFIG['episodes']):
        observation, _ = env.reset()
        
        # Collect trajectory
        states, discrete_actions_list, rewards = [], [], []
        total_reward = 0
        
        for step in range(CONFIG['episode_length']):
            # Convert observation to JAX array
            state = jnp.array(observation, dtype=jnp.float32)
            
            # Select action (returns both discrete and continuous)
            discrete_actions, continuous_actions = agent.select_action(state)
            
            # Store transition data
            states.append(state)
            discrete_actions_list.append(discrete_actions)
            
            # Take step in environment using continuous actions
            observation, reward, terminated, truncated, _ = env.step(np.array(continuous_actions))
            rewards.append(reward)
            total_reward += reward
            
            if terminated or truncated:
                break
        
        # Convert to JAX arrays
        states = jnp.stack(states)
        discrete_actions_array = jnp.stack(discrete_actions_list)
        rewards = jnp.array(rewards)
        
        # Update policy using REINFORCE with baseline
        avg_return, avg_advantage, avg_baseline, advantage_std, baseline_std = agent.update(states, discrete_actions_array, rewards)
        episode_rewards.append(total_reward)
        
        # Log progress
        if (episode + 1) % CONFIG['log_interval'] == 0:
            recent_reward = np.mean(episode_rewards[-CONFIG['log_interval']])
            
            print(f"Episode {episode + 1:3d} | Reward: {recent_reward:6.2f} | Return: {avg_return:6.2f} | "
                  f"Baseline: {avg_baseline:6.2f} | Advantage: {avg_advantage:6.2f} | "
                  f"Advantage Std: {advantage_std:.2f} | Baseline Std: {baseline_std:.2f}")
        
        # # Render demonstration episodes
        # if (episode + 1) % CONFIG['render_interval'] == 0:
        #     for demo_ep in range(CONFIG['render_episodes']):
        #         demo_reward = run_demo_episode(agent, env_render, episode + 1)
        #         demo_rewards.append(demo_reward)
            
    
    env.close()
    env_render.close()
    
    return episode_rewards, demo_rewards

# Start training with rendering
try:
    print("Starting REINFORCE with baseline training...\n")
    
    training_rewards, demo_rewards = train_agent()
    
    print(f"\nTraining completed!")
    print(f"Final average training reward: {np.mean(training_rewards[-10:]):.2f}")
    if demo_rewards:
        print(f"Final average demo reward: {np.mean(demo_rewards[-3:]):.2f}")
        
except KeyboardInterrupt:
    print("\nTraining interrupted by user.")
    #Close environments gracefully
except Exception as e:
    print(f"Training error: {e}")
    print("This might be due to environment setup issues.")

Starting REINFORCE with baseline training...

Environment: HalfCheetah-v5
Observation space: 17
Action space: 6 (continuous)
Discrete bins per action: 7
Action range: -1.0 to 1.0
Rendering every 1 episodes
Episode   1 | Reward: -138.09 | Return: -36.71 | Baseline:  -0.24 | Advantage: -36.47 | Advantage Std: 12.89 | Baseline Std: 0.25
Episode   1 | Reward: -138.09 | Return: -36.71 | Baseline:  -0.24 | Advantage: -36.47 | Advantage Std: 12.89 | Baseline Std: 0.25
Episode   2 | Reward: -75.84 | Return: -18.78 | Baseline:  -0.23 | Advantage: -18.55 | Advantage Std: 11.93 | Baseline Std: 0.23
Episode   2 | Reward: -75.84 | Return: -18.78 | Baseline:  -0.23 | Advantage: -18.55 | Advantage Std: 11.93 | Baseline Std: 0.23
Episode   3 | Reward: -98.09 | Return: -23.44 | Baseline:  -0.24 | Advantage: -23.20 | Advantage Std: 13.75 | Baseline Std: 0.27
Episode   3 | Reward: -98.09 | Return: -23.44 | Baseline:  -0.24 | Advantage: -23.20 | Advantage Std: 13.75 | Baseline Std: 0.27
Episode   4 | Rewa