In [None]:
import colorednoise
import jax
import jax.numpy as jnp
import jax.random as jrandom
from typing import Union, Tuple, Optional
import numpy as np
from functools import partial

In [None]:
@partial(jax.jit, static_argnums=(0,))
    def powerlaw_psd_gaussian_jax(self,
            key: jrandom.PRNGKey,
            exponent: float,
            shape: tuple,
            fmin: float = 0.0,
    ) -> jnp.ndarray:
        """JAX implementation of Gaussian (1/f)**beta noise.

        Based on the algorithm in:
        Timmer, J. and Koenig, M.:
        On generating power law noise.
        Astron. Astrophys. 300, 707-710 (1995)

        Parameters
        ----------
        key : jax.random.PRNGKey
            The random key for JAX's random number generator
        exponent : float
            The power-spectrum exponent (beta) where S(f) = (1/f)**beta
        size : int or tuple of ints
            The output shape. The last dimension is taken as time.
        fmin : float, optional
            Low-frequency cutoff (default: 0.0)

        Returns
        -------
        jnp.ndarray
            The generated noise samples with the specified power law spectrum
        """
        # Get number of samples from the last dimension
        n_samps = shape[0]

        # Calculate frequencies (assuming sample rate of 1)
        f = jnp.fft.rfftfreq(self.agent_config.BASE_NSAMPS)  # n_samps)

        # Validate and normalize fmin
        # if not (0 <= fmin <= 0.5):  # TODO add this in somehow
        #     raise ValueError("fmin must be between 0 and 0.5")
        fmin = jnp.maximum(fmin, 1.0 / n_samps)

        # Build scaling factors
        s_scale = f
        ix = jnp.sum(s_scale < fmin)
        s_scale = jnp.where(s_scale < fmin, s_scale[ix], s_scale)
        s_scale = s_scale ** (-exponent / 2.0)

        # Calculate theoretical output standard deviation
        w = s_scale[1:]
        w = w.at[-1].multiply((1 + (n_samps % 2)) / 2.0)  # Correct f = ±0.5
        sigma = 2 * jnp.sqrt(jnp.sum(w ** 2)) / n_samps

        # Adjust size for Fourier components
        # fourier_size = list(size)
        # fourier_size[-1] = len(f)

        # Generate random components
        key1, key2 = jrandom.split(key)
        sr = jrandom.normal(key1, (len(f), self.agent_config.PLANNING_HORIZON,
                                                                  self.action_dim)) * s_scale
        si = jrandom.normal(key2, (len(f), self.agent_config.PLANNING_HORIZON,
                                                                  self.action_dim)) * s_scale

        # Handle special frequencies using lax.cond
        def handle_even_case(args):
            si_, sr_ = args
            # Set imaginary part of Nyquist freq to 0 and multiply real part by sqrt(2)
            si_last = si_.at[..., -1].set(0.0)
            sr_last = sr_.at[..., -1].multiply(jnp.sqrt(2.0))
            return si_last, sr_last

        def handle_odd_case(args):
            return args

        si, sr = jax.lax.cond((n_samps % 2) == 0, handle_even_case, handle_odd_case, (si, sr))

        # DC component must be real
        si = si.at[..., 0].set(0)
        sr = sr.at[..., 0].multiply(jnp.sqrt(2.0))

        # Combine components
        s = sr + 1j * si

        # Transform to time domain and normalize
        y = jnp.fft.irfft(s, n=self.agent_config.BASE_NSAMPS, axis=-1) / sigma

        return y

In [None]:
def compare_implementations(seed: int, exponent: float, size: Union[int, Tuple[int, ...]], fmin: float = 0.0, num_samples: int = 1000) -> dict:
    """
    Compare the JAX and NumPy implementations of power law noise.

    Parameters
    ----------
    seed : int
        Random seed for reproducibility
    exponent : float
        Power spectrum exponent
    size : int or tuple
        Size of the output
    fmin : float
        Low-frequency cutoff
    num_samples : int
        Number of samples to generate for statistical comparison

    Returns
    -------
    dict
        Dictionary containing comparison metrics
    """
    # Import the original implementation
    from colorednoise import powerlaw_psd_gaussian as powerlaw_numpy

    # Initialize random states
    key = jrandom.PRNGKey(seed)
    np.random.seed(seed)

    # Generate samples from both implementations
    jax_samples = jnp.stack([
        powerlaw_psd_gaussian_jax(jrandom.fold_in(key, i), exponent, size, fmin)
        for i in range(num_samples)
    ])

    numpy_samples = np.stack([
        powerlaw_numpy(exponent, size, fmin, random_state=i)
        for i in range(num_samples)
    ])

    # Compute comparison metrics
    metrics = {
        'mean_difference': float(jnp.mean(jnp.abs(
            jnp.mean(jax_samples) - np.mean(numpy_samples)
        ))),
        'std_difference': float(jnp.abs(
            jnp.std(jax_samples) - np.std(numpy_samples)
        )),
        'jax_mean': float(jnp.mean(jax_samples)),
        'numpy_mean': float(np.mean(numpy_samples)),
        'jax_std': float(jnp.std(jax_samples)),
        'numpy_std': float(np.std(numpy_samples)),
    }

    # Compute power spectra
    jax_psd = jnp.mean(jnp.abs(jnp.fft.rfft(jax_samples, axis=-1))**2, axis=0)
    numpy_psd = np.mean(np.abs(np.fft.rfft(numpy_samples, axis=-1))**2, axis=0)

    metrics['psd_correlation'] = float(np.corrcoef(jax_psd, numpy_psd)[0, 1])

    return metrics

