<div align="center">
    <h1>🌩️ Nimbus Demo</h1>
</div>

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 20px;
    border-radius: 8px;
    margin-bottom: 20px;
    border-left: 3px solid #98C1D9;
    text-align: center;
">
    High-performance JAX-based flight simulation framework. This notebook demonstrates Nimbus's core capabilities.
</div>

<div style="
    background: #293542;
    padding: 10px;
    border-radius: 8px;
    margin: 20px 0;
    text-align: center;
">
    <img src="https://raw.githubusercontent.com/Auxeno/nimbus/main/videos/clip_2.gif" 
         style="
             border-radius: 6px;
             max-width: 100%;
             width: 700px;
         " 
         alt="Nimbus Flight Simulation" />
</div>

---

## 📦 Installation

Install Nimbus and dependencies:

In [None]:
# Install Nimbus from GitHub
!pip install -q git+https://github.com/auxeno/nimbus

# Install JAX with GPU support (optional - will use CPU if not available)
!pip install -q "jax[cuda12]"

# Install visualisation dependencies
!pip install -q numpy plotly

## 🚀 Massive Parallelisation

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin-bottom: 20px;
    border-left: 3px solid #2F3E9E;
">
    Simulate <strong>1 million aircraft</strong> in parallel using JAX vectorisation.
</div>

In [None]:
import jax
from nimbus import (
    InitialConditions, SimulationConfig, generate_simulation, quick_scenario, step
)

# Set up a shared heightmap and waypoint route
_, heightmap, waypoint_route = quick_scenario(seed=0)

# Generate 1 million unique initial aircraft simulation states
num_aircraft = 1_000_000
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, num=num_aircraft)
simulation_states = jax.vmap(generate_simulation, in_axes=(0, None))(
    keys, 
    InitialConditions.default()
)

# Compile and vectorise the step function for maximum performance
config = SimulationConfig()
step_fn = jax.jit(step, static_argnames=("config",))
step_parallel = jax.vmap(step_fn, in_axes=(None, 0, None, None, None))

# Execute one simulation step for all aircraft
stepped_states = step_parallel(
    key, simulation_states, heightmap, waypoint_route, config
)

print(f"✅ Successfully simulated {num_aircraft:,} aircraft in parallel!")
print(f"Aircraft positions shape: {stepped_states[0].aircraft.body.position.shape}")

In [None]:
%%timeit

# Benchmark the parallel simulation
jax.block_until_ready(
    step_parallel(key, simulation_states, heightmap, waypoint_route, config)
)

## ⏱️ Extended Simulation

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin-bottom: 20px;
    border-left: 3px solid #2F3E9E;
">
    Simulate 1 minute of flight across 1,000 vectorised simulations using JAX's efficient <code>lax.scan</code>.
</div>

In [None]:
from nimbus import quick_terrain, generate_route, generate_simulation, step

# Configuration
num_environments = 1000
num_seconds = 60
fps = 60
num_steps = num_seconds * fps

# Generate shared terrain and unique initial conditions
heightmap = quick_terrain(seed=0)
rng = jax.random.PRNGKey(0)

# Create initial states and routes for each environment
env_keys = jax.random.split(rng, num=num_environments)
simulation_states = jax.vmap(generate_simulation, in_axes=(0, None))(
    env_keys,
    InitialConditions.default()
)
routes = jax.vmap(generate_route, in_axes=(0, None))(
    env_keys,
    InitialConditions.default()
)

# Configure simulation with 60 FPS timestep
config = SimulationConfig(dt=1/fps)
step_fn = jax.jit(step, static_argnames=("config",))
step_parallel = jax.vmap(step_fn, in_axes=(0, 0, None, 0, None))

# Define scan function for temporal iteration
def simulate_timestep(carry, key):
    sims, routes = carry
    env_keys = jax.random.split(key, num=num_environments)
    next_sims, next_routes = step_parallel(env_keys, sims, heightmap, routes, config)
    return (next_sims, next_routes), None

