In [None]:
import torch
from autoemulate.emulators import GaussianProcessRBF

In [None]:
from autoemulate.simulations.base import Simulator


class SignalGenerator(Simulator):
    def __init__(self, parameters_range = None, output_names = None, log_level = "progress_bar"):
        if parameters_range is None:
            parameters_range = {
                "phase": (-3.14, 3.14),
                "amplitude": (0, 1),
                "frequency": (0.1, 10),
                "noise_level": (0, 0.1)
            }
        if output_names is None:
            output_names = ["signal"]
        super().__init__(parameters_range, output_names, log_level)
    
    def _forward(self, x):
        phase = x[:, 0]
        amplitude = x[:, 1]
        frequency = x[:, 2]
        noise_level = x[:, 3]
        t = torch.linspace(0, 1, steps=100)
        # signal = amplitude * torch.sin(2 * torch.pi * frequency * t + phase)
        # signal += noise_level * torch.randn_like(signal)
        # # Make non-linear to test the DFT compared to PCA
        # signal += t**2 * torch.sin(2 * torch.pi * frequency * t + phase)
        f = amplitude * t + phase
        signal = torch.sin(2 * torch.pi * frequency * f * t)
        return signal.unsqueeze(0)

In [None]:
parameters_range = {
    "phase": (-3.14, 3.14),
    # "phase": (0, 0),
    # "amplitude": (1, 1.001),
    "amplitude": (0, 1),
    # Difficult to compress since all frequencies are present
    # "frequency": (1, 10),
    "frequency": (1, 1.0001),
    "noise_level": (0, 1.5)
}
sim = SignalGenerator(parameters_range)
n_samples = 200
x = sim.sample_inputs(n_samples)
y, x = sim.forward_batch(x)

In [None]:
import matplotlib.pyplot as plt
for idx in range(5):
    plt.plot(y[idx], label=f"Signal {idx}")

In [None]:
from autoemulate import AutoEmulate
from autoemulate.core.model_selection import evaluate
from autoemulate.emulators.transformed.base import TransformedEmulator
from autoemulate.transforms.discrete_fourier import DiscreteFourierTransform
from autoemulate.transforms import StandardizeTransform, PCATransform


em = TransformedEmulator(
    x,
    y,
    model=GaussianProcessRBF,
    x_transforms=[],
    y_transforms=[DiscreteFourierTransform(n_components=2)]
)

em.fit(x, y)


In [None]:
em.y_transforms[0].inv(em.model.predict(x).sample(torch.Size([100])).mean(0)).shape


In [None]:
ae = AutoEmulate(
    x,
    y,
    models=[GaussianProcessRBF],
    x_transforms_list=[[]],
    # Can learn for phase and amplitude with small number of components
    y_transforms_list=[
        [PCATransform(n_components=2), StandardizeTransform()],
        [PCATransform(n_components=2)],
        [DiscreteFourierTransform(n_components=2), StandardizeTransform()],
        [DiscreteFourierTransform(n_components=2)],
        # [PCATransform(n_components=10), StandardizeTransform()],
        # [PCATransform(n_components=10)],
        # [DiscreteFourierTransform(n_components=10), StandardizeTransform()],
        # [DiscreteFourierTransform(n_components=10)],
    ],
    n_splits=5,
    log_level="warning",
    model_params={},
    transformed_emulator_params={"n_samples": 100}
)

In [None]:
ae.summarize().T


# Better Simulators for DFT vs PCA Comparison

Here are several simulators that should demonstrate clear advantages of DFT over PCA:

1. **MultiFrequencySignal**: Multiple sine waves with different frequencies
2. **ChirpSignal**: Frequency changes over time (frequency sweeps)
3. **BurstSignal**: Short bursts of oscillations
4. **ModulatedSignal**: Amplitude modulation with carrier frequencies
5. **ImpulseResponseSignal**: Decaying oscillations (damped sinusoids)

