In [None]:
%cd ..

## 🌩️ Nimbus Demo

Explanations of JAX can be found in the other notebook in this repo.

Sections:
- Install
- Generate scenarios
- 1,000,000 parallel aircraft


### Install

In [None]:
!pip install git+https://github.com/auxeno/nimbus.git
!pip install numpy, plotly

### Generate and Plot Scenarios

In [None]:
import jax
from nimbus.core.config import MapConfig
from nimbus.core.scenario import InitialConditions, generate_scenario


# Generate pre-defined scenario
simulation_state, heightmap, waypoint_route = generate_scenario(
    key=jax.random.PRNGKey(4), 
    initial_conditions=InitialConditions.showcase(), 
    terrain_config=MapConfig().terrain
)

In [None]:
# --- Code to view a scenario in 3D with Plotly ---

import numpy as np
import plotly.graph_objects as go
from nimbus.core import quaternion


def plot_terrain_route_aircraft(
    simulation_state, 
    heightmap, 
    waypoint_route, 
    map_config,
):
    # Terrain (Z = up), transpose required
    heightmap_np = np.array(heightmap.T, dtype=np.float32)
    terrain_elevation = (heightmap_np - 0.5) * 2.0 * float(map_config.terrain_height)

    rows, cols = heightmap_np.shape
    half_size = float(map_config.size) / 2.0
    north = np.linspace(-half_size, half_size, cols, dtype=np.float32)
    east = np.linspace(-half_size, half_size, rows, dtype=np.float32)
    X, Y = np.meshgrid(north, east)

    # Positions (NED -> plotting with up = -down)
    aircraft_ned = np.array(simulation_state.aircraft.body.position, dtype=np.float32)
    ax, ay, az_up = aircraft_ned[0], aircraft_ned[1], -aircraft_ned[2]

    route_ned = np.array(waypoint_route.positions, dtype=np.float32)
    rx, ry, rz_up = route_ned[:, 0], route_ned[:, 1], -route_ned[:, 2]

    # Orientation from quaternion (body +X forward)
    q = np.array(simulation_state.aircraft.body.orientation, dtype=np.float32)
    forward_body = np.array([1.0, 0.0, 0.0], dtype=np.float32)
    forward_world_ned = np.array(quaternion.to_rotation_matrix(q)) @ forward_body
    forward_world_plot = np.array(
        [forward_world_ned[0], 
         forward_world_ned[1], 
         -forward_world_ned[2]],
         dtype=np.float32
    )
    forward_world_plot /= np.linalg.norm(forward_world_plot) + 1e-6
    u, v, w = (forward_world_plot * 250.0)

    fig = go.Figure()

    # Terrain surface
    fig.add_trace(
        go.Surface(
            x=X, y=Y, z=terrain_elevation,
            colorscale="ice",
            showscale=False,
            opacity=0.96,
            name="Terrain",
        )
    )

    # Waypoints
    fig.add_trace(
        go.Scatter3d(
            x=rx, y=ry, z=rz_up,
            mode="markers+lines",
            marker=dict(size=6, color="white"),
            line=dict(color="white", width=3),
            name="Waypoints",
        )
    )

    # Aircraft
    fig.add_trace(
        go.Cone(
            x=[ax], y=[ay], z=[az_up],
            u=[u], v=[v], w=[w],
            anchor="tail",
            colorscale=[[0, "black"], [1, "blue"]],
            showscale=False,
            sizemode="absolute",
            sizeref=350.0,
            name="Aircraft",
        )
    )

    # Plot formatting
    fig.update_layout(
        template="plotly_dark",
        title="Scenario 3D View",
        scene=dict(
            xaxis_title="North (m)",
            yaxis_title="East (m)",
            zaxis_title="Up (m)",
            aspectmode="data",
            xaxis=dict(showbackground=False),
            yaxis=dict(showbackground=False),
            zaxis=dict(showbackground=False),
            camera=dict(eye=dict(x=0.7, y=0.9, z=0.55)),
        ),
        margin=dict(l=0, r=0, t=50, b=0),
    )
    return fig