In [None]:
# Example usage and testing
def plot_comparison(
    jax_samples: jnp.ndarray,
    numpy_samples: np.ndarray,
    params: dict,
    save_path: Optional[str] = None
) -> None:
    """
    Create comparison plots between JAX and NumPy implementations.

    Parameters
    ----------
    jax_samples : jnp.ndarray
        Samples from JAX implementation
    numpy_samples : np.ndarray
        Samples from NumPy implementation
    params : dict
        Parameters used for generation
    save_path : str, optional
        If provided, save the plot to this path
    """
    import matplotlib.pyplot as plt

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot 1: Histogram comparison
    ax1.hist(np.array(jax_samples).flatten(), bins=50, alpha=0.5, label='JAX', density=True)
    ax1.hist(numpy_samples.flatten(), bins=50, alpha=0.5, label='NumPy', density=True)
    ax1.set_title('Distribution Comparison')
    ax1.set_xlabel('Value')
    ax1.set_ylabel('Density')
    ax1.legend()

    # Plot 2: Power Spectral Density
    freqs_jax = jnp.fft.rfftfreq(jax_samples.shape[-1])
    freqs_numpy = np.fft.rfftfreq(numpy_samples.shape[-1])

    psd_jax = jnp.mean(jnp.abs(jnp.fft.rfft(jax_samples, axis=-1))**2, axis=0)
    psd_numpy = np.mean(np.abs(np.fft.rfft(numpy_samples, axis=-1))**2, axis=0)

    # Normalize PSDs for comparison
    psd_jax = psd_jax / jnp.max(psd_jax)
    psd_numpy = psd_numpy / np.max(psd_numpy)

    ax2.loglog(freqs_jax[1:], psd_jax[1:], label='JAX', alpha=0.7)
    ax2.loglog(freqs_numpy[1:], psd_numpy[1:], '--', label='NumPy', alpha=0.7)
    ax2.set_title(f'Power Spectral Density (β={params["exponent"]})')
    ax2.set_xlabel('Frequency')
    ax2.set_ylabel('Normalized Power')
    ax2.legend()

    plt.suptitle(f'Comparison for β={params["exponent"]}, size={params["size"]}, fmin={params["fmin"]}')
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
    plt.show()

In [None]:
# Test parameters
    test_params = [
        {'exponent': 1.0, 'size': 1024, 'fmin': 0.0},  # Pink noise
        {'exponent': 2.0, 'size': 1024, 'fmin': 0.0},  # Brown noise
        {'exponent': 0.5, 'size': 1024, 'fmin': 0.1},  # With frequency cutoff
    ]

    for params in test_params:
        print(f"\nTesting with parameters: {params}")

        # Generate samples for comparison
        seed = 42
        key = jrandom.PRNGKey(seed)
        np.random.seed(seed)

        num_samples = 1000
        jax_samples = jnp.stack([
            powerlaw_psd_gaussian_jax(jrandom.fold_in(key, i), **params)
            for i in range(num_samples)
        ])

        numpy_samples = np.stack([
            colorednoise.powerlaw_psd_gaussian(params['exponent'], params['size'],
                                params['fmin'], random_state=i)
            for i in range(num_samples)
        ])

        # Compute metrics
        metrics = compare_implementations(seed=42, **params)

        print("Comparison metrics:")
        print(f"Mean difference: {metrics['mean_difference']:.6f}")
        print(f"Std difference: {metrics['std_difference']:.6f}")
        print(f"PSD correlation: {metrics['psd_correlation']:.6f}")
        print(f"JAX mean/std: {metrics['jax_mean']:.6f}/{metrics['jax_std']:.6f}")
        print(f"NumPy mean/std: {metrics['numpy_mean']:.6f}/{metrics['numpy_std']:.6f}")

        # Generate comparison plots
        plot_comparison(jax_samples, numpy_samples, params)