# Developer's Guide to FWI: Using AcousticWave Class

This tutorial was prepared by Alexandre Olender olender@usp.br

This **advanced tutorial** is designed for developers who want to understand the inner workings of Full Waveform Inversion (FWI) and implement custom FWI algorithms using Spyro's low-level `AcousticWave` class. Instead of using the high-level `FullWaveformInversion` wrapper, we'll build our own FWI implementation from scratch.

## Why Use AcousticWave Directly?

The `AcousticWave` class provides direct access to:
- **Forward wave propagation**: `forward_solve()`
- **Adjoint wave propagation**: `gradient_solve()`
- **Mesh and model management**: Direct control over discretization
- **Custom optimization algorithms**: Implement your own optimization strategies
- **Advanced physics**: Add custom terms to the wave equation
- **Performance optimization**: Fine-tune memory usage and computational efficiency

## What You'll Learn:

1. **Manual FWI implementation** using only `AcousticWave` methods
2. **Custom optimization loops** with gradient-based methods
3. **Direct manipulation** of Firedrake functions and meshes
4. **Memory management** for large-scale problems
5. **Advanced debugging** and monitoring techniques
6. **Custom objective functions** beyond simple L2 misfit

## Prerequisites:
- Solid understanding of FWI theory and algorithms
- Experience with Firedrake/PETSc programming
- Knowledge of optimization theory (gradient descent, line search, etc.)
- Familiarity with the basic FWI tutorial

## 1. Setup and Imports

We'll import the core libraries plus optimization tools that we'll need to implement our custom FWI algorithm.

In [None]:
# Enable inline plotting
%matplotlib inline

# Core libraries
import spyro
import firedrake as fire
import numpy as np
import matplotlib.pyplot as plt
import time
import warnings
from scipy.optimize import minimize

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

print("Developer FWI Tutorial - Using AcousticWave Class")
print("=" * 50)

## 2. Problem Configuration

We'll define our problem parameters similar to the standard FWI tutorial, but we'll be working directly with `AcousticWave` objects.

In [None]:
# Problem parameters
degree = 4
frequency = 5.0
final_time = 1.3

# Base dictionary for both forward and inversion problems
base_dictionary = {
    "options": {
        "cell_type": "T",  # Triangular elements
        "variant": "lumped",
        "degree": degree,
        "dimension": 2,
    },
    "parallelism": {
        "type": "automatic",
    },
    "mesh": {
        "length_z": 2.0,
        "length_x": 2.0,
        "length_y": 0.0,
    },
    "acquisition": {
        "source_type": "ricker",
        "source_locations": spyro.create_transect((-0.35, 0.5), (-0.35, 1.5), 1),
        "frequency": frequency,
        "delay": 1.0/frequency,
        "delay_type": "time",
        "receiver_locations": spyro.create_transect((-1.65, 0.5), (-1.65, 1.5), 200),
    },
    "absorving_boundary_conditions": {
        "status": True,
        "damping_type": "local",
    },
    "time_axis": {
        "initial_time": 0.0,
        "final_time": final_time,
        "dt": 0.0001,
        "amplitude": 1,
        "output_frequency": 100,
        "gradient_sampling_frequency": 1,
    },
    "visualization": {
        "forward_output": False,  # Disable to save memory
        "forward_output_filename": "results/forward_output.pvd",
        "fwi_velocity_model_output": False,
        "velocity_model_filename": None,
        "gradient_output": False,
        "gradient_filename": "results/Gradient.pvd",
        "adjoint_output": False,
        "adjoint_filename": None,
        "debug_output": False,
    },
}

print(f"Problem setup:")
print(f"  Frequency: {frequency} Hz")
print(f"  Degree: {degree}")
print(f"  Domain: {base_dictionary['mesh']['length_x']} × {base_dictionary['mesh']['length_z']} km")
print(f"  Receivers: {len(base_dictionary['acquisition']['receiver_locations'])}")

## 3. Generate True Model and Synthetic Data

First, we'll create the "true" velocity model and generate synthetic observed data using the `AcousticWave` class directly.

