# Driven Turbulence with Energy Spectra

This tutorial demonstrates comprehensive forced turbulence simulations with advanced diagnostics.

**What you'll learn:**
- Forced turbulence reaching statistical steady state
- Computing energy spectra: E(k), E(k⊥), E(k∥)
- Identifying inertial range scaling (k⁻⁵/³)
- Energy balance: injection vs dissipation
- Parameter scanning for different forcing amplitudes

**Runtime:** ~20 seconds on M1 Pro for 64³ resolution

**Prerequisites:** Complete Tutorial 01 (getting_started.ipynb) first

## Setup and Imports

In [None]:
import time
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from pathlib import Path

from krmhd import (
    SpectralGrid3D,
    initialize_random_spectrum,
    gandalf_step,
    compute_cfl_timestep,
    energy as compute_energy,
    force_alfven_modes,
    compute_energy_injection_rate,
)
from krmhd.diagnostics import (
    EnergyHistory,
    energy_spectrum_1d,
    energy_spectrum_perpendicular,
    energy_spectrum_parallel,
)

# Configure plotting
plt.style.use('default')
%matplotlib inline

print("✓ Imports successful")

## 1. Simulation Parameters

We'll set up a moderately-resolved turbulence simulation:

### Grid Parameters
- **Resolution**: 64×64×32 (suitable for laptops)
- **Domain**: Unit box (Lx = Ly = Lz = 1.0)

### Physics Parameters
- **v_A**: Alfvén velocity (set to 1.0, our velocity unit)
- **η**: Resistivity (magnetic diffusion coefficient)
- **β_i**: Ion plasma beta (ratio of thermal to magnetic pressure)
- **ν**: Collision frequency (for Hermite moments)

### Forcing Parameters
- **amplitude**: Controls energy injection rate (ε ∝ amplitude²)
- **n_min, n_max**: Mode numbers for forcing (1-2 = large scales)

In [None]:
# Grid resolution
Nx, Ny, Nz = 64, 64, 32
Lx = Ly = Lz = 1.0

# Physics parameters
v_A = 1.0          # Alfvén velocity
eta = 0.02         # Resistivity
beta_i = 1.0       # Ion plasma beta
nu = 0.02          # Collision frequency

# Forcing parameters
force_amplitude = 0.3   # Forcing strength
n_force_min = 1         # Minimum forcing mode
n_force_max = 2         # Maximum forcing mode

# Initial spectrum (weak, so we see forcing effect)
alpha = 5.0 / 3.0       # Spectral index (k⁻⁵/³)
amplitude = 0.1         # Initial amplitude
k_min = 1.0
k_max = 10.0

# Time integration
n_steps = 200           # Number of timesteps
cfl_safety = 0.3        # CFL safety factor
save_interval = 5       # Save diagnostics every N steps

print(f"Grid: {Nx} × {Ny} × {Nz}")
print(f"Physics: v_A={v_A}, η={eta}, β_i={beta_i}")
print(f"Forcing: amplitude={force_amplitude}, modes n ∈ [{n_force_min}, {n_force_max}]")
print(f"Time: {n_steps} steps with CFL safety = {cfl_safety}")

## 2. Initialize Grid and State

We start with a **weak random turbulent spectrum** so we can clearly observe the effect of forcing.

In [None]:
print("Initializing grid and state...")
start_init = time.time()

# Create spectral grid
grid = SpectralGrid3D.create(Nx=Nx, Ny=Ny, Nz=Nz, Lx=Lx, Ly=Ly, Lz=Lz)
print(f"✓ Created {Nx}×{Ny}×{Nz} spectral grid")

# Initialize with weak turbulent spectrum
state = initialize_random_spectrum(
    grid,
    M=20,              # Number of Hermite moments
    alpha=alpha,       # Spectral index
    amplitude=amplitude,
    k_min=k_min,
    k_max=k_max,
    v_th=1.0,
    beta_i=beta_i,
    nu=nu,
    Lambda=1.0,
    seed=42,
)
print(f"✓ Initialized weak k⁻{alpha:.2f} spectrum")

# Initialize JAX random key for forcing
key = jax.random.PRNGKey(42)

# Compute initial energy
initial_energies = compute_energy(state)
print(f"✓ Initial energy: E_total = {initial_energies['total']:.4e}")
print(f"  - Magnetic: {initial_energies['magnetic']:.4e}")
print(f"  - Kinetic:  {initial_energies['kinetic']:.4e}")

elapsed_init = time.time() - start_init
print(f"\nInitialization time: {elapsed_init:.2f}s")

## 3. Run Forced Turbulence Simulation