# Run the simulation using JAX's scan (efficient loop)
time_keys = jax.random.split(rng, num=num_steps)
(final_states, final_routes), _ = jax.lax.scan(
    f=simulate_timestep,
    init=(simulation_states, routes),
    xs=time_keys
)

print(f"✅ Simulated {num_seconds} seconds of flight")
print(f"   • {num_environments:,} parallel environments")
print(f"   • {num_steps:,} timesteps @ {fps} FPS")
print(f"   • {num_environments * num_steps:,} total simulation steps")

In [None]:
%%timeit

# Benchmark the full minute simulation
init_carry = (simulation_states, routes)
jax.block_until_ready(jax.lax.scan(simulate_timestep, init_carry, time_keys))

## 🗺️ Scenario Visualisation

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin-bottom: 20px;
    border-left: 3px solid #2F3E9E;
">
    Generate and visualise 3D flight scenarios with procedural terrain. Plots are interactive - drag to rotate and scroll to zoom.
</div>

In [None]:
from nimbus import quick_scenario, MapConfig

# Generate a complete scenario with:
# - Aircraft with customisable randomised initial conditions
# - Procedurally generated terrain using layered simplex noise
# - Waypoint route for navigation
simulation_state, heightmap, waypoint_route = quick_scenario(seed=4)

print(f"Generated scenario:")
print(f"  • Terrain: {heightmap.shape[0]}x{heightmap.shape[1]} heightmap")
print(f"  • Waypoints: {len(waypoint_route.positions)} navigation points")
print(f"  • Aircraft altitude: {-simulation_state.aircraft.body.position[2]:.1f}m")

In [None]:
# 3D visualisation function
import numpy as np
import plotly.graph_objects as go
from nimbus.core import quaternion

def plot_scenario_3d(simulation_state, heightmap, waypoint_route, map_config):
    """Create an interactive 3D plot of the flight scenario."""
    
    # Convert terrain to plotting coordinates (Z-up convention)
    heightmap_np = np.array(heightmap.T, dtype=np.float32)
    terrain_elevation = (heightmap_np - 0.5) * 2.0 * float(map_config.terrain_height)
    
    # Create terrain mesh grid
    rows, cols = heightmap_np.shape
    half_size = float(map_config.size) / 2.0
    north = np.linspace(-half_size, half_size, cols, dtype=np.float32)
    east = np.linspace(-half_size, half_size, rows, dtype=np.float32)
    X, Y = np.meshgrid(north, east)
    
    # Extract aircraft position (convert NED to plotting frame)
    aircraft_pos = np.array(simulation_state.aircraft.body.position, dtype=np.float32)
    ax, ay, az = aircraft_pos[0], aircraft_pos[1], -aircraft_pos[2]  # Z-up for plotting
    
    # Extract waypoint positions
    route_pos = np.array(waypoint_route.positions, dtype=np.float32)
    rx, ry, rz = route_pos[:, 0], route_pos[:, 1], -route_pos[:, 2]
    
    # Calculate aircraft heading from quaternion
    q = np.array(simulation_state.aircraft.body.orientation, dtype=np.float32)
    forward_body = np.array([1.0, 0.0, 0.0], dtype=np.float32)
    forward_world = np.array(quaternion.to_rotation_matrix(q)) @ forward_body
    forward_plot = np.array([forward_world[0], forward_world[1], -forward_world[2]])
    forward_plot = (forward_plot / (np.linalg.norm(forward_plot) + 1e-6)) * 250.0
    
    # Create 3D plot
    fig = go.Figure()
    
    # Add terrain surface
    fig.add_trace(go.Surface(
        x=X, y=Y, z=terrain_elevation,
        colorscale="ice",
        showscale=False,
        opacity=0.96,
        name="Terrain"
    ))
    
    # Add waypoint route
    fig.add_trace(go.Scatter3d(
        x=rx, y=ry, z=rz,
        mode="markers+lines",
        marker=dict(size=6, color="white"),
        line=dict(color="white", width=3),
        name="Route"
    ))
    
    # Add aircraft
    fig.add_trace(go.Cone(
        x=[ax], y=[ay], z=[az],
        u=[forward_plot[0]], v=[forward_plot[1]], w=[forward_plot[2]],
        anchor="tail",
        colorscale=[[0, "#98C1D9"], [1, "cyan"]],
        showscale=False,
        sizemode="absolute",
        sizeref=350.0,
        name="Aircraft"
    ))
    
    # Configure layout
    fig.update_layout(
        template="plotly_dark",
        title="Flight Scenario 3D View",
        scene=dict(
            xaxis_title="North (m)",
            yaxis_title="East (m)",
            zaxis_title="Altitude (m)",
            aspectmode="data",
            camera=dict(eye=dict(x=0.7, y=0.9, z=0.55))
        ),
        margin=dict(l=0, r=0, t=50, b=0),
        height=600
    )
    
    return fig