In [None]:
def create_true_model():
    """Create the true velocity model using AcousticWave directly."""
    print("Creating true velocity model...")
    
    # Create AcousticWave object for forward modeling
    true_wave = spyro.AcousticWave(dictionary=base_dictionary)
    
    # Set up high-resolution mesh for accurate forward modeling
    true_wave.set_mesh(input_mesh_parameters={
        "edge_length": 0.05,
        "mesh_type": "firedrake_mesh"
    })
    
    # Get mesh coordinates for conditional velocity definition
    mesh_z = true_wave.mesh_z
    mesh_x = true_wave.mesh_x
    
    # Define complex velocity model with multiple anomalies
    center_z, center_x, radius = -1.0, 1.0, 0.4
    square_top_z, square_bot_z = -0.8, -1.2
    square_left_x, square_right_x = 0.8, 1.2
    
    # Build velocity model using nested conditionals
    velocity_model = fire.conditional(
        (mesh_z - center_z)**2 + (mesh_x - center_x)**2 < radius**2, 
        3.0,  # Circular anomaly
        2.5   # Background
    )
    
    velocity_model = fire.conditional(
        fire.And(
            fire.And(mesh_z < square_top_z, mesh_z > square_bot_z),
            fire.And(mesh_x > square_left_x, mesh_x < square_right_x)
        ),
        3.5,  # Rectangular anomaly (highest velocity)
        velocity_model
    )
    
    # Set the velocity model
    true_wave.set_initial_velocity_model(conditional=velocity_model, output=True)
    
    return true_wave

# Create true model
true_wave_obj = create_true_model()
print("True model created successfully!")

In [None]:
# Generate synthetic observed data
print("Generating synthetic observed data...")
print("This may take a few minutes...")

# Run forward simulation to generate "observed" data
true_wave_obj.forward_solve()

# Extract the observed data (shot records)
observed_data = true_wave_obj.receivers_output.copy()
print(f"Observed data shape: {observed_data.shape}")
print(f"Data range: [{np.min(observed_data):.6f}, {np.max(observed_data):.6f}]")

# Plot the true velocity model
try:
    spyro.plots.plot_model(true_wave_obj, 
                          filename="true_model_dev.png", 
                          show=True)
    print("True model plotted successfully!")
except Exception as e:
    print(f"Plotting error: {e}")

# Clean up memory - important for large problems
del true_wave_obj.forward_solution
print("Synthetic data generation complete!")

## 4. Create Inversion Wave Object

Now we'll create our inversion `AcousticWave` object with a different mesh resolution and initial velocity model.

In [None]:
def create_inversion_wave():
    """Create AcousticWave object for inversion with coarser mesh."""
    print("Creating inversion wave object...")
    
    # Create new AcousticWave object for inversion
    inv_wave = spyro.AcousticWave(dictionary=base_dictionary)
    
    # Use coarser mesh for computational efficiency
    inv_wave.set_mesh(input_mesh_parameters={
        "edge_length": 0.08,  # Coarser than true model
        "mesh_type": "firedrake_mesh"
    })
    
    # Start with homogeneous initial guess
    initial_velocity = 2.5  # km/s
    inv_wave.set_initial_velocity_model(constant=initial_velocity)
    
    print(f"Inversion mesh created with edge length: 0.08")
    print(f"Initial velocity: {initial_velocity} km/s")
    print(f"Function space: {inv_wave.function_space}")
    print(f"DOFs: {inv_wave.function_space.dim()}")
    
    return inv_wave

# Create inversion wave object
inversion_wave = create_inversion_wave()

# Get the initial model for optimization
initial_model = inversion_wave.initial_velocity_model.copy(deepcopy=True)
print("Inversion wave object created successfully!")

## 5. Implement Custom FWI Functions

Now we'll implement the core FWI functions: objective function calculation, gradient computation, and optimization utilities.