In [None]:
class MultiFrequencySignal(Simulator):
    """
    Generates signals with multiple frequency components.
    DFT should excel here because it naturally decomposes into frequency components.
    PCA will struggle because the frequency content varies independently.
    """
    def __init__(self, parameters_range=None, output_names=None, log_level="progress_bar"):
        if parameters_range is None:
            parameters_range = {
                "freq1": (1, 5),      # Low frequency component
                "freq2": (8, 15),     # High frequency component  
                "freq3": (20, 30),    # Very high frequency component
                "amp1": (0.1, 1.0),   # Amplitude of first component
                "amp2": (0.1, 1.0),   # Amplitude of second component
                "amp3": (0.1, 1.0),   # Amplitude of third component
                "phase1": (0, 2*torch.pi),
                "phase2": (0, 2*torch.pi),
                "phase3": (0, 2*torch.pi),
            }
        if output_names is None:
            output_names = ["signal"]
        super().__init__(parameters_range, output_names, log_level)
    
    def _forward(self, x):
        freq1, freq2, freq3 = x[:, 0], x[:, 1], x[:, 2]
        amp1, amp2, amp3 = x[:, 3], x[:, 4], x[:, 5]
        phase1, phase2, phase3 = x[:, 6], x[:, 7], x[:, 8]
        
        t = torch.linspace(0, 2, steps=128)  # 2 seconds, 128 points
        
        # Create multi-frequency signal
        signal = (amp1.unsqueeze(-1) * torch.sin(2 * torch.pi * freq1.unsqueeze(-1) * t + phase1.unsqueeze(-1)) +
                 amp2.unsqueeze(-1) * torch.sin(2 * torch.pi * freq2.unsqueeze(-1) * t + phase2.unsqueeze(-1)) +
                 amp3.unsqueeze(-1) * torch.sin(2 * torch.pi * freq3.unsqueeze(-1) * t + phase3.unsqueeze(-1)))
        
        return signal

In [None]:
class ChirpSignal(Simulator):
    """
    Generates chirp signals (frequency sweeps).
    DFT captures frequency content better than PCA for these non-stationary signals.
    """
    def __init__(self, parameters_range=None, output_names=None, log_level="progress_bar"):
        if parameters_range is None:
            parameters_range = {
                "f0": (1, 5),         # Starting frequency
                "f1": (10, 50),       # Ending frequency
                "amplitude": (0.5, 2.0),
                "chirp_type": (0, 1), # 0=linear, 1=exponential (we'll threshold at 0.5)
            }
        if output_names is None:
            output_names = ["signal"]
        super().__init__(parameters_range, output_names, log_level)
    
    def _forward(self, x):
        f0, f1, amplitude, chirp_type = x[:, 0], x[:, 1], x[:, 2], x[:, 3]
        
        t = torch.linspace(0, 2, steps=128)
        
        # Linear chirp vs exponential chirp
        linear_mask = chirp_type < 0.5
        
        # Linear chirp: frequency changes linearly
        freq_linear = f0.unsqueeze(-1) + (f1.unsqueeze(-1) - f0.unsqueeze(-1)) * t / 2
        phase_linear = 2 * torch.pi * torch.cumsum(freq_linear, dim=-1) * (t[1] - t[0])
        
        # Exponential chirp: frequency changes exponentially
        freq_exp = f0.unsqueeze(-1) * (f1.unsqueeze(-1) / f0.unsqueeze(-1)) ** (t / 2)
        phase_exp = 2 * torch.pi * f0.unsqueeze(-1) * (2 / torch.log(f1.unsqueeze(-1) / f0.unsqueeze(-1))) * \
                   ((f1.unsqueeze(-1) / f0.unsqueeze(-1)) ** (t / 2) - 1)
        
        # Choose linear or exponential based on parameter
        phase = torch.where(linear_mask.unsqueeze(-1), phase_linear, phase_exp)
        signal = amplitude.unsqueeze(-1) * torch.sin(phase)
        
        return signal