The simulation loop:
1. **Force**: Inject energy at large scales (modes n=1-2)
2. **Evolve**: Nonlinear cascade + dissipation
3. **Diagnose**: Track energy, injection rate, spectra

In steady state, we expect: **ε_injection ≈ ε_dissipation**

In [None]:
print("Running forced turbulence simulation...\n")
start_sim = time.time()

# Initialize energy history
history = EnergyHistory()
history.append(state)

# Compute CFL-limited timestep
dt = compute_cfl_timestep(state, v_A=v_A, cfl_safety=cfl_safety)
print(f"Using timestep dt = {dt:.4f} (CFL-limited)\n")

# Track energy injection
total_injection = 0.0
injection_history = []

# Main timestepping loop
for i in range(n_steps):
    # 1. Apply forcing
    state_before_forcing = state
    state, key = force_alfven_modes(
        state,
        amplitude=force_amplitude,
        n_min=n_force_min,
        n_max=n_force_max,
        dt=dt,
        key=key
    )
    
    # 2. Compute energy injection rate
    eps_inj = compute_energy_injection_rate(state_before_forcing, state, dt)
    total_injection += eps_inj * dt
    injection_history.append(eps_inj)
    
    # 3. Evolve dynamics (cascade + dissipation)
    state = gandalf_step(state, dt=dt, eta=eta, v_A=v_A)
    
    # 4. Save diagnostics
    if (i + 1) % save_interval == 0:
        history.append(state)
        E = compute_energy(state)["total"]
        print(f"Step {i+1:3d}/{n_steps}: t={state.time:.2f} τ_A, E={E:.4e}, ε_inj={eps_inj:.3e}")

elapsed_sim = time.time() - start_sim

print(f"\n✓ Completed {n_steps} timesteps")
print(f"  Runtime: {elapsed_sim:.1f}s ({elapsed_sim/60:.2f} min)")
print(f"  (Reference: ~20s on M1 Pro)")

## 4. Compute Final Diagnostics

Let's analyze the simulation results.

In [None]:
print("Computing diagnostics...\n")

# Final energy
final_energies = compute_energy(state)
print(f"Energy Budget:")
print(f"  Initial energy:  {initial_energies['total']:.4e}")
print(f"  Final energy:    {final_energies['total']:.4e}")
print(f"  Energy change:   {final_energies['total'] - initial_energies['total']:.4e}")
print(f"  Relative change: {(final_energies['total']/initial_energies['total'] - 1)*100:.1f}%")
print(f"  Total injection: {total_injection:.4e}\n")

# Energy partition
E_mag = final_energies["magnetic"]
E_kin = final_energies["kinetic"]
mag_fraction = E_mag / (E_mag + E_kin)
print(f"Energy Partition:")
print(f"  Magnetic: {E_mag:.4e} ({mag_fraction*100:.1f}%)")
print(f"  Kinetic:  {E_kin:.4e} ({(1-mag_fraction)*100:.1f}%)")
print(f"  Equipartition expected at 50%\n")

# Injection/dissipation balance
avg_injection_rate = total_injection / state.time
print(f"Steady State Balance:")
print(f"  Average injection rate: ⟨ε_inj⟩ = {avg_injection_rate:.3e}")
print(f"  (In steady state: ε_inj ≈ ε_diss)")

## 5. Compute Energy Spectra

Energy spectra reveal the distribution of energy across scales:

- **E(k)**: 1D spherically-averaged spectrum
- **E(k⊥)**: Perpendicular spectrum (across field)
- **E(k∥)**: Parallel spectrum (along field)

In MHD turbulence, we expect **k⁻⁵/³ scaling** in the inertial range.

In [None]:
print("Computing energy spectra...")

# Compute all three spectra
spec_1d = energy_spectrum_1d(state)
spec_perp = energy_spectrum_perpendicular(state)
spec_parallel = energy_spectrum_parallel(state)

k_bins, E_k = spec_1d
k_perp_bins, E_k_perp = spec_perp
k_par_bins, E_k_par = spec_parallel

print(f"✓ 1D spectrum: {len(k_bins)} wavenumber bins")
print(f"✓ Perpendicular spectrum: {len(k_perp_bins)} k⊥ bins")
print(f"✓ Parallel spectrum: {len(k_par_bins)} k∥ bins")

## 6. Visualization: Energy History

First, let's plot how energy evolves over time.

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 10))

times = jnp.array(history.times)
energies = jnp.array(history.E_total)
E_mag_hist = jnp.array(history.E_magnetic)
E_kin_hist = jnp.array(history.E_kinetic)