In [None]:
class CustomFWI:
    """Custom FWI implementation using AcousticWave class."""
    
    def __init__(self, wave_obj, observed_data):
        self.wave = wave_obj
        self.observed_data = observed_data
        self.current_model = wave_obj.initial_velocity_model.copy(deepcopy=True)
        self.iteration = 0
        self.functional_history = []
        self.gradient_norm_history = []
        
        # Bounds for velocity
        self.vmin = 2.0
        self.vmax = 4.0
        
        print(f"CustomFWI initialized:")
        print(f"  Model DOFs: {self.current_model.dat.data.size}")
        print(f"  Observed data shape: {observed_data.shape}")
        print(f"  Velocity bounds: [{self.vmin}, {self.vmax}] km/s")
    
    def compute_misfit(self, velocity_array):
        """Compute data misfit for given velocity model."""
        # Update velocity model
        self.current_model.dat.data[:] = velocity_array
        self.wave.initial_velocity_model.assign(self.current_model)
        
        # Reset wave state
        self.wave.reset_pressure()
        
        # Run forward simulation
        self.wave.forward_solve()
        synthetic_data = self.wave.receivers_output
        
        # Compute L2 misfit
        residual = self.observed_data - synthetic_data
        misfit = 0.5 * np.sum(residual**2)
        
        return misfit, residual
    
    def compute_gradient(self, velocity_array):
        """Compute gradient using adjoint method."""
        # Compute misfit and residual
        misfit, residual = self.compute_misfit(velocity_array)
        
        # Run adjoint solve to get gradient
        gradient_function = self.wave.gradient_solve(misfit=residual)
        gradient_array = gradient_function.dat.data[:].copy()
        
        # Apply simple regularization (optional)
        # gradient_array *= self.current_model.dat.data[:]  # Velocity scaling
        
        return gradient_array
    
    def objective_function(self, velocity_array):
        """Objective function for scipy.optimize."""
        # Apply bounds
        velocity_array = np.clip(velocity_array, self.vmin, self.vmax)
        
        # Compute misfit
        misfit, _ = self.compute_misfit(velocity_array)
        
        # Store history
        self.functional_history.append(misfit)
        
        print(f"Iteration {self.iteration}: Misfit = {misfit:.6e}")
        self.iteration += 1
        
        return misfit
    
    def gradient_function(self, velocity_array):
        """Gradient function for scipy.optimize."""
        # Apply bounds
        velocity_array = np.clip(velocity_array, self.vmin, self.vmax)
        
        # Compute gradient
        gradient = self.compute_gradient(velocity_array)
        
        # Store gradient norm
        grad_norm = np.linalg.norm(gradient)
        self.gradient_norm_history.append(grad_norm)
        
        print(f"  Gradient norm: {grad_norm:.6e}")
        
        return gradient
    
    def get_current_model(self):
        """Get current velocity model as Firedrake function."""
        return self.current_model

# Create custom FWI object
print("Creating custom FWI object...")
fwi_solver = CustomFWI(inversion_wave, observed_data)
print("Custom FWI object created!")

## 6. Test Forward and Adjoint Operations

Before running the full inversion, let's test our forward and adjoint implementations to ensure they're working correctly.

In [None]:
# Test forward modeling
print("Testing forward modeling...")
initial_velocity_array = fwi_solver.current_model.dat.data[:].copy()
initial_misfit, initial_residual = fwi_solver.compute_misfit(initial_velocity_array)

print(f"Initial misfit: {initial_misfit:.6e}")
print(f"Initial model statistics:")
print(f"  Min velocity: {np.min(initial_velocity_array):.3f} km/s")
print(f"  Max velocity: {np.max(initial_velocity_array):.3f} km/s")
print(f"  Mean velocity: {np.mean(initial_velocity_array):.3f} km/s")

# Test gradient computation
print("\nTesting gradient computation...")
gradient = fwi_solver.compute_gradient(initial_velocity_array)

print(f"Gradient statistics:")
print(f"  Gradient norm: {np.linalg.norm(gradient):.6e}")
print(f"  Min gradient: {np.min(gradient):.6e}")
print(f"  Max gradient: {np.max(gradient):.6e}")
print(f"  Gradient shape: {gradient.shape}")

# Verify gradient direction (should point towards reducing misfit)
# Perform a simple finite difference check
eps = 1e-6
perturbed_velocity = initial_velocity_array.copy()
perturbed_velocity[0] += eps  # Perturb first parameter

misfit_plus, _ = fwi_solver.compute_misfit(perturbed_velocity)
finite_diff_grad = (misfit_plus - initial_misfit) / eps
analytical_grad = gradient[0]

print(f"\nGradient verification (first parameter):")
print(f"  Finite difference: {finite_diff_grad:.6e}")
print(f"  Analytical: {analytical_grad:.6e}")
print(f"  Relative error: {abs(finite_diff_grad - analytical_grad) / abs(analytical_grad):.2%}")