In [None]:
# Create and display the 3D visualisation
fig = plot_scenario_3d(
    simulation_state,
    heightmap,
    waypoint_route,
    MapConfig()
)
fig.show()

## 🎯 Custom Scenarios

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin-bottom: 20px;
    border-left: 3px solid #2F3E9E;
">
    Define precise initial conditions. Note: Nimbus uses NED coordinates where positive Z points down.
</div>

In [None]:
import jax
from nimbus import Fixed, Uniform, InitialConditions, MapConfig, generate_scenario

# Define custom initial conditions
custom_conditions = InitialConditions(
    # Position in North-East-Down (NED) frame [meters]
    position=(
        Fixed(1000.0),           # 1km North of center
        Fixed(-500.0),           # 500m West of center  
        Uniform(0.0, -1000.0),   # random altitude 0-1000m
    ),
    
    # Velocity in NED frame [m/s]
    velocity=(
        Fixed(-100.0),  # flying South at 100 m/s
        Fixed(0.0),     # no lateral velocity
        Fixed(0.0),     # level flight
    ),
    
    # Orientation as Euler angles [degrees]
    orientation_euler=(
        Fixed(180.0),   # heading South
        Fixed(2.0),     # slight nose-up pitch
        Fixed(0.0),     # wings level
    ),

    angular_velocity=(Fixed(0.0), Fixed(0.0), Fixed(0.0)),

    wind_speed=Fixed(0.0),
    wind_direction=Fixed(0.0),
    
    # Waypoints for navigation [NED positions]
    waypoints=(
        (Fixed(-1500.0), Fixed(0.0), Fixed(-1000.0)),
        (Fixed(-2000.0), Fixed(0.0), Fixed(-1400.0)),
        (Fixed(-2500.0), Fixed(0.0), Fixed(-1000.0)),
        (Fixed(-3000.0), Fixed(0.0), Fixed(-600.0)),
    )
)

# Generate scenario with custom conditions and taller mountains
map_config = MapConfig(terrain_height=3000.0)
custom_sim, custom_terrain, custom_route = generate_scenario(
    key=jax.random.PRNGKey(0),
    initial_conditions=custom_conditions,
    terrain_config=map_config.terrain
)

# Visualise the custom scenario
fig = plot_scenario_3d(custom_sim, custom_terrain, custom_route, map_config)
fig.update_layout(title="Custom Scenario with High Mountains")
fig.show()

## ⚙️ Further Features

### Terrain Generation

<div style="
    background: #293542;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin: 20px 0;
">
    Procedural terrain using layered simplex noise. Adjust parameters for different terrain characteristics.
</div>

In [None]:
import jax
from nimbus import generate_terrain_map, TerrainConfig

# Generate terrains with different characteristics
keys = jax.random.split(jax.random.PRNGKey(0), 3)

terrain_configs = [
    TerrainConfig(),                                  # default balanced terrain
    TerrainConfig(mountain_gain=5.0, bump_gain=0.3),  # mountainous
    TerrainConfig(mountain_gain=0.5, bump_gain=1.5),  # rough/bumpy
]

terrains = [generate_terrain_map(k, c) for k, c in zip(keys, terrain_configs)]