In [None]:
class BurstSignal(Simulator):
    """
    Generates burst signals - short oscillations with gaps.
    DFT should outperform PCA because bursts have clear frequency content,
    but PCA will be confused by the temporal sparsity.
    """
    def __init__(self, parameters_range=None, output_names=None, log_level="progress_bar"):
        if parameters_range is None:
            parameters_range = {
                "frequency": (5, 25),    # Burst frequency
                "burst_duration": (0.1, 0.5),  # Duration of each burst
                "burst_start": (0.2, 1.0),     # When burst starts
                "amplitude": (0.5, 2.0),
                "decay_rate": (0, 10),   # How quickly burst decays
            }
        if output_names is None:
            output_names = ["signal"]
        super().__init__(parameters_range, output_names, log_level)
    
    def _forward(self, x):
        frequency, burst_duration, burst_start, amplitude, decay_rate = \
            x[:, 0], x[:, 1], x[:, 2], x[:, 3], x[:, 4]
        
        t = torch.linspace(0, 2, steps=128)
        
        # Create burst windows
        burst_mask = ((t >= burst_start.unsqueeze(-1)) & 
                     (t <= (burst_start + burst_duration).unsqueeze(-1)))
        
        # Oscillation within burst
        oscillation = torch.sin(2 * torch.pi * frequency.unsqueeze(-1) * 
                               (t - burst_start.unsqueeze(-1)))
        
        # Apply exponential decay within burst
        decay_time = t - burst_start.unsqueeze(-1)
        decay_envelope = torch.exp(-decay_rate.unsqueeze(-1) * 
                                 torch.clamp(decay_time, min=0))
        
        # Combine everything
        signal = (amplitude.unsqueeze(-1) * oscillation * 
                 burst_mask.float() * decay_envelope)
        
        return signal

In [None]:
class ModulatedSignal(Simulator):
    """
    Generates amplitude modulated signals.
    DFT will capture both carrier and modulation frequencies,
    while PCA will struggle with the frequency mixing.
    """
    def __init__(self, parameters_range=None, output_names=None, log_level="progress_bar"):
        if parameters_range is None:
            parameters_range = {
                "carrier_freq": (10, 30),      # High frequency carrier
                "modulation_freq": (0.5, 3),   # Low frequency modulation
                "modulation_depth": (0.2, 0.9), # How deep the modulation is
                "carrier_amp": (0.5, 1.5),
            }
        if output_names is None:
            output_names = ["signal"]
        super().__init__(parameters_range, output_names, log_level)
    
    def _forward(self, x):
        carrier_freq, mod_freq, mod_depth, carrier_amp = \
            x[:, 0], x[:, 1], x[:, 2], x[:, 3]
        
        t = torch.linspace(0, 2, steps=128)
        
        # Carrier signal
        carrier = torch.sin(2 * torch.pi * carrier_freq.unsqueeze(-1) * t)
        
        # Modulation envelope
        modulation = 1 + mod_depth.unsqueeze(-1) * torch.sin(2 * torch.pi * mod_freq.unsqueeze(-1) * t)
        
        # Amplitude modulated signal
        signal = carrier_amp.unsqueeze(-1) * carrier * modulation
        
        return signal

In [None]:
class ImpulseResponseSignal(Simulator):
    """
    Generates damped oscillations (impulse responses).
    DFT should capture the frequency content better than PCA,
    especially when the damping and frequency vary independently.
    """
    def __init__(self, parameters_range=None, output_names=None, log_level="progress_bar"):
        if parameters_range is None:
            parameters_range = {
                "natural_freq": (2, 20),     # Natural frequency of oscillation
                "damping_ratio": (0.1, 2.0), # Damping coefficient
                "initial_amp": (0.5, 2.0),   # Initial amplitude
                "delay": (0, 0.5),           # Delay before impulse
            }
        if output_names is None:
            output_names = ["signal"]
        super().__init__(parameters_range, output_names, log_level)
    
    def _forward(self, x):
        nat_freq, damping, init_amp, delay = x[:, 0], x[:, 1], x[:, 2], x[:, 3]
        
        t = torch.linspace(0, 2, steps=128)
        
        # Time after impulse starts
        t_impulse = torch.clamp(t - delay.unsqueeze(-1), min=0)
        
        # Impulse hasn't started yet
        impulse_mask = (t >= delay.unsqueeze(-1)).float()
        
        # Damped oscillation
        omega_d = nat_freq.unsqueeze(-1) * torch.sqrt(torch.clamp(1 - damping.unsqueeze(-1)**2, min=0.01))
        exponential_decay = torch.exp(-damping.unsqueeze(-1) * nat_freq.unsqueeze(-1) * t_impulse)
        oscillation = torch.sin(omega_d * t_impulse)
        
        signal = (init_amp.unsqueeze(-1) * exponential_decay * 
                 oscillation * impulse_mask)
        
        return signal