print("\nForward and adjoint operations tested successfully!")

## 7. Run Custom FWI Optimization

Now we'll run our custom FWI using scipy's optimization routines. This demonstrates how to implement FWI with full control over the optimization process.

In [None]:
# Set optimization parameters
max_iterations = 15  # Limit for tutorial purposes
optimization_method = 'L-BFGS-B'  # Limited-memory BFGS with bounds

# Set up velocity bounds for each parameter
n_params = len(initial_velocity_array)
bounds = [(fwi_solver.vmin, fwi_solver.vmax) for _ in range(n_params)]

print(f"Starting FWI optimization:")
print(f"  Method: {optimization_method}")
print(f"  Max iterations: {max_iterations}")
print(f"  Parameters: {n_params}")
print(f"  Bounds: [{fwi_solver.vmin}, {fwi_solver.vmax}] km/s")
print("\n" + "="*60)

# Reset iteration counter
fwi_solver.iteration = 0
fwi_solver.functional_history = []
fwi_solver.gradient_norm_history = []

# Record start time
start_time = time.time()

# Run optimization
result = minimize(
    fun=fwi_solver.objective_function,
    x0=initial_velocity_array,
    method=optimization_method,
    jac=fwi_solver.gradient_function,
    bounds=bounds,
    options={
        'maxiter': max_iterations,
        'disp': True,
        'gtol': 1e-6,
        'ftol': 1e-12,
    }
)

# Record end time
end_time = time.time()
total_time = end_time - start_time

print("="*60)
print(f"Optimization completed in {total_time:.2f} seconds")
print(f"Final misfit: {result.fun:.6e}")
print(f"Success: {result.success}")
print(f"Message: {result.message}")
print(f"Function evaluations: {result.nfev}")
print(f"Gradient evaluations: {result.njev}")

# Update the model with final result
final_velocity_array = result.x
fwi_solver.current_model.dat.data[:] = final_velocity_array
inversion_wave.initial_velocity_model.assign(fwi_solver.current_model)

## 8. Analyze Results and Visualize Convergence

Let's analyze the inversion results and visualize the convergence behavior.

In [None]:
# Plot convergence history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot functional (misfit) history
ax1.semilogy(fwi_solver.functional_history, 'b-o', linewidth=2, markersize=6)
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Misfit (Log Scale)')
ax1.set_title('Functional Convergence')
ax1.grid(True, alpha=0.3)

# Plot gradient norm history  
ax2.semilogy(fwi_solver.gradient_norm_history, 'r-s', linewidth=2, markersize=6)
ax2.set_xlabel('Iteration')
ax2.set_ylabel('Gradient Norm (Log Scale)')
ax2.set_title('Gradient Norm Convergence')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Analyze velocity model statistics
print("\nVelocity Model Analysis:")
print("="*40)

initial_stats = {
    'min': np.min(initial_velocity_array),
    'max': np.max(initial_velocity_array), 
    'mean': np.mean(initial_velocity_array),
    'std': np.std(initial_velocity_array)
}

final_stats = {
    'min': np.min(final_velocity_array),
    'max': np.max(final_velocity_array),
    'mean': np.mean(final_velocity_array),
    'std': np.std(final_velocity_array)
}

print("Initial Model:")
for key, value in initial_stats.items():
    print(f"  {key.capitalize()}: {value:.3f} km/s")

print("\nFinal Model:")
for key, value in final_stats.items():
    print(f"  {key.capitalize()}: {value:.3f} km/s")

print(f"\nMisfit Reduction:")
initial_misfit = fwi_solver.functional_history[0]
final_misfit = fwi_solver.functional_history[-1]
reduction_factor = initial_misfit / final_misfit
reduction_percent = (1 - final_misfit/initial_misfit) * 100

print(f"  Initial: {initial_misfit:.6e}")
print(f"  Final: {final_misfit:.6e}")
print(f"  Reduction factor: {reduction_factor:.2f}x")
print(f"  Reduction: {reduction_percent:.1f}%")

In [None]:
# Visualize the final inverted model
print("\nVisualizing final inverted model...")
try:
    spyro.plots.plot_model(inversion_wave, 
                          filename="inverted_model_dev.png", 
                          show=True)
    print("Final inverted model plotted successfully!")