print("Generated terrain variations:")
for i, terrain in enumerate(terrains):
    elevation_range = (terrain.max() - terrain.min()) * 2000  # Assuming 2000m height
    print(f"  Terrain {i+1}: shape={terrain.shape}, elevation range≈{elevation_range:.0f}m")

### Aircraft Configuration

<div style="
    background: #293542;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin: 20px 0;
">
    Configure aircraft parameters to simulate different vehicle types.
</div>

In [None]:
import jax
from nimbus import AircraftConfig, InitialConditions, generate_simulation

# Define different aircraft types with realistic parameters
aircraft_types = {
    "Fighter": AircraftConfig(
        mass=8000.0,                      # lightweight fighter
        surface_areas=(8.0, 15.0, 25.0),  # compact frontal area, moderate wings
        max_thrust=120_000.0,             # high thrust-to-weight ratio
        max_attack_angle=25.0,            # highly maneuverable
        coef_lift=12.0,                   # high lift capability
        coef_drag=0.25,                   # slightly higher drag
        coefs_torque=(30.0, 10.0, 2.0),   # very responsive controls
        g_limit_max=9.0,                  # fighter pilot limits
        g_limit_min=-3.0,
    ),
    "Transport": AircraftConfig(
        mass=70_000.0,                       # heavy transport aircraft
        surface_areas=(80.0, 150.0, 250.0),  # large cross-sections
        max_thrust=400_000.0,                # powerful but lower T/W ratio
        max_attack_angle=15.0,               # conservative flight envelope
        coef_lift=8.0,                       # moderate lift
        coef_drag=0.18,                      # optimized for efficiency
        coefs_torque=(15.0, 5.0, 1.0),       # slower control response
        g_limit_max=2.5,                     # passenger comfort limits
        g_limit_min=-1.0,
    ),
    "Aerobatic": AircraftConfig(
        mass=1200.0,                     # very light aerobatic plane
        surface_areas=(2.0, 4.0, 12.0),  # minimal frontal area
        max_thrust=15_000.0,             # good power-to-weight
        max_attack_angle=35.0,           # extreme maneuverability
        coef_lift=15.0,                  # very high lift
        coef_drag=0.3,                   # draggy at high angles
        coefs_torque=(50.0, 20.0, 5.0),  # extremely responsive
        g_limit_max=12.0,                # aerobatic limits
        g_limit_min=-6.0,
    ),
    "Glider": AircraftConfig(
        mass=600.0,                      # lightweight glider
        surface_areas=(1.5, 3.0, 35.0),  # very high aspect ratio wings
        max_thrust=0.0,                  # no engine
        max_attack_angle=12.0,           # narrow optimal range
        coef_lift=14.0,                  # excellent lift-to-drag
        coef_drag=0.08,                  # extremely low drag
        coefs_torque=(8.0, 3.0, 0.5),    # gentle controls
        g_limit_max=5.5,                 # structural limits
        g_limit_min=-2.5,
    ),
}

# Generate simulations for each aircraft type
key = jax.random.PRNGKey(0)
initial = InitialConditions.default()

print("Aircraft configurations:")
for name, config in aircraft_types.items():
    thrust_weight = config.max_thrust / (config.mass * 9.81) if config.max_thrust > 0 else 0
    wing_loading = config.mass / config.surface_areas[2]  # mass / wing area
    print(f"\n{name}:")
    print(f"  • Mass: {config.mass:,.0f} kg")
    print(f"  • Thrust/Weight: {thrust_weight:.2f}")
    print(f"  • Wing Loading: {wing_loading:.1f} kg/m²")
    print(f"  • G-limits: +{config.g_limit_max:.1f}/-{config.g_limit_min:.1f} G")

## 📊 Performance Benchmarking

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 15px;
    border-radius: 6px;
    margin-bottom: 20px;
    border-left: 3px solid #2F3E9E;
">
    Benchmarking of Nimbus simulation throughput across different scales. Tests vectorised aircraft count vs temporal simulation steps.
</div>

In [None]:
# Benchmarking function
import time
import jax
from nimbus import (
    InitialConditions, SimulationConfig, generate_simulation, generate_route, 
    quick_terrain, step
)