In [None]:
def compare_dft_vs_pca(simulator_class, simulator_name, n_samples=200, n_components_list=[4, 8, 16]):
    """
    Compare DFT vs PCA performance for a given simulator.
    Fixed to use simple transform chains without interference.
    """
    print(f"\n=== Testing {simulator_name} ===")
    
    # Create simulator and generate data
    sim = simulator_class()
    x = sim.sample_inputs(n_samples)
    y, x = sim.forward_batch(x)
    
    # Visualize a few examples
    plt.figure(figsize=(12, 4))
    for idx in range(3):
        plt.subplot(1, 3, idx+1)
        plt.plot(y[idx].numpy())
        plt.title(f"Example {idx+1}")
    plt.suptitle(f"{simulator_name} Examples")
    plt.tight_layout()
    plt.show()
    
    results = []
    
    for n_comp in n_components_list:
        print(f"\nTesting with {n_comp} components...")
        
        try:
            # FIXED: Use simple transform chains without StandardizeTransform interference
            ae = AutoEmulate(
                x, y,
                models=[GaussianProcessRBF],
                x_transforms_list=[[]],  # No input transforms
                y_transforms_list=[
                    [PCATransform(n_components=n_comp)],                    # Simple PCA
                    [DiscreteFourierTransform(n_components=n_comp)],        # Simple DFT
                ],
                n_splits=3,  # Reduced for speed
                log_level="warning",  # Less verbose
                model_params={},
                # transformed_emulator_params={"": 50},
                n_bootstraps=None
            )
            
            summary = ae.summarize()
            print(summary.T)
            
            # Extract scores with better error handling
            summary["y_transforms_str"] = summary["y_transforms"].apply(lambda x: str(x))
            pca_rows = summary[summary['y_transforms_str'].str.contains('PCA')]
            dft_rows = summary[summary['y_transforms_str'].str.contains('DiscreteFourier')]
            
            if len(pca_rows) > 0 and len(dft_rows) > 0:
                pca_score = pca_rows['r2_test'].iloc[0]
                dft_score = dft_rows['r2_test'].iloc[0]

                results.append({
                    'n_components': n_comp,
                    'PCA_score': pca_score,
                    'DFT_score': dft_score,
                    'DFT_advantage': dft_score - pca_score
                })
                
                print(f"  PCA R² score: {pca_score:.4f}")
                print(f"  DFT R² score: {dft_score:.4f}")
                print(f"  DFT advantage: {dft_score - pca_score:.4f}")
                
                if dft_score > pca_score:
                    print(f"  ✓ DFT wins!")
                else:
                    print(f"  ✗ PCA wins")
            else:
                print(f"  Error: Could not find PCA or DFT results")
                
        except Exception as e:
            print(f"  Error with {n_comp} components: {e}")
    
    # Create results DataFrame and plot
    if results:
        import pandas as pd
        results_df = pd.DataFrame(results)
        
        # Plot comparison
        plt.figure(figsize=(10, 6))
        plt.subplot(1, 2, 1)
        plt.plot(results_df['n_components'], results_df['PCA_score'], 'o-', label='PCA', linewidth=2)
        plt.plot(results_df['n_components'], results_df['DFT_score'], 's-', label='DFT', linewidth=2)
        plt.xlabel('Number of Components')
        plt.ylabel('Test Score (R²)')
        plt.title(f'{simulator_name}: PCA vs DFT')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        colors = ['green' if x > 0 else 'red' for x in results_df['DFT_advantage']]
        plt.bar(results_df['n_components'], results_df['DFT_advantage'], color=colors, alpha=0.7)
        plt.xlabel('Number of Components')
        plt.ylabel('DFT Advantage (R² difference)')
        plt.title('DFT Performance Advantage')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0, color='black', linestyle='-', alpha=0.8)
        
        plt.tight_layout()
        plt.show()
        
        print("\nSummary Results:")
        print(results_df.to_string(index=False, float_format='%.4f'))
        
        return results_df
    else:
        print("No results to display")
        return None

