In [1]:

from abc import ABC, abstractmethod
from typing import Any
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from rl.cstr.optimization.base_agent import compute_gae

In [2]:
def sample_trajectory_data():
    """
    Fixture providing sample trajectory data for testing compute_gae.
    
    Creates realistic trajectory data that would be returned by collect_trajectories:
    - rewards: Conversion efficiency rewards from CSTR control
    - dones: Episode termination flags
    - values: Critic's value estimates for each state
    
    Returns:
        tuple: (rewards, dones, values) for GAE computation
    """
    # Sample trajectory data (10 timesteps)
    rewards = [15.2, 12.8, 18.1, 14.5, 16.3, 13.7, 17.9, 15.8, 14.2, 16.7]
    dones = [False, False, False, False, False, False, False, False, False, True]
    values = [15.0, 13.0, 17.5, 14.0, 16.0, 13.5, 17.0, 15.5, 14.0, 16.5]
    
    return rewards, dones, values

In [3]:
# Given: Sample trajectory data (10 timesteps)
rewards, dones, values = sample_trajectory_data()

In [4]:
rewards

[15.2, 12.8, 18.1, 14.5, 16.3, 13.7, 17.9, 15.8, 14.2, 16.7]

In [5]:
dones

[False, False, False, False, False, False, False, False, False, True]

In [6]:
values

[15.0, 13.0, 17.5, 14.0, 16.0, 13.5, 17.0, 15.5, 14.0, 16.5]

In [7]:
# When: Computing GAE with default parameters
gae_lambda = 0.95
gae_advantages_normalized, total_expected_future_rewards, raw_gae_advantages = compute_gae(rewards, dones, values, gae_lambda)

In [8]:
gae_advantages_normalized

tensor([ 1.3105,  1.1703,  0.8678,  0.6285,  0.2950,  0.0233, -0.4003, -0.8421,
        -1.2547, -1.7983])

In [9]:
total_expected_future_rewards

tensor([105.8748,  99.7865,  95.4629,  84.9838,  77.2563,  66.8311,  57.9763,
         43.5901,  30.0555,  16.7000])

In [10]:
raw_gae_advantages

tensor([90.8748, 86.7865, 77.9629, 70.9838, 61.2563, 53.3311, 40.9763, 28.0901,
        16.0555,  0.2000])

In [11]:
advantages_normalized_manual = (raw_gae_advantages - raw_gae_advantages.mean()) / (raw_gae_advantages.std() + 1e-8)

In [12]:
advantages_normalized_manual

tensor([ 1.2432,  1.1102,  0.8233,  0.5963,  0.2799,  0.0221, -0.3797, -0.7989,
        -1.1903, -1.7060])

In [13]:
gae_advantages_normalized

tensor([ 1.3105,  1.1703,  0.8678,  0.6285,  0.2950,  0.0233, -0.4003, -0.8421,
        -1.2547, -1.7983])

In [16]:
advantages_mean_1 = gae_advantages_normalized.mean()

In [17]:
advantages_mean_1

tensor(5.9605e-09)

In [18]:
advantages_mean = gae_advantages_normalized.mean().item()

In [19]:
advantages_mean

5.9604645663569045e-09

In [20]:
assert advantages_mean < 1e-6

In [23]:
advantages_std_1 = gae_advantages_normalized.std()

In [24]:
advantages_std_1

tensor(1.0541)

In [21]:
advantages_std = gae_advantages_normalized.std().item()

In [22]:
advantages_std

1.054092526435852

In [26]:
advantages_std - 1.0

0.05409252643585205

In [25]:
assert abs(advantages_std - 1.0) < 1e-6

AssertionError: 

In [14]:
advantages_std_2 = advantages_normalized_manual.std().item()

In [15]:
advantages_std_2

1.0

