# Gradio Application

This notebook explores integrating `nrtk` with `Gradio` to create an interactive interface for applying perturbations.

## Table of Contents

* [Environment Setup](#environment-setup)
* [Defining the Application](#defining-application)
* [Running the Application](#running-application)

## Environment Setup <a name="environment-setup"></a>

In [1]:
import sys  # noqa: F401

!{sys.executable} -m pip install -qU pip
print("Installing nrtk...")
!{sys.executable} -m pip install -q nrtk
print("Installing headless OpenCV...")
!{sys.executable} -m pip uninstall -qy opencv-python opencv-python-headless  # make sure they're both gone.
!{sys.executable} -m pip install -q opencv-python-headless
print("Installing gradio...")
!{sys.executable} -m pip install -q gradio
# Restricting fastapi version due to import gradio issue (typing_extension error for py3.8, py3.9)
print("Installing fastapi...")
!{sys.executable} -m pip install -q "fastapi==0.100.0"
print("Done!")

Installing nrtk...
Installing headless OpenCV...
Installing gradio...
Installing fastapi...
Done!


In [2]:
from pathlib import Path

import numpy as np
import yaml

from nrtk.impls.perturb_image.pybsm.perturber import PybsmPerturber
from nrtk.impls.perturb_image.pybsm.scenario import PybsmScenario
from nrtk.impls.perturb_image.pybsm.sensor import PybsmSensor

## Defining the Application <a name="defining-application"></a>

pyBSM has many configuration parameters. Here we define defaults so the user doesn't have to provide every single value. 

Note, for this application we'll be focusing on the use case of satellite imagery, so we'll try to pick default values that work with images similar to those in the xView dataset. Perturbing images from other datasets or other operational tasks may not be successful without modification to these values; defining broad defaults is extremely difficult, if not impossible due to the physics-based nature of these perturbations.

In [3]:
# Define default values for fields for initilization and button clicks
default_values = {
    "gsd": 0.3,
    "scenario": {
        "aircraft_speed": 100.0,
        "altitude": 4000.0,
        "background_reflectance": 0.07,
        "background_temperature": 293.0,
        "cn2_at_1m": 1.7e-14,
        "ground_range": 1000,
        "ha_wind_speed": 21.0,
        "ihaze": 1,
        "name": "",
        "target_reflectance": 0.15,
        "target_temperature": 295.0,
    },
    "sensor": {
        "D": 0.029,
        "bit_depth": 11.9,
        "dark_current": 0.0,
        "da_x": 0.0001,
        "da_y": 0.0001,
        "eta": 0.4,
        "f": 0.27,
        "int_time": 0.03,
        "max_n": 96000,
        "max_well_fill": 0.6,
        "name": "",
        "opt_trans_wavelengths": [3.8e-07, 7.0e-07],
        "optics_transmission": [],
        "p_x": 2.0e-05,
        "qe": [0.05, 0.6, 0.75, 0.85, 0.85, 0.75, 0.5, 0.2, 0],
        "qe_wavelengths": [3.0e-7, 4.0e-7, 5.0e-7, 6.0e-7, 7.0e-7, 8.0e-7, 9.0e-7, 1.0e-6, 1.1e-6],
        "read_noise": 25.0,
        "s_x": 0.0,
        "s_y": 0.0,
    },
}

It may also be helpful to start from an existing configuration or to generate a new configuration from the input fields, so we'll define some functions to enable that capability:

In [None]:
from gradio import Error  # type: ignore


def load_config(config: dict) -> dict:
    """Loads configuration from config dictionary to UI elements."""
    scenario_config = config["scenario"]
    default_scenario = default_values["scenario"]
    sensor_config = config["sensor"]
    default_sensor = default_values["sensor"]
    return {
        input_group: Group(visible=True),
        img_gsd: config["gsd"] if "gsd" in config else default_values["gsd"],
        scenario_name: scenario_config["name"] if "name" in scenario_config else default_scenario["name"],
        ihaze: scenario_config["ihaze"] if "ihaze" in scenario_config else default_scenario["ihaze"],
        altitude_m: scenario_config["altitude"] if "altitude" in scenario_config else default_scenario["altitude"],
        ground_range_m: (
            scenario_config["ground_range"] if "ground_range" in scenario_config else default_scenario["ground_range"]
        ),
        aircraft_speed_m_per_s: (
            scenario_config["aircraft_speed"]
            if "aircraft_speed" in scenario_config
            else default_scenario["aircraft_speed"]
        ),
        target_reflectance: (
            scenario_config["target_reflectance"]
            if "target_reflectance" in scenario_config
            else default_scenario["target_reflectance"]
        ),
        target_temperature_k: (
            scenario_config["target_temperature"]
            if "target_temperature" in scenario_config
            else default_scenario["target_temperature"]
        ),
        bkgd_reflectance: (
            scenario_config["background_reflectance"]
            if "background_reflectance" in scenario_config
            else default_scenario["background_reflectance"]
        ),
        bkgd_temperature_k: (
            scenario_config["background_temperature"]
            if "background_temperature" in scenario_config
            else default_scenario["background_temperature"]
        ),
        ha_windspeed_m_per_s: (
            scenario_config["ha_wind_speed"]
            if "ha_wind_speed" in scenario_config
            else default_scenario["ha_wind_speed"]
        ),
        cn2_at_1m: scenario_config["cn2_at_1m"] if "cn2_at_1m" in scenario_config else default_scenario["cn2_at_1m"],
        sensor_name: sensor_config["name"] if "name" in sensor_config else default_sensor["name"],
        D_m: sensor_config["D"] if "D" in sensor_config else default_sensor["D"],
        f_m: sensor_config["f"] if "f" in sensor_config else default_sensor["f"],
        p_x_m: sensor_config["p_x"] if "p_x" in sensor_config else default_sensor["p_x"],
        # Convert lists to comma separated text
        opt_trans_wavelengths_str: (
            ", ".join(str(x) for x in sensor_config["opt_trans_wavelengths"])
            if "opt_trans_wavelengths" in sensor_config and sensor_config["opt_trans_wavelengths"]
            else ""
        ),
        optics_transmission_str: (
            ", ".join(str(x) for x in sensor_config["optics_transmission"])
            if "optics_transmission" in sensor_config and sensor_config["optics_transmission"]
            else ""
        ),
        eta: sensor_config["eta"] if "eta" in sensor_config else default_sensor["eta"],
        int_time_s: sensor_config["int_time"] if "int_time" in sensor_config else default_sensor["int_time"],
        dark_current: (
            sensor_config["dark_current"] if "dark_current" in sensor_config else default_sensor["dark_current"]
        ),
        read_noise: sensor_config["read_noise"] if "read_noise" in sensor_config else default_sensor["read_noise"],
        max_n: sensor_config["max_n"] if "max_n" in sensor_config else default_sensor["max_n"],
        bit_depth: sensor_config["bit_depth"] if "bit_depth" in sensor_config else default_sensor["bit_depth"],
        max_well_fill: (
            sensor_config["max_well_fill"] if "max_well_fill" in sensor_config else default_sensor["max_well_fill"]
        ),
        s_x: sensor_config["s_x"] if "s_x" in sensor_config else default_sensor["s_x"],
        s_y: sensor_config["s_y"] if "s_y" in sensor_config else default_sensor["s_y"],
        da_x: sensor_config["da_x"] if "da_x" in sensor_config else default_sensor["da_x"],
        da_y: sensor_config["da_y"] if "da_y" in sensor_config else default_sensor["da_y"],
        # Convert lists to comma separated text
        qe_str: ", ".join(str(x) for x in sensor_config["qe"]) if "qe" in sensor_config and sensor_config["qe"] else "",
        qe_wavelengths_str: (
            ", ".join(str(x) for x in sensor_config["qe_wavelengths"])
            if "qe_wavelengths" in sensor_config and sensor_config["qe_wavelengths"]
            else ""
        ),
    }


def _generate_config_error_check(data: dict) -> None:
    """Input validation."""
    if data[ihaze] not in PybsmScenario.ihaze_values:
        raise Error("Invalid ihaze value!")
    if data[altitude_m] not in PybsmScenario.altitude_values:
        raise Error("Invalid altitude value!")
    if data[ground_range_m] not in PybsmScenario.ground_range_values:
        raise Error("Invalid ground range value!")


def generate_config(data: dict) -> dict:
    """Generate dictionary that can easily be transformed into sensor and scenario objects."""
    _generate_config_error_check(data)

    scenario_config = {
        "name": data[scenario_name],
        "ihaze": data[ihaze],
        "altitude": data[altitude_m],
        "ground_range": data[ground_range_m],
        "aircraft_speed": data[aircraft_speed_m_per_s],
        "target_reflectance": data[target_reflectance],
        "target_temperature": data[target_temperature_k],
        "background_reflectance": data[bkgd_reflectance],
        "background_temperature": data[bkgd_temperature_k],
        "ha_wind_speed": data[ha_windspeed_m_per_s],
        "cn2_at_1m": data[cn2_at_1m],
    }

    # Convert text fields into lists of floats
    opt_trans_wavelengths = [float(w.strip()) for w in data[opt_trans_wavelengths_str].split(",") if w.strip()]
    optics_transmission = [float(t.strip()) for t in data[optics_transmission_str].split(",") if t.strip()]
    optics_transmission = optics_transmission if optics_transmission else None
    qe = [float(q.strip()) for q in data[qe_str].split(",") if q.strip()]
    qe = qe if qe else None
    qe_wavelengths = [float(q.strip()) for q in data[qe_wavelengths_str].split(",") if q.strip()]
    qe_wavelengths = qe_wavelengths if qe_wavelengths else None

    # More input validation
    if len(opt_trans_wavelengths) < 2:
        raise Error("At least 2 optical transmission wavelengths required!")
    if opt_trans_wavelengths[0] >= opt_trans_wavelengths[-1]:
        raise Error("Optical transmission wavelengths should be entered least to greatest!")
    if optics_transmission is not None and len(optics_transmission) != len(opt_trans_wavelengths):
        raise Error("If provided, Optical Transmission must have the same number of values as Spectral Bandpass!")

    sensor_config = {
        "name": data[sensor_name],
        "D": data[D_m],
        "f": data[f_m],
        "p_x": data[p_x_m],
        "opt_trans_wavelengths": opt_trans_wavelengths,
        "optics_transmission": optics_transmission,
        "eta": data[eta],
        "int_time": data[int_time_s],
        "dark_current": data[dark_current],
        "read_noise": data[read_noise],
        "max_n": data[max_n],
        "bit_depth": data[bit_depth],
        "max_well_fill": data[max_well_fill],
        "s_x": data[s_x],
        "s_y": data[s_y],
        "da_x": data[da_x],
        "da_y": data[da_y],
        "qe": qe,
        "qe_wavelengths": qe_wavelengths,
    }

    return {"scenario": scenario_config, "sensor": sensor_config, "gsd": data[img_gsd]}

To utilize these capabilities as well as to give a way to actually apply a perturbation, we'll provide buttons for the user to click. These buttons require listeners to carry out a task, so we define those here:

In [None]:
from gradio import Column, Info  # type: ignore


def gen_new_config() -> dict:
    """Resets all fields to default values."""
    return load_config(default_values)


def load_config_from_file(data: dict) -> dict:
    """Loads configuration from given file to UI elements."""
    if not data[config_file]:
        raise Error("A file must be uploaded to load existing configuration!")
    with open(data[config_file]) as file:
        config = yaml.safe_load(file)

    return load_config(config)


def submit(data: dict) -> dict:
    """Apply the perturbation and hide/show relevant UI elements as needed."""
    config = generate_config(data)
    scenario_config = config["scenario"]
    sensor_config = config["sensor"]

    # Sensor expects numpy arrays, but plain lists serialize better so convert here
    if sensor_config["opt_trans_wavelengths"]:
        sensor_config["opt_trans_wavelengths"] = np.asarray(sensor_config["opt_trans_wavelengths"])
    if sensor_config["optics_transmission"]:
        sensor_config["optics_transmission"] = np.asarray(sensor_config["optics_transmission"])
    if sensor_config["qe"]:
        sensor_config["qe"] = np.asarray(sensor_config["qe"])
    if sensor_config["qe_wavelengths"]:
        sensor_config["qe_wavelengths"] = np.asarray(sensor_config["qe_wavelengths"])
    gsd = config["gsd"]

    sensor = PybsmSensor(**sensor_config)
    scenario = PybsmScenario(**scenario_config)
    perturber = PybsmPerturber(sensor=sensor, scenario=scenario)

    # Apply the perturbation and display
    return {
        output_col: Column(visible=True),
        out_img: perturber(image=data[input_img], additional_params={"img_gsd": gsd}),
    }


def save(data: dict) -> None:
    """Saves current configuration at given path."""
    if not data[file_path]:
        raise Error("A filename must be provided to save the configuration!")
    path = Path(data[file_path])
    path.parent.mkdir(parents=True, exist_ok=True)
    config = generate_config(data)
    with open(path, "w") as yaml_file:
        yaml.dump(config, yaml_file)
    Info(f"Saved config: {data[file_path]}")

Lastly, we define the layout of the application and register the button click listener functions:

In [6]:
from gradio import (  # type: ignore
    Accordion,  # type: ignore
    Blocks,
    Button,
    Dropdown,  # type: ignore
    Examples,
    File,  # type: ignore
    Group,
    Image,
    Number,
    Row,
    Textbox,
)

with Blocks() as demo:
    with Row():
        with Column() as input_col:
            sample_img_path = "../daml/data/unperturbed/92_1920_2201_2432_2713.jpg"
            input_img = Image(label="Input Image", value=sample_img_path)

            Examples(
                examples=[
                    sample_img_path,
                    "../daml/data/unperturbed/99_384_0_896_512.jpg",
                    "../daml/data/unperturbed/125_1152_768_1664_1280.jpg",
                    "../daml/data/unperturbed/126_1920_1920_2432_2432.jpg",
                ],
                inputs=input_img,
            )

            with Row():
                gen_config_btn = Button("Generate New Configuration")
                with Column():
                    config_file = File(label="Configuration File", file_types=[".yaml"])
                    load_config_btn = Button("Load Configuration from File")

            with Group(visible=False) as input_group:
                with Accordion("Image Parameters", open=False):
                    img_gsd = Number(
                        label="Image Ground Sample Distance (GSD) (m)",
                        info="The size of one pixel on the ground",
                        value=default_values["gsd"],
                    )

                with Accordion("Scenario Parameters") as scenario_params:
                    altitude_m = Number(
                        label="Altitude (m)",
                        info="Sensor height above ground level in meters. The database includes the following "
                        "altitude options: 2m 32.55m 75m 150m 225m 500m, 1000m to 12000m in 1000m steps, "
                        "14000m to 20000m in 2000m steps, and 24500m",
                        value=default_values["scenario"]["altitude"],
                    )
                    ground_range_m = Number(
                        label="Ground Range (m)",
                        info="Distance on the ground between the target and sensor in meters. The following "
                        "ground ranges are included in the database at each altitude until the ground "
                        "range exceeds the distance to the spherical earth horizon: 0m 100m 500m, 1000m to "
                        "20000m in 1000m steps, 22000m to 80000m in 2000m steps, and 85000m to "
                        "300000m in 5000m steps.",
                        value=default_values["scenario"]["ground_range"],
                    )

                    with Accordion("Additional Scenario Parameters", open=False) as opt_scenario_params:
                        scenario_name = Textbox(label="Scenario Name", value=default_values["scenario"]["name"])
                        ihaze = Dropdown(
                            label="IHAZE",
                            info="MODTRAN code for visibility",
                            choices=[1, 2],
                            value=default_values["scenario"]["ihaze"],
                        )
                        aircraft_speed_m_per_s = Number(
                            label="Aircraft Speed (m/s)",
                            info="Ground speed of the aircraft",
                            value=default_values["scenario"]["aircraft_speed"],
                        )
                        with Row():
                            target_reflectance = Number(
                                label="Target Reflectance",
                                info="Object reflectance",
                                value=default_values["scenario"]["target_reflectance"],
                            )
                            target_temperature_k = Number(
                                label="Target Temperature (K)",
                                info="Object temperature (Kelvin)",
                                value=default_values["scenario"]["target_temperature"],
                            )
                        with Row():
                            bkgd_reflectance = Number(
                                label="Background Reflectance",
                                info="Background reflectance",
                                value=default_values["scenario"]["background_reflectance"],
                            )
                            bkgd_temperature_k = Number(
                                label="Background Temperature (K)",
                                info="Background temperature (Kelvin)",
                                value=default_values["scenario"]["background_temperature"],
                            )
                        ha_windspeed_m_per_s = Number(
                            label="High Altitude Windspeed (m/s)",
                            info="Used to calculate the turbulence profile",
                            value=default_values["scenario"]["ha_wind_speed"],
                        )
                        cn2_at_1m = Number(
                            label="Refractive Index Structure Parameter",
                            info='The refractive index structure parameter "near the ground" (e.g. '
                            "at h = 1m). Used to calculate the turbulence profile",
                            value=default_values["scenario"]["cn2_at_1m"],
                        )

                with Accordion("Sensor Parameters") as sensor_params:
                    D_m = Number(label="Effective Aperture Diameter (m)", value=default_values["sensor"]["D"])
                    f_m = Number(label="Focal Length (m)", value=default_values["sensor"]["f"])
                    with Accordion("Additional Sensor Parameters", open=False) as opt_sensor_parameters:
                        sensor_name = Textbox(label="Sensor Name", value=default_values["sensor"]["name"])
                        p_x_m = Number(
                            label="Detector Center-to-Center Spacing (Pitch) (m)",
                            value=default_values["sensor"]["p_x"],
                        )
                        opt_trans_wavelengths_str = Textbox(
                            label="Spectral Bandpass of the Camera (m)",
                            info="""Enter a comma separated list.
                            At minimum, a start and end wavelength should be specified""",
                            value=(
                                ", ".join(map(str, default_values["sensor"]["opt_trans_wavelengths"]))
                                if default_values["sensor"]["opt_trans_wavelengths"]
                                else ""
                            ),
                        )
                        optics_transmission_str = Textbox(
                            label="Full System In-Band Optical Transmission",
                            info="""Enter a comma separated list.
                            Loss due to any telescope obscuration should not be included""",
                            value=(
                                ", ".join(map(str, default_values["sensor"]["optics_transmission"]))
                                if default_values["sensor"]["optics_transmission"]
                                else ""
                            ),
                        )
                        eta = Number(label="Relative Linear Obscuration", value=default_values["sensor"]["eta"])
                        int_time_s = Number(
                            label="Integration Time (s)",
                            info="Maximum integration time",
                            value=default_values["sensor"]["int_time"],
                        )
                        dark_current = Number(
                            label="Detector Dark Current (e-/s)",
                            value=default_values["sensor"]["dark_current"],
                        )
                        read_noise = Number(
                            label="RMS Read Noise (RMS e-)",
                            value=default_values["sensor"]["read_noise"],
                        )
                        max_n = Number(label="Maximum ADC Level (e-)", value=default_values["sensor"]["max_n"])
                        bit_depth = Number(
                            label="Bit Depth (bits)",
                            info="Resolution of the detector ADC",
                            value=default_values["sensor"]["bit_depth"],
                        )
                        max_well_fill = Number(
                            label="Max Well Fill",
                            info="Desired well fill. i.e. maximum well size x desired fill fraction",
                            value=default_values["sensor"]["max_well_fill"],
                        )
                        with Row():
                            s_x = Number(
                                label="RMS Jitter Amplitude, X Direction (rad)",
                                value=default_values["sensor"]["s_x"],
                            )
                            s_y = Number(
                                label="RMS Jitter Amplitude, Y Direction (rad)",
                                value=default_values["sensor"]["s_y"],
                            )
                        with Row():
                            da_x = Number(
                                label="Line of Sight Angular Drift Rate, X Direction (rad/s)",
                                info="Drift rate during one integration time",
                                value=default_values["sensor"]["da_x"],
                            )
                            da_y = Number(
                                label="Line of Sight Angular Drift Rate, Y Direction (rad/s)",
                                info="Drift rate during one integration time",
                                value=default_values["sensor"]["da_y"],
                            )
                        qe_str = Textbox(
                            label="Quantum Efficiency as a function of Wavelength (e-/photon)",
                            info="Enter a comma separated list",
                            value=(
                                ", ".join(map(str, default_values["sensor"]["qe"]))
                                if default_values["sensor"]["qe"]
                                else ""
                            ),
                        )
                        qe_wavelengths_str = Textbox(
                            label="Wavelengths Corresponding to the Quantum Efficiency Array (microns)",
                            info="Enter a comma separated list",
                            value=(
                                ", ".join(map(str, default_values["sensor"]["qe_wavelengths"]))
                                if default_values["sensor"]["qe_wavelengths"]
                                else ""
                            ),
                        )

                submit_btn = Button("Perturb Image")

        with Column(visible=False) as output_col:
            out_img = Image(label="Perturbed Image")
            file_path = Textbox(label="Config Filename")
            save_btn = Button("Save Configuration")

    # Button listeners
    gen_config_btn.click(
        fn=gen_new_config,
        inputs=None,
        outputs=[
            input_group,
            img_gsd,
            scenario_name,
            ihaze,
            altitude_m,
            ground_range_m,
            aircraft_speed_m_per_s,
            target_reflectance,
            target_temperature_k,
            bkgd_reflectance,
            bkgd_temperature_k,
            ha_windspeed_m_per_s,
            cn2_at_1m,
            sensor_name,
            D_m,
            f_m,
            p_x_m,
            opt_trans_wavelengths_str,
            optics_transmission_str,
            eta,
            int_time_s,
            dark_current,
            read_noise,
            max_n,
            bit_depth,
            max_well_fill,
            s_x,
            s_y,
            da_x,
            da_y,
            qe_str,
            qe_wavelengths_str,
        ],
    )
    load_config_btn.click(
        fn=load_config_from_file,
        inputs={config_file},
        outputs=[
            input_group,
            img_gsd,
            scenario_name,
            ihaze,
            altitude_m,
            ground_range_m,
            aircraft_speed_m_per_s,
            target_reflectance,
            target_temperature_k,
            bkgd_reflectance,
            bkgd_temperature_k,
            ha_windspeed_m_per_s,
            cn2_at_1m,
            sensor_name,
            D_m,
            f_m,
            p_x_m,
            opt_trans_wavelengths_str,
            optics_transmission_str,
            eta,
            int_time_s,
            dark_current,
            read_noise,
            max_n,
            bit_depth,
            max_well_fill,
            s_x,
            s_y,
            da_x,
            da_y,
            qe_str,
            qe_wavelengths_str,
        ],
    )
    submit_btn.click(
        fn=submit,
        inputs={
            input_img,
            img_gsd,
            scenario_name,
            ihaze,
            altitude_m,
            ground_range_m,
            aircraft_speed_m_per_s,
            target_reflectance,
            target_temperature_k,
            bkgd_reflectance,
            bkgd_temperature_k,
            ha_windspeed_m_per_s,
            cn2_at_1m,
            sensor_name,
            D_m,
            f_m,
            p_x_m,
            opt_trans_wavelengths_str,
            optics_transmission_str,
            eta,
            int_time_s,
            dark_current,
            read_noise,
            max_n,
            bit_depth,
            max_well_fill,
            s_x,
            s_y,
            da_x,
            da_y,
            qe_str,
            qe_wavelengths_str,
        },
        outputs=[out_img, output_col],
    )
    save_btn.click(
        fn=save,
        inputs={
            input_img,
            img_gsd,
            scenario_name,
            ihaze,
            altitude_m,
            ground_range_m,
            aircraft_speed_m_per_s,
            target_reflectance,
            target_temperature_k,
            bkgd_reflectance,
            bkgd_temperature_k,
            ha_windspeed_m_per_s,
            cn2_at_1m,
            sensor_name,
            D_m,
            f_m,
            p_x_m,
            opt_trans_wavelengths_str,
            optics_transmission_str,
            eta,
            int_time_s,
            dark_current,
            read_noise,
            max_n,
            bit_depth,
            max_well_fill,
            s_x,
            s_y,
            da_x,
            da_y,
            qe_str,
            qe_wavelengths_str,
            file_path,
        },
        outputs=None,
    )

## Running the Application <a name="running-application"></a>

Now our application is ready for exploration!

In [None]:
demo.launch(show_error=True)