# Mass product SBR phase 1 outputs

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import io
import json
from datetime import datetime

import matplotlib.pyplot as plt
import pandas as pd
import rasterio
import torch
from lib.models.schemas import WatershedParameters
from tqdm import tqdm

from sbr_2025.utils.quantification import (
    get_wind_components,
    sbr_form_outputs,
)
from src.azure_wrap.ml_client_utils import initialize_blob_service_client
from src.inference.inference_functions import (
    crop_main_data_landsat,
    crop_reference_data_landsat,
    fetch_landsat_items_for_point,
    query_landsat_catalog_for_point,
)
from src.training.loss_functions import TwoPartLoss
from src.utils.parameters import LANDSAT_HAPI_DATA_PATH, SatelliteID
from src.utils.quantification_utils import calc_effective_wind_speed, calc_wind_direction
from src.utils.radtran_utils import RadTranLookupTable
from src.utils.utils import initialize_clients, load_model_and_concatenator

pd.set_option("display.max_colwidth", 250)

🔔🔔🔔🔔

**HOW TO UPDATE THIS NOTEBOOK**

This is a template notebook that is under version control.
When making changes, before staging and committing the changes, run `nbstripout` to remove the output from the notebook.

```bash
# conda install -c conda-forge nbstripout  # install nbstripout if not already installed
nbstripout --drop-empty-cells --extra-keys="metadata.kernelspec.display_name metadata.kernelspec.name" Landsat_SBR_export_results.ipynb
git add -p Landsat_SBR_export_results.ipynb
```

🔔🔔🔔🔔

## Setup

In [None]:
ml_client, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

In [None]:
lat, lon = 32.82175, -111.78581  # release point for SBRs

# can we predict on larger chips?
crop_size = 256  # 128
center_buffer = 3  # 12  # 5  # number of pixels to search from center

# Date range of SBRs
start_date = "2025-01-01"
end_date = "2025-04-30"  # a month after the end of the SBRs to include reference images after

# these dates have KNOWN releases during Phase 0
phase0_release_dates = [
    # "2024-11-14",
    "2024-11-15",  # LS
    "2024-11-16",  # LS
    # "2024-11-19",
    # "2024-11-22",
    "2024-11-24",  # LS
    "2024-12-02",  # LS
    # "2024-12-04",
    # "2024-12-07",
    "2024-12-09",  # LS
    # "2024-12-17",
    # "2024-12-19",
    # "2024-12-22",
    "2024-12-26",  # LS
    # "2024-12-29",
]

release_dates_2022 = [
    "2022-10-25",
    "2022-10-26",
    "2022-11-02",
    "2022-11-03",
    "2022-11-10",
    "2022-11-11",
    "2022-11-18",
]

watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.1,
    watershed_floor_threshold=0.075,
    closing_footprint_size=0,
)
item_meta_dict = {}

start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ids = ["37", "155", "55", "57", "103", "168", "154", "14"]
models = []
band_concatenators = []
for model_id in model_ids:
    model, band_concatenator, train_params = load_model_and_concatenator(
        f"models:/landsat/{model_id}", device, SatelliteID.LANDSAT
    )
    model.eval()
    models.append(model)
    band_concatenators.append(band_concatenator)

lossFn = TwoPartLoss(train_params["binary_threshold"], train_params["MSE_multiplier"])

In [None]:
stac_items = query_landsat_catalog_for_point(lat, lon, start_date, end_date)
for item in stac_items:
    cloud_cover = item.properties["eo:cloud_cover"]
    print(
        f"{item.datetime.date().isoformat()} with {cloud_cover:6.1f}% clouds - ({item.id})"
        f"and CRS {item.properties['proj:code']}"
    )
overpass_dates = [item.datetime.date().isoformat() for item in stac_items]

In [None]:
%%time

target_date = stac_items[-1].datetime.date().isoformat()
target_date = datetime.strptime(target_date, "%Y-%m-%d")

items = fetch_landsat_items_for_point(lat=lat, lon=lon, query_datetime=target_date, how_many_days_back=180)

main_data = crop_main_data_landsat(items, abs_client, s3_client, lat, lon, crop_size)
main_item = main_data["tile_item"]

# Use 25% clouds/shadows cutoff to include non-obvious reference images that may help
max_bad_pixel_perc = 25  # max % sum of clouds/cloud shadows/nodata in reference chips
num_snapshots = 40

for item in tqdm(items):
    if item.id not in item_meta_dict:
        # Get the Landsat metadata from the coastal band (arbitrary choice).
        try:
            item_meta_dict[item.id] = item.get_raster_meta("coastal", abs_client=abs_client)
        except Exception as e:
            print(e)
            continue

all_reference_data = crop_reference_data_landsat(
    items,
    main_data,
    abs_client,
    s3_client,
    lat,
    lon,
    crop_size,
    required_num_previous_snapshots=num_snapshots,
    max_bad_pixel_perc=max_bad_pixel_perc,
    item_meta_dict=item_meta_dict,
)

In [None]:
# Prepare wind data once
wind_data = pd.read_csv("../../src/data/ancillary/wind_vectors_gt_vs_models_2025_sbr_sites.csv")
wind_data = wind_data[wind_data["sensor"] == "LS"]
wind_data["sensing_time"] = pd.to_datetime(
    wind_data["date"] + " " + wind_data["overpass_time_utc"], utc=True
).dt.strftime("%Y-%m-%dT%H:%M:%S+0000")