In [None]:


    # def test_compute_gae_episode_termination(self, sample_trajectory_data):
    #     """
    #     Test that compute_gae handles episode termination correctly.
    #     """
    #     # Given: Sample trajectory data with episode termination
    #     rewards, dones, values = sample_trajectory_data
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle episode termination correctly
    #     # Check that the last timestep has a done flag
    #     assert dones[-1] == True, \
    #         "Last timestep should have done=True"
        
    #     # Check that gae_values and gae_advantages are computed correctly
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_different_gamma(self, sample_trajectory_data):
    #     """
    #     Test that compute_gae works with different gamma values.
    #     """
    #     # Given: Sample trajectory data
    #     rewards, dones, values = sample_trajectory_data
        
    #     # When: Computing GAE with different gamma values
    #     gamma_0_9 = 0.9
    #     gamma_0_99 = 0.99
    #     gae_lambda = 0.95
        
    #     gae_values_0_9, gae_advantages_0_9 = compute_gae(rewards, dones, values, gamma_0_9, gae_lambda)
    #     gae_values_0_99, gae_advantages_0_99 = compute_gae(rewards, dones, values, gamma_0_99, gae_lambda)
        
    #     # Then: GAE should work with different gamma values
    #     # Check that both computations produce valid results
    #     assert torch.all(torch.isfinite(gae_values_0_9)), \
    #         "gae_values with gamma=0.9 should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages_0_9)), \
    #         "gae_advantages with gamma=0.9 should be finite"
    #     assert torch.all(torch.isfinite(gae_values_0_99)), \
    #         "gae_values with gamma=0.99 should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages_0_99)), \
    #         "gae_advantages with gamma=0.99 should be finite"
        
    #     # Check that results are different for different gamma values
    #     assert not torch.allclose(gae_values_0_9, gae_values_0_99), \
    #         "Different gamma values should produce different results"

    # def test_compute_gae_constant_rewards(self):
    #     """
    #     Test that compute_gae works with constant rewards.
    #     """
    #     # Given: Trajectory with constant rewards
    #     rewards = [10.0, 10.0, 10.0, 10.0, 10.0]
    #     dones = [False, False, False, False, True]
    #     values = [10.0, 10.0, 10.0, 10.0, 10.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle constant rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_increasing_rewards(self):
    #     """
    #     Test that compute_gae works with increasing rewards.
    #     """
    #     # Given: Trajectory with increasing rewards
    #     rewards = [5.0, 10.0, 15.0, 20.0, 25.0]
    #     dones = [False, False, False, False, True]
    #     values = [5.0, 10.0, 15.0, 20.0, 25.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle increasing rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_decreasing_rewards(self):
    #     """
    #     Test that compute_gae works with decreasing rewards.
    #     """
    #     # Given: Trajectory with decreasing rewards
    #     rewards = [25.0, 20.0, 15.0, 10.0, 5.0]
    #     dones = [False, False, False, False, True]
    #     values = [25.0, 20.0, 15.0, 10.0, 5.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle decreasing rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_negative_rewards(self):
    #     """
    #     Test that compute_gae works with negative rewards.
    #     """
    #     # Given: Trajectory with negative rewards
    #     rewards = [-5.0, -10.0, -15.0, -20.0, -25.0]
    #     dones = [False, False, False, False, True]
    #     values = [-5.0, -10.0, -15.0, -20.0, -25.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle negative rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_mixed_rewards(self):
    #     """
    #     Test that compute_gae works with mixed positive and negative rewards.
    #     """
    #     # Given: Trajectory with mixed rewards
    #     rewards = [5.0, -10.0, 15.0, -20.0, 25.0]
    #     dones = [False, False, False, False, True]
    #     values = [5.0, -10.0, 15.0, -20.0, 25.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle mixed rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"

    # def test_compute_gae_zero_rewards(self):
    #     """
    #     Test that compute_gae works with zero rewards.
    #     """
    #     # Given: Trajectory with zero rewards
    #     rewards = [0.0, 0.0, 0.0, 0.0, 0.0]
    #     dones = [False, False, False, False, True]
    #     values = [0.0, 0.0, 0.0, 0.0, 0.0]
        
    #     # When: Computing GAE with default parameters
    #     gamma = 0.99
    #     gae_lambda = 0.95
    #     gae_values, gae_advantages = compute_gae(rewards, dones, values, gamma, gae_lambda)
        
    #     # Then: GAE should handle zero rewards correctly
    #     # Check that gae_values and gae_advantages have correct length
    #     assert len(gae_values) == len(rewards), \
    #         "gae_values length should match rewards length"
    #     assert len(gae_advantages) == len(rewards), \
    #         "gae_advantages length should match rewards length"
        
    #     # Check that gae_values and gae_advantages are finite
    #     assert torch.all(torch.isfinite(gae_values)), \
    #         "gae_values should be finite"
    #     assert torch.all(torch.isfinite(gae_advantages)), \
    #         "gae_advantages should be finite"
        
    #     # Check that gae_values and gae_advantages are all zeros
    #     assert torch.all(gae_values == 0), \
    #         "gae_values should be all zeros for zero rewards"
    #     assert torch.all(gae_advantages == 0), \
    #         "gae_advantages should be all zeros for zero rewards"