# Plot 1: Total energy
axes[0].plot(times, energies, 'b-', linewidth=2, label="Total energy")
axes[0].axhline(energies[-1], color='r', linestyle='--', alpha=0.5, label=f"Final: {energies[-1]:.3e}")
axes[0].axhline(initial_energies['total'], color='gray', linestyle=':', alpha=0.5, label=f"Initial: {initial_energies['total']:.3e}")
axes[0].set_ylabel("Total Energy", fontsize=12)
axes[0].set_title("Forced Turbulence: Energy Evolution", fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Plot 2: Magnetic vs Kinetic
axes[1].plot(times, E_mag_hist, 'r-', linewidth=2, label="Magnetic", alpha=0.8)
axes[1].plot(times, E_kin_hist, 'b-', linewidth=2, label="Kinetic", alpha=0.8)
axes[1].set_ylabel("Energy Components", fontsize=12)
axes[1].set_title("Magnetic vs Kinetic Energy", fontsize=12, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

# Plot 3: Magnetic fraction
mag_frac = E_mag_hist / (E_mag_hist + E_kin_hist)
axes[2].plot(times, mag_frac, 'g-', linewidth=2)
axes[2].axhline(0.5, color='k', linestyle='--', alpha=0.5, label="Equipartition (50%)")
axes[2].set_xlabel("Time (τ_A)", fontsize=12)
axes[2].set_ylabel("Magnetic Fraction", fontsize=12)
axes[2].set_title("Energy Partition", fontsize=12, fontweight='bold')
axes[2].set_ylim([0, 1])
axes[2].legend(fontsize=10)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal magnetic fraction: {mag_frac[-1]:.3f} (equipartition = 0.5)")

## 7. Visualization: Energy Spectra

Now let's examine the energy distribution across scales. The **inertial range** should show k⁻⁵/³ scaling.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Forcing wavenumber bounds
k_force_min = 2 * jnp.pi * n_force_min / Lx
k_force_max = 2 * jnp.pi * n_force_max / Lx

# Plot 1: 1D spherically-averaged spectrum
axes[0].loglog(k_bins, E_k, 'o-', markersize=5, linewidth=2, label='E(k)')

# Add k⁻⁵/³ reference line in inertial range
k_inertial = k_bins[(k_bins > 5) & (k_bins < 15)]
if len(k_inertial) > 0:
    k_ref = k_inertial[len(k_inertial)//2]
    E_ref = E_k[(k_bins > 5) & (k_bins < 15)][len(k_inertial)//2]
    k_range = k_bins[(k_bins > 3) & (k_bins < 20)]
    axes[0].loglog(k_range, E_ref * (k_range/k_ref)**(-5/3),
                   'k--', alpha=0.7, linewidth=2, label='k⁻⁵/³ (Kolmogorov)')

# Mark forcing band
axes[0].axvspan(k_force_min, k_force_max, alpha=0.2, color='green', label='Forcing band')

axes[0].set_xlabel("Wavenumber |k|", fontsize=12)
axes[0].set_ylabel("E(|k|)", fontsize=12)
axes[0].set_title("1D Energy Spectrum", fontsize=13, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3, which="both")

# Plot 2: Perpendicular spectrum
axes[1].loglog(k_perp_bins, E_k_perp, 'o-', markersize=5, linewidth=2, 
               color='orange', label='E(k⊥)')
axes[1].axvspan(k_force_min, k_force_max, alpha=0.2, color='green', label='Forcing')
axes[1].set_xlabel("k⊥ (perpendicular)", fontsize=12)
axes[1].set_ylabel("E(k⊥)", fontsize=12)
axes[1].set_title("Perpendicular Spectrum", fontsize=13, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3, which="both")

# Plot 3: Parallel spectrum
axes[2].semilogy(k_par_bins, E_k_par, 'o-', markersize=6, linewidth=2, 
                 color='purple', label='E(k∥)')
axes[2].axvspan(k_force_min, k_force_max, alpha=0.2, color='green', label='Forcing')
axes[2].set_xlabel("k∥ (parallel)", fontsize=12)
axes[2].set_ylabel("E(k∥)", fontsize=12)
axes[2].set_title("Parallel Spectrum", fontsize=13, fontweight='bold')
axes[2].legend(fontsize=10)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nSpectral Features:")
print(f"  • Forcing injects energy at k ≈ {k_force_min:.1f}-{k_force_max:.1f}")
print(f"  • Energy cascades to small scales (high k)")
print(f"  • Inertial range (if present) shows k⁻⁵/³ scaling")
print(f"  • Dissipation cuts off spectrum at high k")

## 8. Energy Injection Rate Over Time

The stochastic forcing creates fluctuating injection rates.

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

time_steps = jnp.arange(1, len(injection_history)+1) * dt
ax.plot(time_steps, injection_history, 'r-', linewidth=1, alpha=0.6, label='ε_inj(t)')
ax.axhline(avg_injection_rate, color='darkred', linestyle='--', linewidth=2, 
           label=f'Mean: {avg_injection_rate:.3e}')

ax.set_xlabel('Time (τ_A)', fontsize=12)
ax.set_ylabel('Energy Injection Rate ε_inj', fontsize=12)
ax.set_title('Stochastic Energy Injection', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nInjection Statistics:")
print(f"  Mean: {jnp.mean(jnp.array(injection_history)):.3e}")
print(f"  Std:  {jnp.std(jnp.array(injection_history)):.3e}")
print(f"  Max:  {jnp.max(jnp.array(injection_history)):.3e}")

## 9. Parameter Scan: Effect of Forcing Amplitude

Let's explore how different forcing amplitudes affect the steady-state energy level.

**Expected behavior**: Higher forcing → higher steady-state energy

In [None]:
print("Running parameter scan over forcing amplitudes...\n")

amplitudes_to_test = [0.1, 0.2, 0.3, 0.4]
scan_results = {}

for amp in amplitudes_to_test:
    print(f"Testing amplitude = {amp}...")
    
    # Reset state
    state_scan = initialize_random_spectrum(
        grid, M=20, alpha=alpha, amplitude=0.1, k_min=k_min, k_max=k_max,
        v_th=1.0, beta_i=beta_i, nu=nu, Lambda=1.0, seed=42
    )
    key_scan = jax.random.PRNGKey(42)
    
    energies_scan = [compute_energy(state_scan)['total']]
    
    # Run shorter simulation
    for i in range(100):
        state_scan, key_scan = force_alfven_modes(
            state_scan, amplitude=amp, n_min=n_force_min, 
            n_max=n_force_max, dt=dt, key=key_scan
        )
        state_scan = gandalf_step(state_scan, dt=dt, eta=eta, v_A=v_A)
        energies_scan.append(compute_energy(state_scan)['total'])
    
    scan_results[amp] = jnp.array(energies_scan)
    print(f"  Final energy: {energies_scan[-1]:.4e}\n")

print("✓ Parameter scan complete")

In [None]:
# Plot scan results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Left: Energy evolution for each amplitude
colors = plt.cm.viridis(jnp.linspace(0.2, 0.9, len(amplitudes_to_test)))
for amp, color in zip(amplitudes_to_test, colors):
    time_scan = jnp.arange(len(scan_results[amp])) * dt
    ax1.plot(time_scan, scan_results[amp], linewidth=2, 
             label=f'Amp = {amp}', color=color)

ax1.set_xlabel('Time (τ_A)', fontsize=12)
ax1.set_ylabel('Total Energy', fontsize=12)
ax1.set_title('Energy Evolution: Parameter Scan', fontsize=13, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Right: Final energy vs forcing amplitude
final_energies_scan = [scan_results[amp][-1] for amp in amplitudes_to_test]
ax2.plot(amplitudes_to_test, final_energies_scan, 'o-', 
         markersize=10, linewidth=2, color='darkblue')
ax2.set_xlabel('Forcing Amplitude', fontsize=12)
ax2.set_ylabel('Final Energy', fontsize=12)
ax2.set_title('Steady-State Energy vs Forcing', fontsize=13, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Finding:")
print("  Higher forcing amplitude → Higher steady-state energy")
print("  Energy grows until injection balances dissipation")

## Summary and Key Results

In this tutorial, you learned:

✅ How to run forced turbulence simulations with energy injection  
✅ Computing and visualizing energy spectra (1D, k⊥, k∥)  
✅ Identifying inertial range scaling (k⁻⁵/³)  
✅ Understanding energy balance: injection vs dissipation  
✅ Parameter scanning to explore forcing effects  

### Physical Insights

1. **Forced turbulence** reaches statistical steady state when ε_inj ≈ ε_diss
2. **Energy cascade** transfers energy from large scales (forcing) to small scales (dissipation)
3. **Inertial range** shows k⁻⁵/³ Kolmogorov-like scaling
4. **Equipartition** between magnetic and kinetic energy emerges naturally

### Next Steps

- **Tutorial 03**: Analyzing decaying turbulence and spectral slopes
- **Advanced**: See `examples/benchmarks/alfvenic_cascade_benchmark.py` for high-resolution production runs
- **Parameter exploration**: Try varying η, ν, resolution, forcing scales

### References

- See `docs/physics_validity.md` for RMHD validity regimes
- See `docs/numerical_methods.md` for spectral method details
- See `docs/parameter_scans.md` for research workflow guidance