# U-Net–MF comparison

Create EMIT MF config for comparison with U-Net over "ground truth" sites.

In [1]:
import datetime
from math import ceil

import fsspec
import joblib
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr

from src.inference.inference_functions import prepare_data_item, predict
from src.training.loss_functions import TwoPartLoss
from src.utils.parameters import SatelliteID
from src.utils.utils import load_model_and_concatenator

In [2]:
site_granule_map = pd.read_csv("emit_gt_granule_map.csv", index_col="dual_index")

### Load in model etc for inference

In [None]:
model_name = "models:/emit/8"

In [None]:
model, band_extractor, model_params = load_model_and_concatenator(
    model_name, device="cpu", satellite_id=SatelliteID.EMIT
)
lossFn = TwoPartLoss(binary_threshold=model_params["binary_threshold"], MSE_multiplier=model_params["MSE_multiplier"])

### Load in MF results

In [11]:
with fsspec.open(
    "azureml://subscriptions/6e71ce37-b9fe-4c43-942b-cf0f7e78c8ab/resourcegroups/orbio-ml-rg/workspaces/"
    "orbio-ml-ml-workspace/datastores/workspaceblobstore/paths/data/emit/emit_mf_gt_retrievals.nc"
) as fs:
    mfda = xr.open_dataset(fs)["mf_retrievals"]

In [4]:
def plot_single_site_comparison(site: pd.Series) -> plt.Figure:
    # Get our cached Radiance for running inference
    cv_rad_cache_uri = (
        "azureml://subscriptions/6e71ce37-b9fe-4c43-942b-cf0f7e78c8ab/resourcegroups/orbio-ml-rg/workspaces/orbio-ml-ml-workspace/"
        f"datastores/workspaceblobstore/paths/data/emit/crop_cache/{site.emit_id}/{site.lat}_{site.lon}_128.joblib.gz"
    )

    with fsspec.open(cv_rad_cache_uri) as fs:
        cache = joblib.load(fs)

    # Prepare the data for out model
    data_item = prepare_data_item(
        cropped_data=[cache],
        crop_size=cache["crop_params"]["out_height"],
        satellite=SatelliteID.EMIT,
    )

    # Run inference
    prediction = predict(
        model=model, device="cpu", band_extractor=band_extractor, recycled_item=data_item, lossFn=lossFn
    )

    # Plot
    pred_plot_keys = ["binary_probability", "conditional_pred", "marginal_pred"]
    n_mf_plots = 1

    n_cols = 4
    n_rows = ceil((len(pred_plot_keys) + n_mf_plots) / n_cols)

    fig_scaling = 4

    fig = plt.figure(figsize=(n_cols * fig_scaling, n_rows * fig_scaling))

    # CV plots

    for i, k in enumerate(pred_plot_keys, start=1):
        data = prediction[k].squeeze()

        ax = fig.add_subplot(n_rows, n_cols, i)
        im = ax.imshow(data)
        plt.colorbar(im, ax=ax)
        ax.axis("off")
        ax.set_title(f"CV {k}")

    # MF plots

    mf_retrieval = mfda.sel(dual_index=site.name).squeeze()

    ax = fig.add_subplot(n_rows, n_cols, i + 1)
    im = ax.imshow(mf_retrieval)
    plt.colorbar(im, ax=ax)
    ax.axis("off")
    ax.set_title("MF retrieval")

    props_str = ", ".join(f"{k}: {site[k]}" for k in ["quantification_kg_h", "source", "lat", "lon", "date"])
    fig.text(0.01, 0.02, f"{props_str}; Warning: MF retrieval has different orientation due to orthorectification.")

    fig.suptitle(f"{model_name.split('/', 1)[1]}: {site.name}")
    plt.tight_layout()

    return fig

### Check single site

In [None]:
site = site_granule_map.iloc[2]
granule_id = site["emit_id"]
date = datetime.datetime.fromisoformat(site["date"])