def benchmark_nimbus(
    aircraft_counts: list[int], 
    timestep_counts: list[int], 
    num_runs: int = 10, 
    fps: int = 60
) -> dict:

    results = {}
    
    # Setup shared resources
    print("Setting up benchmark resources...")
    heightmap = quick_terrain(seed=0)
    config = SimulationConfig(dt=1/fps)
    initial = InitialConditions.default()
    
    for num_timesteps in timestep_counts:
        print(f"\n{'='*60}")
        print(f"Benchmarking {num_timesteps} timesteps ({num_timesteps/fps:.1f} seconds)")
        print('='*60)
        
        for num_aircraft in aircraft_counts:
            config_key = (num_aircraft, num_timesteps)
            
            print(f"  {num_aircraft:>8} aircraft: ", end='', flush=True)
            
            # Generate initial states
            key = jax.random.PRNGKey(0)
            keys = jax.random.split(key, num_aircraft)
            simulations = jax.vmap(generate_simulation, in_axes=(0, None))(keys, initial)
            routes = jax.vmap(generate_route, in_axes=(0, None))(keys, initial)
            
            # Compile step function
            step_fn = jax.jit(step, static_argnames=("config",))
            
            if num_timesteps == 1:
                # Single step benchmark - vmap only
                step_parallel = jax.vmap(step_fn, in_axes=(0, 0, None, 0, None))
                
                # Warmup
                step_keys = jax.random.split(key, num_aircraft)
                _ = step_parallel(step_keys, simulations, heightmap, routes, config)
                jax.block_until_ready(_)
                
                # Benchmark runs
                times = []
                for run in range(num_runs):
                    start = time.perf_counter()
                    output = step_parallel(step_keys, simulations, heightmap, routes, config)
                    jax.block_until_ready(output)
                    times.append(time.perf_counter() - start)
                    print('.', end='', flush=True)
                
            else:
                # Multi-step benchmark - vmap + scan
                step_parallel = jax.vmap(step_fn, in_axes=(0, 0, None, 0, None))
                
                def simulate_timestep(carry, key):
                    sims, routes_carry = carry
                    env_keys = jax.random.split(key, num_aircraft)
                    next_sims, next_routes = step_parallel(env_keys, sims, heightmap, routes_carry, config)
                    return (next_sims, next_routes), None
                
                # Warmup
                time_keys = jax.random.split(key, num_timesteps)
                _ = jax.lax.scan(simulate_timestep, (simulations, routes), time_keys)
                jax.block_until_ready(_)
                
                # Benchmark runs
                times = []
                for run in range(num_runs):
                    start = time.perf_counter()
                    output = jax.lax.scan(simulate_timestep, (simulations, routes), time_keys)
                    jax.block_until_ready(output)
                    times.append(time.perf_counter() - start)
                    print('.', end='', flush=True)
            
            # Calculate throughput
            avg_time = sum(times) / len(times)
            total_steps = num_aircraft * num_timesteps
            throughput = total_steps / avg_time
            
            results[config_key] = {
                'aircraft': num_aircraft,
                'timesteps': num_timesteps,
                'avg_time': avg_time,
                'throughput': throughput,
                'total_steps': total_steps
            }
            
            print(f" {throughput:>12,.0f} steps/sec")
    
    return results

In [None]:
# Define benchmark counts
aircraft_counts = [1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000]
timestep_counts = [1, 60, 1800]  # 1 frame, 1 second, 30 seconds

# Run benchmarks
benchmark_results = benchmark_nimbus(aircraft_counts, timestep_counts)

---

<div style="
    background: #1E2A38;
    color: #EEEEEE;
    padding: 20px;
    border-radius: 8px;
    margin-top: 40px;
    text-align: center;
    border: 1px solid #98C1D9;
">
    <p style="margin-top: 15px; color: #EEEEEE;">
        Check out the <a href="https://github.com/Auxeno/nimbus" style="color: #98C1D9;">GitHub repository</a> 
        for more examples and documentation.
    </p>
</div>