In [None]:
# Test 1: MultiFrequencySignal
# This should show strong DFT advantage because multiple independent frequencies
# are present, which DFT naturally decomposes but PCA struggles with

multi_freq_results = compare_dft_vs_pca(
    MultiFrequencySignal, 
    "MultiFrequencySignal",
    # n_components_list=[6, 12, 20]
    n_components_list=[12, 20]
)

In [None]:
# Test 2: BurstSignal  
# This should show DFT advantage because bursts have clear frequency content
# but temporal sparsity confuses PCA

burst_results = compare_dft_vs_pca(
    BurstSignal, 
    "BurstSignal",
    n_components_list=[4, 8, 16]
)

In [None]:
# Test 3: ModulatedSignal
# This should show DFT advantage because amplitude modulation creates
# sidebands in frequency domain that DFT captures well

modulated_results = compare_dft_vs_pca(
    ModulatedSignal, 
    "ModulatedSignal",  
    n_components_list=[4, 8, 16]
)

In [None]:
# Test 4: ImpulseResponseSignal
# This should show DFT advantage because damped oscillations have
# clear frequency peaks that vary independently from damping

impulse_results = compare_dft_vs_pca(
    ImpulseResponseSignal, 
    "ImpulseResponseSignal",
    n_components_list=[4, 8, 16]
)

In [None]:
# Let's test with a much simpler, more controlled comparison
import numpy as np
from sklearn.decomposition import PCA


print("=== Controlled Comparison: DFT vs PCA ===")

def simple_comparison(sim_class, n_samples=200, n_components_list=[4, 8, 12, 16]):
    """Simple comparison without complex transform chains."""
    sim = sim_class()
    x = sim.sample_inputs(n_samples)
    y, x = sim.forward_batch(x)
    
    print(f"Testing {sim_class.__name__} with {y.shape[1]} time points")
    
    for n_comp in n_components_list:
        try:
            # Simple PCA with R² calculation
            pca = PCA(n_components=n_comp)
            pca.fit(y.numpy())
            y_pca = pca.transform(y.numpy())
            y_pca_recon = pca.inverse_transform(y_pca)
            pca_r2 = 1 - np.mean((y.numpy() - y_pca_recon)**2) / np.var(y.numpy())
            
            # Simple DFT with R² calculation
            dft = DiscreteFourierTransform(n_components=n_comp)
            dft.fit(y)
            y_dft = dft(y)
            y_dft_recon = dft._inverse(y_dft)
            dft_r2 = 1 - np.mean((y.numpy() - y_dft_recon.numpy())**2) / np.var(y.numpy())
            
            print(f"  {n_comp} components:")
            print(f"    PCA R²: {pca_r2:.4f}")
            print(f"    DFT R²: {dft_r2:.4f}")
            print(f"    PCA expl. var: {pca.explained_variance_ratio_.sum():.4f}")
            
            if dft_r2 > pca_r2:
                print(f"    ✓ DFT wins by {dft_r2 - pca_r2:.4f}")
            else:
                print(f"    ✗ PCA wins by {pca_r2 - dft_r2:.4f}")
                
        except Exception as e:
            print(f"  {n_comp} components: Error - {e}")
        print()

# Test with a very simple, clean signal first
class SimpleSineWave(Simulator):
    def __init__(self):
        parameters_range = {
            "frequency": (2, 10),
            "amplitude": (0.5, 1.5),
            "phase": (0, 2*torch.pi),
        }
        super().__init__(parameters_range, ["signal"])
    
    def _forward(self, x):
        freq, amp, phase = x[:, 0], x[:, 1], x[:, 2]
        t = torch.linspace(0, 4, steps=64)  # 4 seconds, 64 points
        signal = amp.unsqueeze(-1) * torch.sin(2 * torch.pi * freq.unsqueeze(-1) * t + phase.unsqueeze(-1))
        return signal

simple_comparison(SimpleSineWave)
simple_comparison(MultiFrequencySignal)