In [None]:
_ = plot_single_site_comparison(site)

### Plot all sites

For a single model.

In [None]:
%%time

model_cleanname = model_name.split("/", 1)[1].replace("/", "_")
out_dir = f"figures/{model_cleanname}"
os.makedirs(out_dir, exist_ok=True)

for _, site in site_granule_map.iterrows():
    try:
        fig = plot_single_site_comparison(site)
        fig.savefig(f"{out_dir}/{site.name}.png")
    except KeyError:
        print(f"Failed to plot {site.name}. Skipping.")

    plt.close()

### Compare preplotted

Requires that we plotted the ground truth sites for a given model in the above section.

May require installing `ipywidgets`: `pip install ipywidgets`

In [7]:
import ipywidgets as widgets
from IPython.display import display

In [5]:
# The registered model name and id with underscores replacing slashes
compare_models = [
    # "torchgeo_pwr_unet_emit_54",  # gray_egg: unet, resnet50, but pretrained
    "torchgeo_pwr_unet_emit_56",  # nifty_collar: unet++ b1
    # "torchgeo_pwr_unet_emit_58",  # gray_cat: unet,resnet50,no pretraining
    "emit_8",  # upbeat_rose spectralunet++ b1 (training in progress)
]

In [8]:
current_iloc = 0
max_iloc = site_granule_map.shape[0] - 1

back_button = widgets.Button(tooltip="Previous", icon="arrow-left")

next_button = widgets.Button(tooltip="Next", icon="arrow-right")

image_widgets = {}
for model_name in compare_models:
    image_widgets[model_name] = widgets.Image(format="png", width=900, height=300)


def render_comparison_images(site: pd.Series):
    for model_name, image_widget in image_widgets.items():
        with open(f"figures/{model_name}/{site.name}.png", "rb") as fs:
            imagebytes = fs.read()
        image_widget.value = imagebytes


output = widgets.Output()
with output:
    print(f"Image {current_iloc + 1} of {max_iloc + 1}")
    display(widgets.VBox([*image_widgets.values()]))


def render_next(*args):
    global current_iloc, site_granule_map, image_widgets
    with output:
        output.clear_output()
        current_iloc = current_iloc + 1 if current_iloc < max_iloc else 0
        print(f"Image {current_iloc + 1} of {max_iloc + 1}")
        render_comparison_images(site_granule_map.iloc[current_iloc])
        for image_widget in image_widgets.values():
            display(image_widget)


def render_previous(*args):
    global current_iloc, site_granule_map, image_widgets
    with output:
        output.clear_output()
        current_iloc = current_iloc - 1 if current_iloc > 0 else max_iloc
        print(f"Image {current_iloc + 1} of {max_iloc + 1}")
        render_comparison_images(site_granule_map.iloc[current_iloc])
        for image_widget in image_widgets.values():
            display(image_widget)


back_button.on_click(render_previous)
next_button.on_click(render_next)

render_comparison_images(site_granule_map.iloc[current_iloc])

buttons = widgets.HBox([back_button, next_button])

display(buttons, widgets.Box([output]))

HBox(children=(Button(icon='arrow-left', style=ButtonStyle(), tooltip='Previous'), Button(icon='arrow-right', …

Box(children=(Output(),))

`torchgeo_pwr_unet_emit_` 54 vs 56 vs 58: 56 appears to be the best model so far:
* less background noise
* only one to get images 10, 25
* nice example of fewer FPs in images 20, 21, 27
* all models do poorly: 26
* but 56 is the least confident on image 2!

Matched filter has "plume-ier" plumes but also worse SNR, so hard to say if they could be extracted.

`emit_8` vs `torchgeo_pwr_unet_emit_56`

- generally pretty much the same (but 8 appears to have trained much faster!)
- 56 does better: image 29, 34, 37
- 8 does better: 33, 42, 54
- all do poorly: 24