# JAXTrace Tutorial: Memory-Optimized Particle Tracking

This notebook demonstrates how to use JAXTrace for efficient particle tracking in computational fluid dynamics applications.

## Table of Contents
1. [Installation and Setup](#installation)
2. [Loading VTK Data](#loading-data)
3. [Basic Particle Tracking](#basic-tracking)
4. [Advanced Configuration](#advanced-config)
5. [Visualization](#visualization)
6. [Memory Optimization](#memory-optimization)
7. [Performance Comparison](#performance)


## 1. Installation and Setup {#installation}

First, let's import JAXTrace and check the system configuration.

In [None]:
# Import JAXTrace components
import jaxtrace
from jaxtrace import VTKReader, ParticleTracker, ParticleVisualizer
from jaxtrace.utils import (
    get_memory_config, get_system_info, create_custom_particle_distribution,
    estimate_computational_requirements, create_progress_callback
)

import numpy as np
import matplotlib.pyplot as plt

print(f"JAXTrace version: {jaxtrace.__version__}")
print(f"JAX available: {jaxtrace.JAX_AVAILABLE}")

In [None]:
# Check system information
system_info = get_system_info()
print("System Information:")
for key, value in system_info.items():
    print(f"  {key}: {value}")

## 2. Loading VTK Data {#loading-data}

JAXTrace uses a memory-optimized VTK reader that can handle large datasets efficiently.

In [None]:
# Define VTK file pattern
# Replace this with your actual VTK file pattern
vtk_pattern = "path/to/your/vtk/files/*.pvtu"  # or "*.vtu", "simulation_*.vtk", etc.

# Alternative: Create example data if no VTK files available
# This section shows how to create synthetic data for testing
create_example_data = True  # Set to False if you have real VTK files

if create_example_data:
    print("Creating example synthetic data...")
    print("(In real applications, you would use actual VTK files)")
    
    # For this tutorial, we'll create synthetic data
    # In practice, you would replace this with your VTK file pattern
    vtk_pattern = "synthetic_data_*.vtk"  # This won't actually exist
else:
    # Initialize VTK reader with your data
    reader = VTKReader(
        file_pattern=vtk_pattern,
        max_time_steps=50,  # Limit to last 50 time steps for memory efficiency
        cache_size_limit=5   # Keep only 5 time steps in memory
    )
    
    # Get grid information
    grid_info = reader.get_grid_info()
    print(f"Grid Information:")
    print(f"  Number of grid points: {grid_info['n_points']}")
    print(f"  Number of time steps: {grid_info['n_timesteps']}")
    print(f"  Velocity field: {grid_info['velocity_field_name']}")
    print(f"  Domain bounds: {grid_info['bounds']}")
    
    # Validate VTK files
    validation = reader.validate_files()
    print(f"\nValidation: {validation['valid_files']}/{validation['total_files']} files valid")

## 3. Basic Particle Tracking {#basic-tracking}

Now let's set up a basic particle tracking simulation.

In [None]:
# For demonstration, we'll create synthetic particle tracking
# In real applications, you would use the VTK reader from above

if create_example_data:
    # Create synthetic example for demonstration
    print("Running synthetic particle tracking example...")
    
    # Define simulation domain
    box_bounds = ((0.0, 10.0), (0.0, 5.0), (2.0, 8.0))  # (x_min, x_max), (y_min, y_max), (z_min, z_max)
    
    # Create initial particle distribution
    initial_positions = create_custom_particle_distribution(
        box_bounds=box_bounds,
        distribution_type='uniform',
        n_particles=1000,
        random_seed=42
    )
    
    # Simulate simple motion (in real case, this would be done by ParticleTracker)
    n_steps = 50
    dt = 0.1
    
    # Simple spiral motion for demonstration
    positions = initial_positions.copy()
    trajectory_data = [positions.copy()]
    
    for step in range(n_steps):
        # Simple velocity field: spiral motion
        center = np.array([5.0, 2.5, 5.0])
        relative_pos = positions - center
        
        # Circular motion in XY plane, upward drift in Z
        velocities = np.zeros_like(positions)
        velocities[:, 0] = -relative_pos[:, 1] * 0.2  # -y component
        velocities[:, 1] = relative_pos[:, 0] * 0.2   # x component  
        velocities[:, 2] = 0.1                        # upward drift
        
        # Add some noise
        velocities += np.random.normal(0, 0.02, velocities.shape)
        
        # Update positions
        positions += velocities * dt
        trajectory_data.append(positions.copy())
    
    # Convert to trajectory array
    trajectories = np.array(trajectory_data).transpose(1, 0, 2)  # (n_particles, n_steps+1, 3)
    final_positions = trajectories[:, -1, :]
    
    print(f"Synthetic tracking complete: {len(final_positions)} particles, {n_steps} steps")
    
else:
    # Real particle tracking with VTK data
    print("Setting up real particle tracking...")
    
    # Initialize particle tracker
    tracker = ParticleTracker(
        vtk_reader=reader,
        max_gpu_memory_gb=8.0,
        k_neighbors=8,
        shape_function='linear',
        interpolation_method='finite_element'
    )
    
    print(f"Tracker initialized: {tracker}")
    
    # Create initial particle grid
    initial_positions = tracker.create_particle_grid(
        resolution=(20, 20, 20),  # 8000 particles
        bounds_padding=0.1
    )
    
    # Set up progress tracking
    progress_callback = create_progress_callback(
        update_frequency=50,
        show_memory=True,
        show_time=True
    )
    
    # Run particle tracking
    final_positions = tracker.track_particles(
        initial_positions=initial_positions,
        dt=0.01,
        n_steps=500,
        integration_method='euler',
        progress_callback=progress_callback,
        save_trajectories=False  # Set to True if you want full trajectories
    )
    
    print(f"Particle tracking complete: {len(final_positions)} particles")

## 4. Advanced Configuration {#advanced-config}

JAXTrace provides several preset configurations optimized for different memory and accuracy requirements.

In [None]:
# Explore available memory configurations
configs = ['low_memory', 'medium_memory', 'high_memory', 'high_accuracy']

print("Available Configuration Presets:")
for config_name in configs:
    config = get_memory_config(config_name)
    print(f"\n{config_name.upper()}:")
    for key, value in config.items():
        print(f"  {key}: {value}")

In [None]:
# Estimate computational requirements for different configurations
print("Computational Requirements Estimation:")

for config_name in ['medium_memory', 'high_accuracy']:
    config = get_memory_config(config_name)
    
    # Calculate number of particles
    n_particles = np.prod(config['particle_resolution'])
    n_timesteps = config['max_time_steps']
    
    requirements = estimate_computational_requirements(
        n_particles=n_particles,
        n_timesteps=n_timesteps,
        integration_method=config['integration_method'],
        interpolation_method=config['interpolation_method'],
        k_neighbors=config['k_neighbors']
    )
    
    print(f"\n{config_name.upper()}:")
    print(f"  Particles: {requirements['n_particles']:,}")
    print(f"  Estimated memory: {requirements['estimated_memory_gb']:.2f} GB")
    print(f"  Estimated runtime: {requirements['estimated_runtime_hours']:.2f} hours")
    if requirements['recommendations']:
        print(f"  Recommendations: {'; '.join(requirements['recommendations'])}")

## 5. Visualization {#visualization}

JAXTrace provides comprehensive visualization capabilities for analyzing particle tracking results.

In [None]:
# Initialize visualizer
if create_example_data:
    # Use synthetic data
    visualizer = ParticleVisualizer(
        final_positions=final_positions,
        initial_positions=initial_positions,
        trajectories=trajectories
    )
else:
    # Use real tracking results
    visualizer = ParticleVisualizer(
        final_positions=final_positions,
        initial_positions=initial_positions
    )

print(f"Visualizer initialized: {visualizer}")
print("Visualizer info:")
viz_info = visualizer.get_visualization_info()
for key, value in viz_info.items():
    print(f"  {key}: {value}")

In [None]:
# 3D scatter plot of particle positions
print("Creating 3D position plot...")
visualizer.plot_3d_positions(
    show_initial=True,
    show_final=True,
    show_trajectories=True,  # Only works if trajectories are available
    n_show=500,  # Show subset for performance
    cam_view=(45, 30, 0),  # Azimuth, elevation, roll
    figsize=(12, 10)
)

In [None]:
# Cross-section analysis
print("Creating cross-section plots...")
visualizer.plot_cross_sections(
    plane='xy',
    position=5.0,  # Z = 5.0
    slab_thickness=1.0,
    show_initial=True,
    show_final=True,
    figsize=(15, 6)
)

In [None]:
# Displacement analysis (if initial positions available)
if visualizer.initial_positions is not None:
    print("Creating displacement analysis...")
    visualizer.plot_displacement_analysis(figsize=(15, 12))
else:
    print("Displacement analysis requires initial positions")

### Advanced Density Visualization

JAXTrace supports advanced density estimation methods including JAX-accelerated KDE and SPH.

In [None]:
# Density estimation with different methods
print("Creating density plots...")

# JAX KDE density plot
if jaxtrace.JAX_AVAILABLE:
    visualizer.plot_density(
        positions='final',
        plane='xy',
        position=5.0,
        slab_thickness=1.0,
        method='jax_kde',
        grid_resolution=150,
        bandwidth=0.3,
        levels=15,
        cmap='viridis'
    )
else:
    print("JAX not available, using fallback density method")
    visualizer.plot_density(
        positions='final',
        plane='xy', 
        position=5.0,
        slab_thickness=1.0,
        method='seaborn',
        grid_resolution=100,
        levels=15,
        cmap='viridis'
    )

In [None]:
# Combined analysis plot
print("Creating combined analysis plot...")
visualizer.plot_combined_analysis(
    plane='xy',
    position=5.0,
    slab_thickness=1.0,
    method='jax_kde' if jaxtrace.JAX_AVAILABLE else 'seaborn',
    grid_resolution=100,
    figsize=(15, 12)
)

In [None]:
# Interactive 3D visualization (if Plotly is available)
try:
    print("Creating interactive 3D plot...")
    visualizer.plot_interactive_3d(
        show_initial=True,
        show_final=True,
        show_trajectories=True,
        n_show=500
    )
except ImportError:
    print("Plotly not available. Install with: pip install plotly")

## 6. Memory Optimization {#memory-optimization}

JAXTrace provides several strategies for memory optimization when dealing with large datasets.

In [None]:
# Example: Memory-optimized tracking with subsampling
if not create_example_data:  # Only for real VTK data
    print("Demonstrating memory optimization techniques...")
    
    # Create a larger initial grid
    large_initial_positions = tracker.create_particle_grid(
        resolution=(40, 40, 40),  # 64,000 particles
        bounds_padding=0.1
    )
    
    print(f"Large particle set: {len(large_initial_positions)} particles")
    
    # Track with spatial and temporal subsampling
    optimized_final_positions = tracker.track_particles_with_subsampling(
        initial_positions=large_initial_positions,
        dt=0.01,
        n_steps=1000,
        spatial_subsample_factor=2,  # Use every 2nd particle
        temporal_subsample_factor=2,  # Use every 2nd time step
        integration_method='euler'
    )
    
    print(f"Optimized tracking: {len(optimized_final_positions)} particles")

else:
    print("Memory optimization examples require real VTK data")

In [None]:
# Demonstrate automatic parameter optimization
from jaxtrace.utils import optimize_parameters_for_system, validate_simulation_parameters

# Start with a high-memory configuration
base_config = get_memory_config('high_memory')
print("Base configuration:")
for key, value in base_config.items():
    print(f"  {key}: {value}")

# Validate parameters
is_valid, issues = validate_simulation_parameters(base_config)
print(f"\nConfiguration valid: {is_valid}")
if issues:
    for issue in issues:
        print(f"  - {issue}")

# Optimize for current system
optimized_config = optimize_parameters_for_system(
    base_config=base_config,
    target_memory_gb=4.0,  # Target 4GB memory usage
    target_runtime_hours=1.0  # Target 1 hour runtime
)

print("\nOptimized configuration:")
for key, value in optimized_config.items():
    if key in base_config and base_config[key] != value:
        print(f"  {key}: {base_config[key]} → {value} (changed)")
    else:
        print(f"  {key}: {value}")

## 7. Performance Comparison {#performance}

Compare different integration and interpolation methods to find the best performance for your use case.

In [None]:
# Benchmark interpolation methods
if jaxtrace.JAX_AVAILABLE:
    from jaxtrace.utils import benchmark_interpolation_methods
    
    print("Benchmarking interpolation methods...")
    benchmark_results = benchmark_interpolation_methods(
        n_particles=5000,
        n_evaluations=1000,
        methods=['nearest_neighbor', 'finite_element']
    )
    
    print("\nBenchmark Results:")
    for method, results in benchmark_results.items():
        print(f"\n{method.upper()}:")
        print(f"  Time per evaluation: {results['time_per_evaluation_ms']:.3f} ms")
        print(f"  Evaluations per second: {results['evaluations_per_second']:.0f}")
        print(f"  Total time: {results['total_time_seconds']:.3f} s")
else:
    print("JAX not available for benchmarking")

In [None]:
# Compare density estimation methods
if jaxtrace.JAX_AVAILABLE:
    from jaxtrace.visualizer import compare_density_methods
    
    print("Comparing density estimation methods...")
    compare_density_methods(
        final_positions=final_positions,
        plane='xy',
        position=5.0,
        slab_thickness=1.0,
        methods=['jax_kde', 'seaborn'],  # Add 'sph' if you want to test SPH
        figsize=(15, 6)
    )
else:
    print("JAX not available for density method comparison")

## 8. Saving and Loading Results

Save your simulation configuration and results for reproducibility.

In [None]:
# Save simulation configuration
from jaxtrace.utils import save_simulation_config, load_simulation_config

# Use the optimized configuration from earlier
config_to_save = optimized_config
config_to_save['simulation_name'] = 'tutorial_example'
config_to_save['description'] = 'JAXTrace tutorial simulation with optimized parameters'

# Save configuration
save_simulation_config(config_to_save, 'tutorial_config.json')

# Load configuration (demonstration)
loaded_config = load_simulation_config('tutorial_config.json')
print("Loaded configuration:")
for key, value in loaded_config.items():
    print(f"  {key}: {value}")

In [None]:
# Save particle positions for later analysis
import os

# Create output directory
output_dir = 'tutorial_output'
os.makedirs(output_dir, exist_ok=True)

# Save positions as numpy arrays
np.save(os.path.join(output_dir, 'final_positions.npy'), final_positions)
if visualizer.initial_positions is not None:
    np.save(os.path.join(output_dir, 'initial_positions.npy'), visualizer.initial_positions)
if visualizer.trajectories is not None:
    np.save(os.path.join(output_dir, 'trajectories.npy'), visualizer.trajectories)

print(f"Results saved to {output_dir}/")
print(f"Files created:")
for file in os.listdir(output_dir):
    filepath = os.path.join(output_dir, file)
    size_mb = os.path.getsize(filepath) / 1024**2
    print(f"  {file}: {size_mb:.2f} MB")

## 9. Advanced Usage Tips

Here are some advanced tips for getting the best performance from JAXTrace.

In [None]:
# Custom particle distributions for specific applications
print("Creating custom particle distributions...")

# Different distribution types
distributions = ['uniform', 'random', 'gaussian', 'stratified', 'clustered']
box_bounds = ((0, 5), (0, 5), (0, 5))

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, dist_type in enumerate(distributions):
    positions = create_custom_particle_distribution(
        box_bounds=box_bounds,
        distribution_type=dist_type,
        n_particles=1000,
        random_seed=42
    )
    
    ax = axes[i]
    ax.scatter(positions[:, 0], positions[:, 1], alpha=0.6, s=1)
    ax.set_title(f'{dist_type.capitalize()} Distribution')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

# Remove empty subplot
axes[-1].remove()

plt.tight_layout()
plt.show()

In [None]:
# Memory monitoring during simulation
from jaxtrace.utils import monitor_memory_usage

@monitor_memory_usage
def example_memory_intensive_operation():
    """Example function to demonstrate memory monitoring."""
    # Create large array
    large_array = np.random.randn(1000000, 3)
    
    # Do some computation
    result = np.mean(large_array, axis=0)
    
    # Clean up
    del large_array
    
    return result

print("Demonstrating memory monitoring:")
result = example_memory_intensive_operation()
print(f"Result: {result}")

## Summary

This tutorial covered:

1. **Setting up JAXTrace** with system information checks
2. **Loading VTK data** with memory-optimized readers
3. **Basic particle tracking** with different integration methods
4. **Advanced configuration** using preset configurations
5. **Comprehensive visualization** including 3D plots, cross-sections, and density estimation
6. **Memory optimization** strategies for large datasets
7. **Performance comparison** of different methods
8. **Saving and loading** results for reproducibility

### Key Features of JAXTrace:

- **Memory-optimized**: Handles large VTK datasets efficiently
- **GPU-accelerated**: Uses JAX for high-performance computing
- **Flexible**: Multiple integration and interpolation methods
- **Visual**: Comprehensive visualization capabilities
- **Configurable**: Easy-to-use preset configurations
- **Scalable**: Automatic parameter optimization

### Next Steps:

1. Replace the synthetic data with your actual VTK files
2. Experiment with different configuration presets
3. Use the memory optimization features for large simulations
4. Explore advanced visualization options
5. Benchmark different methods for your specific use case

For more information, see the JAXTrace documentation and examples.