In [None]:
# Corrected DFT vs PCA comparison - fixing the transform chain issues
print("=== Fixed DFT vs PCA Comparison ===")

def fixed_comparison(simulator_class, simulator_name, n_samples=200, n_components_list=[6, 12, 20, 40]):
    """
    Proper comparison without problematic transform chains.
    """
    print(f"\n=== Testing {simulator_name} ===")
    
    # Create simulator and generate data
    sim = simulator_class()
    x = sim.sample_inputs(n_samples)
    y, x = sim.forward_batch(x)
    
    # Visualize a few examples
    plt.figure(figsize=(12, 4))
    for idx in range(3):
        plt.subplot(1, 3, idx+1)
        plt.plot(y[idx].numpy())
        plt.title(f"Example {idx+1}")
    plt.suptitle(f"{simulator_name} Examples")
    plt.tight_layout()
    plt.show()
    
    results = []
    
    for n_comp in n_components_list:
        print(f"\nTesting with {n_comp} components...")
        
        try:
            # Use SIMPLE transform chains to avoid interference
            ae = AutoEmulate(
                x, y,
                models=[GaussianProcessRBF],
                x_transforms_list=[[]],  # No input transforms
                y_transforms_list=[
                    [PCATransform(n_components=n_comp)],                    # Simple PCA
                    [DiscreteFourierTransform(n_components=n_comp)],        # Simple DFT
                ],
                n_splits=2,  # Reduced for speed
                log_level="progress_Bar",  # Less verbose
                model_params={},
                # transformed_emulator_params={"n_samples": 50},
                n_bootstraps=None
            )
            
            summary = ae.summarize()
            print(summary)
            # Extract scores with better error handling
            summary["y_transforms_str"] = summary["y_transforms"].apply(lambda x: str(x))
            pca_rows = summary[summary['y_transforms_str'].str.contains('PCA')]
            dft_rows = summary[summary['y_transforms_str'].str.contains('DiscreteFourier')]
            
            if len(pca_rows) > 0 and len(dft_rows) > 0:
                pca_score = pca_rows['r2_test'].iloc[0]
                dft_score = dft_rows['r2_test'].iloc[0]

                results.append({
                    'n_components': n_comp,
                    'PCA_score': pca_score,
                    'DFT_score': dft_score,
                    'DFT_advantage': dft_score - pca_score
                })
                
                print(f"  PCA R² score: {pca_score:.4f}")
                print(f"  DFT R² score: {dft_score:.4f}")
                print(f"  DFT advantage: {dft_score - pca_score:.4f}")
                
                if dft_score > pca_score:
                    print(f"  ✓ DFT wins!")
                else:
                    print(f"  ✗ PCA wins")
            else:
                print(f"  Error: Could not find PCA or DFT results")
                
        except Exception as e:
            print(f"  Error with {n_comp} components: {e}")
    
    # Plot results if we have any
    if results:
        import pandas as pd
        results_df = pd.DataFrame(results)
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(results_df['n_components'], results_df['PCA_score'], 'o-', label='PCA', linewidth=2, markersize=8)
        plt.plot(results_df['n_components'], results_df['DFT_score'], 's-', label='DFT', linewidth=2, markersize=8)
        plt.xlabel('Number of Components')
        plt.ylabel('R² Test Score')
        plt.title(f'{simulator_name}: Performance Comparison')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        colors = ['green' if x > 0 else 'red' for x in results_df['DFT_advantage']]
        plt.bar(results_df['n_components'], results_df['DFT_advantage'], color=colors, alpha=0.7)
        plt.xlabel('Number of Components')
        plt.ylabel('DFT Advantage (R² difference)')
        plt.title('DFT Performance Advantage')
        plt.grid(True, alpha=0.3)
        plt.axhline(y=0, color='black', linestyle='-', alpha=0.8)
        
        plt.tight_layout()
        plt.show()
        
        print("\nSummary Results:")
        print(results_df.to_string(index=False, float_format='%.4f'))
        
        return results_df
    else:
        print("No results to display")
        return None

# Test the fixed comparison
fixed_comparison(MultiFrequencySignal, "MultiFrequencySignal")