In [3]:
import torch
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter

from burgers import BurgersDataset
from transforms import SafetyTransform

In [19]:
def visualize_sample(sample, transform_name):
    """Visualize a data sample with dynamic waveform plots and save as GIF"""
    u, f, s = sample[0], sample[1], sample[2]  # Separate u, f, s channels
    nt, nx = u.shape
    x = torch.linspace(0, 1, nx)
    
    # Create figure with three subplots
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))
    fig.suptitle(f'Burgers Dataset Sample ({transform_name} transform)')
    
    # Initialize lines
    line1, = ax1.plot([], [], 'b-', label='u(x,t)')
    line2, = ax2.plot([], [], 'r-', label='f(x,t)')
    line3, = ax3.plot([], [], 'g-', label='s(x,t)')
    
    # Set axes labels and limits
    for ax in (ax1, ax2, ax3):
        ax.set_xlim(0, 1)
        ax.set_xlabel('x')
        ax.grid(True)
        ax.legend()
    
    ax1.set_ylabel('Solution (u)')
    ax2.set_ylabel('Force (f)')
    ax3.set_ylabel('Safety Score (s)')
    
    # Set y-axis limits based on data
    ax1.set_ylim(u.min(), u.max())
    ax2.set_ylim(f.min(), f.max())
    ax3.set_ylim(s.min(), s.max())
    
    def init():
        """Initialize animation"""
        line1.set_data([], [])
        line2.set_data([], [])
        line3.set_data([], [])
        return line1, line2, line3
    
    def animate(frame):
        """Animation function"""
        # Update lines
        line1.set_data(x, u[frame])
        line2.set_data(x, f[frame])
        line3.set_data(x, s[frame])
        
        # Update title with current time step
        fig.suptitle(f'Burgers Dataset Sample ({transform_name} transform) - t={frame}')
        return line1, line2, line3
    
    # Create animation
    anim = FuncAnimation(fig, animate, init_func=init, frames=nt,
                        interval=200, blit=True)
    
    plt.tight_layout()
    
    # Create save directory if it doesn't exist
    os.makedirs("../experiments/dataset/test_results", exist_ok=True)
    
    # Save animation as GIF
    writer = PillowWriter(fps=5)  # 5 frames per second
    anim.save(f'../experiments/dataset/test_results/waveform_{transform_name}.gif', writer=writer)
    
    # Close the figure to free memory
    plt.close()

In [4]:
def test_burgers_dataset():
    """Test the loading and processing functionality of Burgers dataset"""
    
    # Test different safety transform methods
    transforms = {
        "default": None,  # Use default u²
        "square": SafetyTransform(method="square"),
        "abs": SafetyTransform(method="abs")
    }
    
    for transform_name, transform in transforms.items():
        print(f"\nTesting {transform_name} transform:")
        
        # Initialize dataset
        dataset = BurgersDataset(
            split="train",
            safety_transform=transform,
            root_path="../datasets",
            scaler=10.0,
        )
        
        # Basic information validation
        print(f"Dataset size: {len(dataset)}")
        
        # Get a sample and validate
        sample = dataset[0]
        print(f"Sample shape: {sample.shape}")
        
        # Verify data range (should be in [-1, 1])
        print(f"Data range: [{sample.min():.3f}, {sample.max():.3f}]")
        print(f"control range: [{sample[0].min():.3f}, {sample[0].max():.3f}]")
        print(f"trajectory range: [{sample[1].min():.3f}, {sample[1].max():.3f}]")
        print(f"safety score range: [{sample[2].min():.3f}, {sample[2].max():.3f}]")
                  
        # Save the sample visualization as GIF
        # visualize_sample(sample, transform_name)
        # print(f"Saved animation to test_results/waveform_{transform_name}.gif")

test_burgers_dataset()


Testing default transform:
Dataset size: 39000
Sample shape: torch.Size([3, 16, 128])
Data range: [-0.098, 0.303]
control range: [-0.074, 0.108]
trajectory range: [-0.098, 0.303]
safety score range: [0.000, 0.116]

Testing square transform:
Dataset size: 39000
Sample shape: torch.Size([3, 16, 128])
Data range: [-0.098, 0.303]
control range: [-0.074, 0.108]
trajectory range: [-0.098, 0.303]
safety score range: [0.000, 0.116]

Testing abs transform:
Dataset size: 39000
Sample shape: torch.Size([3, 16, 128])
Data range: [-0.098, 0.303]
control range: [-0.074, 0.108]
trajectory range: [-0.098, 0.303]
safety score range: [0.000, 0.108]