## Produce

In [None]:
with open("./landsat_phase1_decision_dict.json") as fs:
    decisions_dict = json.load(fs)

In [None]:
%%time

max_distance_pixels = 10
pixel_width = 30

with io.StringIO() as fs:
    results_index = []
    for stac_idx, stac_item in enumerate(items):
        target_date = stac_item.time.date().isoformat()
        _target_date = datetime.strptime(target_date, "%Y-%m-%d")
        print(f"Running for date {target_date}")
        if target_date not in decisions_dict:
            print("    Not present in decisions dict!")
            print("=" * 100)
            continue

        # Load main_data
        main_data = crop_main_data_landsat(items, abs_client, s3_client, lat, lon, crop_size, main_idx=stac_idx)
        main_item = main_data["tile_item"]
        sensing_time = main_item.time.isoformat()

        # Get wind speed for observation
        sensing_time = main_item.time.isoformat()
        wind_components = get_wind_components(wind_data, sensing_time)
        u_wind_component, v_wind_component = wind_components["geos"]
        wind_speed_geos = calc_effective_wind_speed(u_wind_component, v_wind_component)
        wind_direction_geos = calc_wind_direction(u_wind_component, v_wind_component)
        u_wind_component, v_wind_component = wind_components["era5"]
        wind_speed_era5 = calc_effective_wind_speed(u_wind_component, v_wind_component)
        wind_direction_era5 = calc_wind_direction(u_wind_component, v_wind_component)

        # Get correct lookup table
        lookup_table = RadTranLookupTable.from_params(
            instrument=main_item.instrument,
            solar_angle=main_item.solar_angle,
            observation_angle=main_item.observation_angle,
            hapi_data_path=LANDSAT_HAPI_DATA_PATH,
            min_ch4=0.0,
            max_ch4=21.0,  # this value was selected based on the common value ranges of the sim plume datasets
            spacing_resolution=40000,
            ref_band=main_item.swir16_band_name,
            band=main_item.swir22_band_name,
            full_sensor_name=main_item.sensor_name,
        )

        csv_row = sbr_form_outputs(
            decisions_dict[target_date],
            wind_speed_geos,
            wind_direction_geos,
            wind_speed_era5,
            wind_direction_era5,
            model_ids,
            models,
            device,
            band_concatenators,
            main_data,
            all_reference_data,
            lossFn,
            lookup_table,
            max_distance_pixels,
            pixel_width,
            main_item,
            _target_date,
            SatelliteID.LANDSAT,
            abs_client,
        )
        print("=" * 100)

        fs.write(csv_row)
        results_index.append(target_date)

    fs.seek(0)
    results = pd.read_csv(fs, header=None)
    results["date"] = results_index
    results.set_index("date", inplace=True)
    results.columns = [
        "PlumeLength",
        "IME",
        "EmissionRate",
        "EmissionRateUpper",
        "EmissionRateLower",
        "EmissionRateUncertaintyType",
        "U10WindSpeed",
        "UeffWindSpeed",
        "WindSpeedUpper",
        "WindSpeedLower",
        "WindSpeedUncertaintyType",
        "WindDirection",
        "notes",
    ]

In [None]:
results.sort_index(ascending=True)

## Visualize

In [None]:
inspect_datestr = "2025-03-23"  # only required if we are plotting single
version = "04302025"  # The date on which the images were produced

print(f"CAUTION: showing images produced on {version}")

plt.close()

# loop and plot all
for inspect_datestr in sorted([*decisions_dict.keys()]):
    # print(decisions_dict[inspect_datestr])

    fig = plt.figure(figsize=(15, 4))
    detection_note = decisions_dict[inspect_datestr]["note"]
    feasibility = "feasible" if decisions_dict[inspect_datestr]["feasible"] else "infeasible"
    fig.suptitle(f"Landsat result for {inspect_datestr}: {detection_note} ({feasibility})")

    with rasterio.open(
        f"data/submission_geotiffs/phase1_submission/LANDSAT/DateGenerated{version}_{inspect_datestr}_LANDSAT_OrbioEarth_Enhancement.tif"
    ) as reader:
        ax = plt.subplot(131)
        data = reader.read(1)
        im = ax.imshow(
            data,
            interpolation="none",
            vmin=1,
            vmax=150 if data.max() < 150 else 1500,  # noqa: PLR2004
            # norm="log",
            cmap="Reds",
        )
        plt.colorbar(im, ax=ax)
        ax.set_title("Retrieval (lin)")

        ax = plt.subplot(132)
        im = ax.imshow(
            reader.read(1) + 1,
            interpolation="none",
            vmin=1,
            vmax=1500,
            norm="log",
            cmap="Reds",
        )
        plt.colorbar(im, ax=ax)
        ax.set_title("Retrieval (log)")

    ax = plt.subplot(133)
    with rasterio.open(
        f"data/submission_geotiffs/phase1_submission/LANDSAT/DateGenerated{version}_{inspect_datestr}_LANDSAT_OrbioEarth_Mask.tif"
    ) as reader:
        ax.imshow(reader.read(1), interpolation="none", vmin=0, vmax=1, cmap="Greys")
        ax.set_title("Mask")