In [1]:

from abc import ABC, abstractmethod
from typing import Any
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from rl.cstr.optimization.base_agent import compute_gae

In [2]:
def sample_trajectory_data():
    """
    Fixture providing sample trajectory data for testing compute_gae.
    
    Creates realistic trajectory data that would be returned by collect_trajectories:
    - rewards: Conversion efficiency rewards from CSTR control
    - dones: Episode termination flags
    - values: Critic's value estimates for each state
    
    Returns:
        tuple: (rewards, dones, values) for GAE computation
    """
    # Sample trajectory data (10 timesteps)
    rewards = [15.2, 12.8, 18.1, 14.5, 16.3, 13.7, 17.9, 15.8, 14.2, 16.7]
    dones = [False, False, False, False, False, False, False, False, False, True]
    values = [15.0, 13.0, 17.5, 14.0, 16.0, 13.5, 17.0, 15.5, 14.0, 16.5]
    
    return rewards, dones, values

In [3]:
# Given: Sample trajectory data (10 timesteps)
rewards, dones, values = sample_trajectory_data()

In [4]:
rewards

[15.2, 12.8, 18.1, 14.5, 16.3, 13.7, 17.9, 15.8, 14.2, 16.7]

In [5]:
dones

[False, False, False, False, False, False, False, False, False, True]

In [6]:
values

[15.0, 13.0, 17.5, 14.0, 16.0, 13.5, 17.0, 15.5, 14.0, 16.5]

In [7]:
# When: Computing GAE with default parameters
gae_lambda = 0.95
gae_advantages_normalized, total_expected_future_rewards, raw_gae_advantages = compute_gae(rewards, dones, values, gae_lambda)

In [8]:
gae_advantages_normalized

tensor([ 1.3105,  1.1703,  0.8678,  0.6285,  0.2950,  0.0233, -0.4003, -0.8421,
        -1.2547, -1.7983])

In [9]:
total_expected_future_rewards

tensor([105.8748,  99.7865,  95.4629,  84.9838,  77.2563,  66.8311,  57.9763,
         43.5901,  30.0555,  16.7000])

In [10]:
raw_gae_advantages

tensor([90.8748, 86.7865, 77.9629, 70.9838, 61.2563, 53.3311, 40.9763, 28.0901,
        16.0555,  0.2000])

In [11]:
advantages_normalized_manual = (raw_gae_advantages - raw_gae_advantages.mean()) / (raw_gae_advantages.std() + 1e-8)

In [12]:
advantages_normalized_manual

tensor([ 1.2432,  1.1102,  0.8233,  0.5963,  0.2799,  0.0221, -0.3797, -0.7989,
        -1.1903, -1.7060])

In [13]:
gae_advantages_normalized

tensor([ 1.3105,  1.1703,  0.8678,  0.6285,  0.2950,  0.0233, -0.4003, -0.8421,
        -1.2547, -1.7983])

In [16]:
advantages_mean_1 = gae_advantages_normalized.mean()

In [17]:
advantages_mean_1

tensor(5.9605e-09)

In [18]:
advantages_mean = gae_advantages_normalized.mean().item()

In [19]:
advantages_mean

5.9604645663569045e-09

In [20]:
assert advantages_mean < 1e-6

In [23]:
advantages_std_1 = gae_advantages_normalized.std()

In [24]:
advantages_std_1

tensor(1.0541)

In [21]:
advantages_std = gae_advantages_normalized.std().item()

In [22]:
advantages_std

1.054092526435852

In [26]:
advantages_std - 1.0

0.05409252643585205

In [25]:
assert abs(advantages_std - 1.0) < 1e-6

AssertionError: 

In [14]:
advantages_std_2 = advantages_normalized_manual.std().item()

In [15]:
advantages_std_2

1.0