except Exception as e:
    print(f"Plotting error: {e}")
    print("You can visualize results using ParaView with .pvd files")

# Compare data fit
print("\nComparing data fit...")
final_misfit_test, final_residual = fwi_solver.compute_misfit(final_velocity_array)
final_synthetic_data = fwi_solver.wave.receivers_output

print(f"Final data fit statistics:")
print(f"  Misfit: {final_misfit_test:.6e}")
print(f"  RMS residual: {np.sqrt(np.mean(final_residual**2)):.6e}")
print(f"  Max residual: {np.max(np.abs(final_residual)):.6e}")
print(f"  Correlation with observed: {np.corrcoef(observed_data.flatten(), final_synthetic_data.flatten())[0,1]:.3f}")

# Memory usage info
print(f"\nMemory Management:")
print(f"  Function space DOFs: {inversion_wave.function_space.dim()}")
print(f"  Receiver count: {len(base_dictionary['acquisition']['receiver_locations'])}")
print(f"  Time steps: {int(final_time / base_dictionary['time_axis']['dt'])}")
print(f"  Total optimization time: {total_time:.2f} seconds")

## 9. Advanced Customizations

Let's demonstrate some advanced customizations that are possible when using the `AcousticWave` class directly.

In [None]:
class AdvancedFWI(CustomFWI):
    """Extended FWI with advanced features."""
    
    def __init__(self, wave_obj, observed_data):
        super().__init__(wave_obj, observed_data)
        self.regularization_weight = 1e-6
        self.smoothing_operator = None
        self.setup_regularization()
    
    def setup_regularization(self):
        """Set up Tikhonov regularization."""
        # Create simple smoothing operator (discrete Laplacian)
        V = self.wave.function_space
        u = fire.TrialFunction(V)
        v = fire.TestFunction(V)
        
        # L2 regularization
        self.regularization_form = fire.inner(fire.grad(u), fire.grad(v)) * fire.dx
        self.reg_matrix = fire.assemble(self.regularization_form)
        
        print(f"Regularization setup with weight: {self.regularization_weight}")
    
    def compute_regularized_misfit(self, velocity_array):
        """Compute misfit with Tikhonov regularization."""
        # Standard data misfit
        data_misfit, residual = self.compute_misfit(velocity_array)
        
        # Model roughness penalty
        self.current_model.dat.data[:] = velocity_array
        reg_term = fire.assemble(fire.action(self.regularization_form, self.current_model))
        
        total_misfit = data_misfit + self.regularization_weight * reg_term
        
        return total_misfit, residual, reg_term
    
    def custom_line_search(self, model, direction, alpha_max=1.0):
        """Custom line search implementation."""
        alpha_values = [alpha_max * (0.5**i) for i in range(8)]
        best_alpha = 0.0
        best_misfit = float('inf')
        
        current_misfit, _ = self.compute_misfit(model)
        
        for alpha in alpha_values:
            test_model = model + alpha * direction
            test_model = np.clip(test_model, self.vmin, self.vmax)
            
            try:
                test_misfit, _ = self.compute_misfit(test_model)
                if test_misfit < best_misfit and test_misfit < current_misfit:
                    best_misfit = test_misfit
                    best_alpha = alpha
            except:
                continue
        
        return best_alpha
    
    def gradient_with_preconditioning(self, velocity_array):
        """Compute preconditioned gradient."""
        gradient = self.compute_gradient(velocity_array)
        
        # Simple diagonal preconditioning (velocity scaling)
        preconditioner = 1.0 / (velocity_array + 1e-6)
        preconditioned_grad = gradient * preconditioner
        
        return preconditioned_grad

# Demonstrate advanced features
print("Advanced FWI Features Demonstration:")
print("="*50)

# Create advanced FWI object
advanced_fwi = AdvancedFWI(inversion_wave, observed_data)

# Test regularized misfit
test_velocity = final_velocity_array.copy()
total_misfit, residual, reg_term = advanced_fwi.compute_regularized_misfit(test_velocity)