In [None]:
fig = plot_terrain_route_aircraft(
    simulation_state=simulation_state,
    heightmap=heightmap,
    waypoint_route=waypoint_route,
    map_config=MapConfig(),
)
fig.show()

#### Custom Scenario

In [None]:
from nimbus.core.scenario import Fixed, Uniform


# Position specified in North East Down world-frame [m]
aircraft_position = (
    Fixed(1000.0),          # place at 1000m North from centre
    Fixed(-500.0),          # place at 500m West from centre
    Uniform(0.0, -1000.0),  # place somewhere between 0m and 1000m altitude
)

# Velocity specified in North East Down world-frame [m/s]
aircraft_velocity = (
    Fixed(-100.0),  # flying South at 100 m/s
    Fixed(0.0), 
    Fixed(0.0),
)


# Orientation specified as yaw, pitch, roll [degrees]
aircraft_orientation = (
    Fixed(180.0),  # facing south
    Fixed(2.0),    # pitched slightly up
    Fixed(0.0),    # level wings
)

# Waypoint positions defined in NED world-frame [m]
waypoint_positions = (
    (Fixed(-1500.0), Fixed(0.0), Fixed(-1000.0)),
    (Fixed(-2000.0), Fixed(0.0), Fixed(-1400.0)),
    (Fixed(-2500.0), Fixed(0.0), Fixed(-1000.0)),
    (Fixed(-3000.0), Fixed(0.0), Fixed(-600.0)),
)

# Define initial scenario conditions from data specified above
custom_initial_conditions = InitialConditions(
    position=aircraft_position,
    velocity=aircraft_velocity,
    orientation_euler=aircraft_orientation,
    angular_velocity=(Fixed(0.0), Fixed(0.0), Fixed(0.0)),
    wind_speed=Fixed(0.0),
    wind_direction=Fixed(0.0),
    waypoints=waypoint_positions
)

# Higher mountains
map_config = MapConfig(terrain_height=3000.0)


# Generate custom scenario
simulation_state, heightmap, waypoint_route = generate_scenario(
    key=jax.random.PRNGKey(0), 
    initial_conditions=custom_initial_conditions,
    terrain_config=map_config.terrain
)

fig = plot_terrain_route_aircraft(
    simulation_state=simulation_state,
    heightmap=heightmap,
    waypoint_route=waypoint_route,
    map_config=map_config,
)
fig.show()

#### Batch Generate 100 Scenarios

Different RNG so different possible random aircraft and waypoint placements.

In [None]:
from nimbus.core.scenario import generate_simulation

simulation_states = jax.vmap(generate_simulation, in_axes=(0, None))(
    jax.random.split(jax.random.PRNGKey(0), num=100), 
    InitialConditions.default(), 
)

In [None]:
### Nice visualisation of different scenarios here ###

### VMAP Capabilities

- Demonstrate capability to VMAP 1,000,000+ aircraft in parallel

In [None]:
import jax
from nimbus.core.config import SimulationConfig
from nimbus.core.scenario import InitialConditions, generate_simulation
from nimbus.core.simulation import step


# Step arguments
key = jax.random.PRNGKey(0)
simulation_states = jax.vmap(generate_simulation, in_axes=(0, None))(
    jax.random.split(key, num=100_000), 
    InitialConditions.default(), 
)
config = SimulationConfig()

# Set config to static (faster)
step_fn = jax.jit(step, static_argnames=("config"))

# Map across second argument (simulation states)
step_vmap = jax.vmap(step_fn, in_axes=(None, 0, None, None, None))

# Vectorised step
step_vmap(key, simulation_states, heightmap, waypoint_route, )


### Map Generation and VMAP Over Maps

- Demonstrate generating multiple maps and VMAP over them
- Point out that maps use considerably more memory than the rest of the simulation state so am much more limited in what you can parallelise

### VMAP Different Aircraft Configs

### VMAP Performance

### Wind Generation