In [None]:


    # def test_compute_gae_episode_termination(self, sample_trajectory_data):
    #     """
    #     Test that compute_gae handles episode termination correctly.
    #     """
    #     # Given: Sample trajectory data with episode termination
    #     rewards, dones, values = sample_trajectory_data
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle episode termination correctly
    #     # Check that the last timestep has a done flag
    #     assert dones[-1] == True, \
    #         "Last timestep should have done=True"
        
    #     # Check that gae_values and gae_advantages are computed correctly
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_different_gamma(self, sample_trajectory_data):
    #     """
    #     Test that compute_gae works with different gamma values.
    #     """
    #     # Given: Sample trajectory data
    #     rewards, dones, values = sample_trajectory_data
        
    #     # When: Computing GAE with different gamma values
    #     gamma_0_9 = 0.9
    #     gamma_0_99 = 0.99
    #     gae_lambda = 0.95
        
    #     gae_values_0_9, gae_advantages_0_9 = compute_gae(rewards, dones, values, gamma_0_9, gae_lambda)
    #     gae_values_0_99, gae_advantages_0_99 = compute_gae(rewards, dones, values, gamma_0_99, gae_lambda)
        
    #     # Then: GAE should work with different gamma values
    #     # Check that both computations produce valid results
    #     assert torch.all(torch.isfinite(gae_values_0_9)), \
    #         "gae_values with gamma=0.9 should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages_0_9)), \
    #         "gae_advantages with gamma=0.9 should be finite"
    #     assert torch.all(torch.isfinite(gae_values_0_99)), \
    #         "gae_values with gamma=0.99 should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages_0_99)), \
    #         "gae_advantages with gamma=0.99 should be finite"
        
    #     # Check that results are different for different gamma values
    #     assert not torch.allclose(gae_values_0_9, gae_values_0_99), \
    #         "Different gamma values should produce different results"

    # def test_compute_gae_constant_rewards(self):
    #     """
    #     Test that compute_gae works with constant rewards.
    #     """
    #     # Given: Trajectory with constant rewards
    #     rewards = [10.0, 10.0, 10.0, 10.0, 10.0]
    #     dones = [False, False, False, False, True]
    #     values = [10.0, 10.0, 10.0, 10.0, 10.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle constant rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_increasing_rewards(self):
    #     """
    #     Test that compute_gae works with increasing rewards.
    #     """
    #     # Given: Trajectory with increasing rewards
    #     rewards = [5.0, 10.0, 15.0, 20.0, 25.0]
    #     dones = [False, False, False, False, True]
    #     values = [5.0, 10.0, 15.0, 20.0, 25.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle increasing rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_decreasing_rewards(self):
    #     """
    #     Test that compute_gae works with decreasing rewards.
    #     """
    #     # Given: Trajectory with decreasing rewards
    #     rewards = [25.0, 20.0, 15.0, 10.0, 5.0]
    #     dones = [False, False, False, False, True]
    #     values = [25.0, 20.0, 15.0, 10.0, 5.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle decreasing rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_negative_rewards(self):
    #     """
    #     Test that compute_gae works with negative rewards.
    #     """
    #     # Given: Trajectory with negative rewards
    #     rewards = [-5.0, -10.0, -15.0, -20.0, -25.0]
    #     dones = [False, False, False, False, True]
    #     values = [-5.0, -10.0, -15.0, -20.0, -25.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle negative rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_mixed_rewards(self):
    #     """
    #     Test that compute_gae works with mixed positive and negative rewards.
    #     """
    #     # Given: Trajectory with mixed rewards
    #     rewards = [5.0, -10.0, 15.0, -20.0, 25.0]
    #     dones = [False, False, False, False, True]
    #     values = [5.0, -10.0, 15.0, -20.0, 25.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle mixed rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_zero_rewards(self):
    #     """
    #     Test that compute_gae works with zero rewards.
    #     """
    #     # Given: Trajectory with zero rewards
    #     rewards = [0.0, 0.0, 0.0, 0.0, 0.0]
    #     dones = [False, False, False, False, True]
    #     values = [0.0, 0.0, 0.0, 0.0, 0.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle zero rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"
        
    #     # Check that gae_values and gae_advantages are all zeros
    #     assert torch.all(gae_values == 0), \
    #         "gae_values should be all zeros for zero rewards"
    #     assert torch.all(gae_advantages == 0), \
    #         "gae_advantages should be all zeros for zero rewards"