print(f"Regularized misfit analysis:")
print(f"  Data misfit: {total_misfit - advanced_fwi.regularization_weight * reg_term:.6e}")
print(f"  Regularization term: {reg_term:.6e}")
print(f"  Total misfit: {total_misfit:.6e}")

# Test preconditioned gradient
precond_grad = advanced_fwi.gradient_with_preconditioning(test_velocity)
regular_grad = advanced_fwi.compute_gradient(test_velocity)

print(f"\nGradient preconditioning:")
print(f"  Original gradient norm: {np.linalg.norm(regular_grad):.6e}")
print(f"  Preconditioned gradient norm: {np.linalg.norm(precond_grad):.6e}")

# Test custom line search
search_direction = -regular_grad / np.linalg.norm(regular_grad)  # Normalized descent direction
optimal_step = advanced_fwi.custom_line_search(test_velocity, search_direction)

print(f"\nCustom line search:")
print(f"  Optimal step size: {optimal_step:.6f}")

print("\nAdvanced features demonstrated successfully!")

## 10. Performance and Memory Optimization Tips

Here are some important considerations for developers implementing FWI with large-scale problems.

In [None]:
def demonstrate_optimization_techniques():
    """Demonstrate performance optimization techniques."""
    
    print("Performance Optimization Techniques:")
    print("="*50)
    
    print("\n1. Memory Management:")
    print("   • Use reset_pressure() between forward solves")
    print("   • Manually delete large arrays when done: del wave.forward_solution")
    print("   • Use deepcopy=False when possible for temporary functions")
    print("   • Monitor memory usage with resource.getrusage()")
    
    print("\n2. Computational Efficiency:")
    print("   • Use coarser meshes for initial iterations (multi-scale FWI)")
    print("   • Disable unnecessary output: visualization['forward_output'] = False")
    print("   • Use matrix-free methods for large problems")
    print("   • Consider lower time accuracy for gradient computation")
    
    print("\n3. Parallelization:")
    print("   • Shot-level parallelism: distribute sources across MPI ranks")
    print("   • Domain decomposition: use PETSc parallel solvers")
    print("   • Frequency parallelism: solve multiple frequencies simultaneously")
    
    print("\n4. Algorithmic Improvements:")
    print("   • Use L-BFGS instead of steepest descent")
    print("   • Implement bound constraints to avoid non-physical velocities")
    print("   • Add regularization to avoid overfitting")
    print("   • Use checkpointing for adjoint computation in time")
    
    print("\n5. Debugging and Monitoring:")
    print("   • Always verify gradient accuracy with finite differences")
    print("   • Monitor functional decrease at each iteration")
    print("   • Check data residuals and model updates")
    print("   • Use try-except blocks for robust optimization")

# Memory usage monitoring example
def memory_monitor_example():
    """Example of memory monitoring during FWI."""
    import resource
    
    def get_memory_usage():
        """Get current memory usage in MB."""
        usage = resource.getrusage(resource.RUSAGE_SELF)
        return usage.ru_maxrss / 1024  # Convert to MB on Linux
    
    print(f"\nMemory Usage Example:")
    print(f"Current memory: {get_memory_usage():.1f} MB")
    
    # Create temporary large arrays to demonstrate
    temp_data = np.zeros((1000, 1000))  # 8 MB array
    print(f"After allocation: {get_memory_usage():.1f} MB")
    
    del temp_data  # Clean up
    print(f"After cleanup: {get_memory_usage():.1f} MB")

# Example of checkpointing strategy
class CheckpointingFWI(CustomFWI):
    """FWI with checkpointing for memory efficiency."""
    
    def __init__(self, wave_obj, observed_data):
        super().__init__(wave_obj, observed_data)
        self.checkpoint_interval = 100  # Save every 100 time steps
        self.checkpoints = {}
    
    def forward_with_checkpointing(self, velocity_array):
        """Forward solve with periodic checkpointing."""
        # This is a conceptual implementation
        # In practice, you would modify the time stepping loop
        self.wave.forward_solve()  # Standard solve for now
        
        # Save checkpoints at specified intervals
        # checkpoints would contain pressure fields at specific times
        return self.wave.receivers_output
    
    def adjoint_with_checkpointing(self, residual):
        """Adjoint solve using saved checkpoints."""
        # Reconstruct forward solution using checkpoints
        # This reduces memory requirements for large time domain problems
        return self.wave.gradient_solve(misfit=residual)

# Demonstrate techniques
demonstrate_optimization_techniques()
memory_monitor_example()

# Show problem scaling
print(f"\nProblem Scaling Analysis:")
print(f"Current problem size:")
print(f"  DOFs: {inversion_wave.function_space.dim()}")
print(f"  Time steps: {int(final_time / base_dictionary['time_axis']['dt'])}")
print(f"  Receivers: {len(base_dictionary['acquisition']['receiver_locations'])}")
print(f"  Storage per forward solve: ~{(inversion_wave.function_space.dim() * int(final_time / base_dictionary['time_axis']['dt']) * 8) / 1e6:.1f} MB")

print(f"\nScaling to larger problems:")
print(f"  3D problems: DOFs scale as O(N³)")
print(f"  Higher frequencies: time steps scale as O(1/f)")
print(f"  More sources: computational cost scales linearly")
print(f"  Longer time windows: memory scales linearly")

## Conclusion and Advanced Topics

Congratulations! You've successfully implemented a custom FWI algorithm using Spyro's `AcousticWave` class. This approach gives you complete control over the inversion process and enables advanced customizations.

### What We Accomplished:
1. **Direct FWI Implementation**: Built FWI from scratch using only `AcousticWave` methods
2. **Custom Optimization**: Integrated with scipy.optimize for flexible optimization strategies
3. **Gradient Verification**: Implemented finite difference checking for gradient accuracy
4. **Advanced Features**: Added regularization, preconditioning, and custom line search
5. **Performance Optimization**: Demonstrated memory management and computational efficiency techniques

### Key Advantages of the AcousticWave Approach:
- **Full Control**: Complete access to all aspects of the wave simulation
- **Custom Physics**: Ability to modify the wave equation and add custom terms
- **Flexible Optimization**: Use any optimization library or custom algorithms
- **Memory Management**: Fine-grained control over memory usage and storage
- **Advanced Features**: Easy to implement regularization, constraints, and preconditioning
- **Debugging**: Direct access to all intermediate results for detailed analysis

### Advanced Extensions You Can Implement:

#### Multi-Scale FWI:
```python
# Start with low frequencies and coarse meshes
for frequency in [2, 3, 5, 8]:
    for mesh_size in [0.1, 0.08, 0.05]:
        # Update dictionary and re-run inversion
        pass
```

#### Multi-Parameter Inversion:
```python
# Simultaneously invert for velocity and density
class MultiParameterFWI(CustomFWI):
    def __init__(self, wave_obj, observed_data):
        # Set up multiple parameter spaces
        pass
```

#### Custom Objective Functions:
```python
# Cross-correlation or envelope-based misfit
def custom_misfit(observed, synthetic):
    # Implement alternative misfit functions
    return misfit
```

#### Machine Learning Integration:
```python
# Use neural networks for velocity model parameterization
import torch
class NeuralFWI(CustomFWI):
    def __init__(self, neural_network):
        # Parameterize velocity using neural networks
        pass
```

### Performance Considerations for Production Use:
- **Parallel Computing**: Implement MPI parallelization for multiple sources
- **GPU Acceleration**: Use Firedrake's GPU capabilities for large problems
- **Checkpointing**: Implement optimal checkpointing strategies for memory efficiency
- **I/O Optimization**: Efficient reading/writing of large datasets
- **Fault Tolerance**: Handle computational failures gracefully

### Real-World Applications:
- **Seismic Exploration**: Oil and gas exploration with complex geology
- **Earthquake Seismology**: Regional and global Earth structure studies  
- **Medical Imaging**: Ultrasound and photoacoustic imaging
- **Non-Destructive Testing**: Material characterization and defect detection
- **Environmental Monitoring**: Groundwater and contamination studies

### Further Reading:
- Virieux & Operto (2009): "An overview of full-waveform inversion in exploration geophysics"
- Tape et al. (2007): "Adjoint tomography of the southern California crust"
- Fichtner (2010): "Full Seismic Waveform Modelling and Inversion"
- Ramos-Martínez et al. (2016): "A robust gradient for long wavelength FWI updates"

The `AcousticWave` approach provides the foundation for implementing state-of-the-art FWI algorithms with complete flexibility and control!