In [None]:

    # @pytest.mark.parametrize(
    #     "clip_value,description",
    #     [
    #         (0.01, "very_conservative"),
    #         (0.1, "conservative"),
    #         (0.2, "standard"),
    #         (0.3, "aggressive"),
    #         (0.5, "very_aggressive"),
    #     ],
    #     ids=["clip_0.01_very_conservative", "clip_0.1_conservative", "clip_0.2_standard", 
    #          "clip_0.3_aggressive", "clip_0.5_very_aggressive"]
    # )
    # def test_ppo_update_different_clip_values(self, mock_actor_critic_model, sample_ppo_data, 
    #                                         clip_value, description):
    #     """
    #     Test that PPO update works correctly with different clipping values.
        
    #     This test verifies that:
    #     1. Different clip values produce different update behaviors
    #     2. Conservative clips (small values) produce smaller parameter changes
    #     3. Aggressive clips (large values) allow larger parameter changes
    #     4. All clip values maintain training stability
        
    #     For CSTR context: Tests different levels of policy change conservatism
    #     in temperature control strategy updates.
    #     """
    #     # Given: A mock model and sample data
    #     model = mock_actor_critic_model
    #     states, actions, log_probs_old, returns, advantages = sample_ppo_data
        
    #     # Create optimizers
    #     actor_params = list(model.actor.parameters()) + [model.log_std]
    #     actor_optimizer = optim.Adam(actor_params, lr=3e-4)
    #     critic_optimizer = optim.Adam(model.critic.parameters(), lr=1e-3)
        
    #     # Store initial parameters
    #     initial_actor_params = {name: param.clone() for name, param in model.actor.named_parameters()}
    #     initial_critic_params = {name: param.clone() for name, param in model.critic.named_parameters()}
    #     initial_log_std = model.log_std.clone()
        
    #     # When: Performing PPO update with the specified clip value
    #     ppo_update(
    #         model=model,
    #         states=states,
    #         actions=actions,
    #         log_probs_old=log_probs_old,
    #         returns=returns,
    #         advantages=advantages,
    #         actor_optimizer=actor_optimizer,
    #         critic_optimizer=critic_optimizer,
    #         clip=clip_value,
    #         epochs=3  # Fewer epochs for faster testing
    #     )
        
    #     # Then: Different clip values should produce different behaviors
    #     # Check that parameters changed
    #     actor_params_changed = False
    #     total_actor_change = 0.0
        
    #     for name, param in model.actor.named_parameters():
    #         initial_param = initial_actor_params[name]
    #         if not torch.allclose(param, initial_param):
    #             actor_params_changed = True
    #             # Calculate total parameter change
    #             param_change = torch.sum(torch.abs(param - initial_param)).item()
    #             total_actor_change += param_change
        
    #     assert actor_params_changed, \
    #         f"Actor parameters should be updated with clip={clip_value} ({description})"
        
    #     # Check that critic parameters changed
    #     critic_params_changed = False
    #     for name, param in model.critic.named_parameters():
    #         if not torch.allclose(param, initial_critic_params[name]):
    #             critic_params_changed = True
    #             break
        
    #     assert critic_params_changed, \
    #         f"Critic parameters should be updated with clip={clip_value} ({description})"
        
    #     # Check that log_std parameter changed
    #     log_std_changed = not torch.allclose(model.log_std, initial_log_std)
    #     assert log_std_changed, \
    #         f"Log std parameter should be updated with clip={clip_value} ({description})"
        
    #     # Verify model still works after update
    #     test_states = torch.FloatTensor([[0.8, 0.2, 350.0], [0.5, 0.5, 340.0]])
    #     mean, std, values = model(test_states)
        
    #     assert torch.all(torch.isfinite(mean)), \
    #         f"Actor mean output should be finite with clip={clip_value} ({description})"
    #     assert torch.all(torch.isfinite(std)), \
    #         f"Actor std output should be finite with clip={clip_value} ({description})"
    #     assert torch.all(torch.isfinite(values)), \
    #         f"Critic values output should be finite with clip={clip_value} ({description})"
        
    #     # Store total change for potential comparison (could be used in future tests)
    #     # For now, just verify that changes occurred
    #     assert total_actor_change > 0, \
    #         f"Total actor parameter change should be positive with clip={clip_value} ({description})"

    # def test_ppo_update_extreme_advantages(self, mock_actor_critic_model, extreme_advantages_data):
    #     """
    #     Test that PPO update handles extreme advantage values correctly.
        
    #     This test verifies that:
    #     1. Very large positive advantages don't cause numerical instability
    #     2. Very large negative advantages don't cause numerical instability
    #     3. Zero advantages are handled correctly
    #     4. Mixed extreme advantages don't crash the training
    #     5. All computations remain finite and stable
        
    #     For CSTR context: Tests robustness when temperature adjustments
    #     have unexpectedly good or bad outcomes.
    #     """
    #     # Given: A mock model and data with extreme advantages
    #     model = mock_actor_critic_model
    #     states, actions, log_probs_old, returns, advantages = extreme_advantages_data
        
    #     # Create optimizers
    #     actor_params = list(model.actor.parameters()) + [model.log_std]
    #     actor_optimizer = optim.Adam(actor_params, lr=3e-4)
    #     critic_optimizer = optim.Adam(model.critic.parameters(), lr=1e-3)
        
    #     # Store initial parameters
    #     initial_actor_params = {name: param.clone() for name, param in model.actor.named_parameters()}
    #     initial_critic_params = {name: param.clone() for name, param in model.critic.named_parameters()}
    #     initial_log_std = model.log_std.clone()
        
    #     # When: Performing PPO update with extreme advantages
    #     ppo_update(
    #         model=model,
    #         states=states,
    #         actions=actions,
    #         log_probs_old=log_probs_old,
    #         returns=returns,
    #         advantages=advantages,
    #         actor_optimizer=actor_optimizer,
    #         critic_optimizer=critic_optimizer,
    #         clip=0.2,
    #         epochs=3  # Fewer epochs for faster testing
    #     )
        
    #     # Then: Extreme advantages should be handled without numerical issues
    #     # Check that parameters changed (training occurred)
    #     actor_params_changed = False
    #     for name, param in model.actor.named_parameters():
    #         if not torch.allclose(param, initial_actor_params[name]):
    #             actor_params_changed = True
    #             break
        
    #     assert actor_params_changed, \
    #         "Actor parameters should be updated even with extreme advantages"
        
    #     # Check that critic parameters changed
    #     critic_params_changed = False
    #     for name, param in model.critic.named_parameters():
    #         if not torch.allclose(param, initial_critic_params[name]):
    #             critic_params_changed = True
    #             break
        
    #     assert critic_params_changed, \
    #         "Critic parameters should be updated even with extreme advantages"
        
    #     # Check that log_std parameter changed
    #     log_std_changed = not torch.allclose(model.log_std, initial_log_std)
    #     assert log_std_changed, \
    #         "Log std parameter should be updated even with extreme advantages"
        
    #     # Verify that all model parameters are finite
    #     for name, param in model.named_parameters():
    #         assert torch.all(torch.isfinite(param)), \
    #             f"Parameter {name} should be finite after extreme advantages"
        
    #     # Verify model can still perform forward passes
    #     test_states = torch.FloatTensor([[0.8, 0.2, 350.0], [0.5, 0.5, 340.0]])
    #     mean, std, values = model(test_states)
        
    #     # Check that outputs are finite
    #     assert torch.all(torch.isfinite(mean)), \
    #         "Actor mean output should be finite after extreme advantages"
    #     assert torch.all(torch.isfinite(std)), \
    #         "Actor std output should be finite after extreme advantages"
    #     assert torch.all(torch.isfinite(values)), \
    #         "Critic values output should be finite after extreme advantages"
        
    #     # Check that outputs have correct shapes
    #     assert mean.shape == (2, 1), \
    #         f"Actor output should have shape (2, 1), got {mean.shape}"
    #     assert std.shape == (1,), \
    #         f"Actor std should have shape (1,), got {std.shape}"
    #     assert values.shape == (2, 1), \
    #         f"Critic output should have shape (2, 1), got {values.shape}"
        
    #     # Verify that std is positive (as expected for standard deviation)
    #     assert torch.all(std > 0), \
    #         "Action std should be positive after extreme advantages"

    # def test_ppo_update_clipping_effectiveness(self, mock_actor_critic_model, clipping_test_data):
    #     """
    #     Test that PPO clipping actually prevents excessive policy changes.
        
    #     This test verifies that:
    #     1. The policy doesn't change too drastically between updates
    #     2. Clipping actually constrains the policy updates
    #     3. The ratio between old and new policies stays within reasonable bounds
    #     4. PPO's conservative update mechanism is working
        
    #     This is a more rigorous test of PPO's core innovation.
    #     """
    #     # Given: A mock model and data designed to trigger clipping
    #     model = mock_actor_critic_model
    #     states, actions, log_probs_old, returns, advantages = clipping_test_data
        
    #     # Create optimizers
    #     actor_params = list(model.actor.parameters()) + [model.log_std]
    #     actor_optimizer = optim.Adam(actor_params, lr=3e-4)
    #     critic_optimizer = optim.Adam(model.critic.parameters(), lr=1e-3)
        
    #     # Convert data to tensors for analysis
    #     states_tensor = torch.FloatTensor(np.array(states))
    #     actions_tensor = torch.FloatTensor(np.array(actions))
    #     log_probs_old_tensor = torch.FloatTensor(log_probs_old)
        
    #     # Get initial policy predictions
    #     with torch.no_grad():
    #         initial_mean, initial_std, _ = model(states_tensor)
    #         initial_dist = torch.distributions.Normal(initial_mean, initial_std)
    #         initial_log_probs = initial_dist.log_prob(actions_tensor).sum(dim=-1)
        
    #     # Store initial parameters
    #     initial_params = {name: param.clone() for name, param in model.named_parameters()}
        
    #     # When: Performing PPO update with clipping
    #     clip_value = 0.2
    #     ppo_update(
    #         model=model,
    #         states=states,
    #         actions=actions,
    #         log_probs_old=log_probs_old,
    #         returns=returns,
    #         advantages=advantages,
    #         actor_optimizer=actor_optimizer,
    #         critic_optimizer=critic_optimizer,
    #         clip=clip_value,
    #         epochs=3
    #     )
        
    #     # Then: Check that PPO clipping is actually working
    #     # Get new policy predictions
    #     with torch.no_grad():
    #         new_mean, new_std, _ = model(states_tensor)
    #         new_dist = torch.distributions.Normal(new_mean, new_std)
    #         new_log_probs = new_dist.log_prob(actions_tensor).sum(dim=-1)
        
    #     # Calculate actual policy ratios
    #     actual_ratios = torch.exp(new_log_probs - log_probs_old_tensor)
        
    #     # Check that ratios are within reasonable bounds (PPO clipping should help here)
    #     # Even with extreme log_probs_old, the actual ratios should be reasonable
    #     max_ratio = torch.max(actual_ratios).item()
    #     min_ratio = torch.min(actual_ratios).item()
        
    #     # Print debugging information to understand PPO behavior
    #     print(f"\nPPO Clipping Debug Info:")
    #     print(f"  Clip value: {clip_value}")
    #     print(f"  Max ratio: {max_ratio:.4f}")
    #     print(f"  Min ratio: {min_ratio:.4f}")
    #     print(f"  Mean ratio: {torch.mean(actual_ratios).item():.4f}")
    #     print(f"  Ratio std: {torch.std(actual_ratios).item():.4f}")
    #     print(f"  Policy divergence: {torch.mean(torch.abs(actual_ratios - 1.0)).item():.4f}")
        
    #     # PPO should prevent extremely large ratios
    #     assert max_ratio < 10.0, \
    #         f"PPO clipping should prevent extremely large ratios, got max ratio of {max_ratio}"
        
    #     # PPO should prevent extremely small ratios (but allow some small ratios)
    #     # The clipping test data has extreme log_probs_old values, so some small ratios are expected
    #     assert min_ratio > 0.01, \
    #         f"PPO clipping should prevent extremely small ratios, got min ratio of {min_ratio}"
        
    #     # Check that policy changes are reasonable
    #     mean_change = torch.mean(torch.abs(new_mean - initial_mean)).item()
    #     std_change = torch.mean(torch.abs(new_std - initial_std)).item()
        
    #     print(f"  Mean policy change: {mean_change:.4f}")
    #     print(f"  Std policy change: {std_change:.4f}")
        
    #     # Policy changes should be moderate (not extreme)
    #     assert mean_change < 5.0, \
    #         f"Policy mean changes should be moderate, got {mean_change}"
    #     assert std_change < 2.0, \
    #         f"Policy std changes should be moderate, got {std_change}"
        
    #     # Verify that the model still works correctly
    #     test_states = torch.FloatTensor([[0.8, 0.2, 350.0], [0.5, 0.5, 340.0]])
    #     mean, std, values = model(test_states)
        
    #     assert torch.all(torch.isfinite(mean)), \
    #         "Actor mean output should be finite after PPO update"
    #     assert torch.all(torch.isfinite(std)), \
    #         "Actor std output should be finite after PPO update"
    #     assert torch.all(torch.isfinite(values)), \
    #         "Critic values output should be finite after PPO update"
        
    #     # Check that std is still positive
    #     assert torch.all(std > 0), \
    #         "Action std should remain positive after PPO update