In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import copy
import json
import pprint
from datetime import datetime

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
import torch
from lib.models.schemas import WatershedParameters
from tqdm import tqdm

from sbr_2025 import BANDS_LANDSAT
from sbr_2025.utils import intersects_center, select_reference_tiles_from_dates_str
from sbr_2025.utils.plotting import (
    Colorbar,
    all_error_analysis_plots,
    get_band_ratio_landsat,
    get_rgb_bands_landsat,
    plot_all_ratio,
    plot_all_rgb,
    plot_max_proba_center_buffer_heatmap,
    plot_normal_and_avg_strategy,
    plot_normal_and_avg_strategy_summary,
    plot_ratio_diffs,
    plot_rgb_ratio,
    plot_wind,
    validate_pred_retrievals,
)
from sbr_2025.utils.prediction import (
    PlumeInfo,
    TileInfo,
    get_center_buffer,
    get_reference_data,
    get_reference_data_before,
    predict,
    predict_for_all_pairs,
)
from sbr_2025.utils.quantification import (
    get_wind_components,
    sbr_form_outputs,
)
from src.azure_wrap.ml_client_utils import initialize_blob_service_client
from src.data.landsat_data import LandsatGranuleAccess
from src.data.sentinel2 import Sentinel2Item
from src.inference.inference_functions import (
    crop_main_data_landsat,
    crop_reference_data_landsat,
    fetch_landsat_items_for_point,
    query_landsat_catalog_for_point,
)
from src.inference.inference_target_location import (
    quantify_retrieval,
)
from src.plotting.plotting_functions import grid16
from src.training.loss_functions import TwoPartLoss
from src.utils.parameters import LANDSAT_HAPI_DATA_PATH, S2_HAPI_DATA_PATH, SatelliteID
from src.utils.quantification_utils import calc_effective_wind_speed, calc_wind_direction
from src.utils.radtran_utils import RadTranLookupTable
from src.utils.utils import initialize_clients, load_model_and_concatenator

pd.set_option("display.max_colwidth", 250)

🔔🔔🔔🔔

**HOW TO UPDATE THIS NOTEBOOK**

This is a template notebook that is under version control.
When making changes, before staging and committing the changes, run `nbstripout` to remove the output from the notebook.

```bash
# conda install -c conda-forge nbstripout  # install nbstripout if not already installed
nbstripout --drop-empty-cells --extra-keys="metadata.kernelspec.display_name metadata.kernelspec.name" Landsat_SBR_exploration.ipynb
git add -p Landsat_SBR_exploration.ipynb
```

🔔🔔🔔🔔

## 0. Setup

In [None]:
ml_client, _, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

In [None]:
lat, lon = 32.82175, -111.78581  # release point for SBRs

# can we predict on larger chips?
crop_size = 256  # 128
center_buffer = 12  # 5  # number of pixels to search from center

# Date range of SBRs
start_date = "2025-01-01"
end_date = "2025-04-30"  # a month after the end of the SBRs to include reference images after

# these dates have KNOWN releases during Phase 0
phase0_release_dates = [
    # "2024-11-14",
    "2024-11-15",  # LS
    "2024-11-16",  # LS
    # "2024-11-19",
    # "2024-11-22",
    "2024-11-24",  # LS
    "2024-12-02",  # LS
    # "2024-12-04",
    # "2024-12-07",
    "2024-12-09",  # LS
    # "2024-12-17",
    # "2024-12-19",
    # "2024-12-22",
    "2024-12-26",  # LS
    # "2024-12-29",
]

release_dates_2022 = [
    "2022-10-25",
    "2022-10-26",
    "2022-11-02",
    "2022-11-03",
    "2022-11-10",
    "2022-11-11",
    "2022-11-18",
]

watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.1,
    watershed_floor_threshold=0.075,
    closing_footprint_size=0,
)
item_meta_dict = {}

start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")

Models
- 37: Train for longer (Production model)
  - [Ep60] mAvgRecall: 40.3% (Hassi: 71.3%, Marcellus:  6.8%, Permian: 42.8%)
- 155: b3 encoder
  - [Ep50] mAvgRecall: 41.2% (Hassi: 72.3%, Marcellus:  7.4%, Permian: 43.8%)
- 55: Bigger encoder b2
  - [Ep68] mAvgRecall: 40.7% (Hassi: 71.4%, Marcellus:  8.0%, Permian: 42.6%)
- 168: Bigger encoder b2
  - [Ep70] mAvgRecall: 40.6% (Hassi: 71.0%, Marcellus:  7.5%, Permian: 43.4%)
- 57: More bands, b1
  - [Ep72] mAvgRecall: 40.3% (Hassi: 73.0%, Marcellus:  6.1%, Permian: 41.8%)
- 103: Only deserts model, Hassi val
  - [Ep28] mAvgRecall: 71.8% (Hassi: 71.8%)
- 154: No modulation schedule (should be slightly worse in picking up extremely faint plumes)
  - [Ep56] mAvgRecall: 40.9% (Hassi: 68.8%, Marcellus: 11.3%, Permian: 42.6%)
- 14: First model we trained
  - [Ep68] mAvgRecall: 39.7% (Hassi: 70.0%, Marcellus:  6.7%, Permian: 42.4%)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ids = ["37", "155", "55", "57", "103", "168", "154", "14"]
models = []
band_concatenators = []
for model_id in model_ids:
    model, band_concatenator, train_params = load_model_and_concatenator(
        f"models:/landsat/{model_id}", device, SatelliteID.LANDSAT
    )
    model.eval()
    models.append(model)
    band_concatenators.append(band_concatenator)

lossFn = TwoPartLoss(train_params["binary_threshold"], train_params["MSE_multiplier"])

### Overpass Dates

In [None]:
stac_items = query_landsat_catalog_for_point(lat, lon, start_date, end_date)
for item in stac_items:
    cloud_cover = item.properties["eo:cloud_cover"]
    print(
        f"{item.datetime.date().isoformat()} with {cloud_cover:6.1f}% clouds - ({item.id})"
        f"and CRS {item.properties['proj:code']}"
    )
overpass_dates = [item.datetime.date().isoformat() for item in stac_items]

In [None]:
%%time
# choose a single date to do inference
target_date = stac_items[-1].datetime.date().isoformat()
target_date = datetime.strptime(target_date, "%Y-%m-%d")

items = fetch_landsat_items_for_point(lat=lat, lon=lon, query_datetime=target_date, how_many_days_back=180)

main_data = crop_main_data_landsat(items, abs_client, s3_client, lat, lon, crop_size)
main_item = main_data["tile_item"]

## Tile info for all overpasses

In [None]:
# Load the metadata into the item
main_item.load_metadata(s3_client, abs_client)

tile_properties = TileInfo(
    instrument_name=main_item.instrument_name,
    date_analysis=datetime.today().date().isoformat(),
    observation_date=main_item.time.date().isoformat(),
    observation_timestamp=main_item.time.time().isoformat(timespec="seconds"),
    start_time=None,
    end_time=None,
    imaging_mode=main_item.imaging_mode,
    off_nadir_angle=main_item.off_nadir_angle,
    viewing_azimuth=main_item.viewing_azimuth,
    solar_zenith=main_item.solar_zenith,
    solar_azimuth=main_item.solar_azimuth,
    orbit_state=None,
)

pprint.pp(tile_properties)

In [None]:
tile_infos = []
for t in tqdm(sorted(overpass_dates)):
    # print(f"Date: {target_date}")
    target_date = datetime.strptime(t, "%Y-%m-%d")

    items = fetch_landsat_items_for_point(lat=lat, lon=lon, query_datetime=target_date)

    main_data = crop_main_data_landsat(items, abs_client, s3_client, lat, lon, crop_size)
    main_item = main_data["tile_item"]
    main_item.load_metadata(s3_client, abs_client)

    tile_properties = TileInfo(
        instrument_name=main_item.item.properties["platform"],
        date_analysis=datetime.today().date().isoformat(),
        observation_date=main_item.time.date().isoformat(),
        observation_timestamp=main_item.time.time().isoformat(timespec="seconds"),
        start_time=None,
        end_time=None,
        imaging_mode=main_item.imaging_mode,
        off_nadir_angle=main_item.off_nadir_angle,
        viewing_azimuth=main_item.viewing_azimuth,
        solar_zenith=main_item.solar_zenith,
        solar_azimuth=main_item.solar_azimuth,
    )

    tile_infos.append(tile_properties.asdict())

df = pd.DataFrame(tile_infos, index=overpass_dates)

In [None]:
df.columns

In [None]:
# if date use conversion to mm/dd/yyy
# pd.to_datetime(df.observation_date).sort_values().dt.strftime('%m/%d/%Y')

for x in list(df.solar_azimuth):
    print(x)

## Load in Reference Data

In [None]:
# Use 25% clouds/shadows cutoff to include non-obvious reference images that may help
max_bad_pixel_perc = 25  # max % sum of clouds/cloud shadows/nodata in reference chips
num_snapshots = 40

In [None]:
# This is the slowest part of the next function,
# getting the raster metadata once here speeds it up by a lot
for item in tqdm(items):
    if item.id not in item_meta_dict:
        # Get the Landsat metadata from the coastal band (arbitrary choice).
        try:
            item_meta_dict[item.id] = item.get_raster_meta("coastal", abs_client=abs_client)
        except Exception as e:
            print(e)
            continue

reference_data = crop_reference_data_landsat(
    items,
    main_data,
    abs_client,
    s3_client,
    lat,
    lon,
    crop_size,
    required_num_previous_snapshots=num_snapshots,
    max_bad_pixel_perc=max_bad_pixel_perc,
    item_meta_dict=item_meta_dict,
)

print("\n---: Summary")
print(f"Main tile for {target_date}: USE {main_item.id}")
for reference in reference_data:
    tile_id = reference["tile_item"].id
    print(f"Reference tile for {target_date}: USE {tile_id}")

## I. RGBs & Ratios

In [None]:
### Show all of them together
plot_all_rgb([main_data, *reference_data], SatelliteID.LANDSAT)
plot_all_ratio([main_data, *reference_data], SatelliteID.LANDSAT)

In [None]:
### Show them individually
data_items = [main_data, *reference_data]

for i, data in enumerate(data_items):
    date = data["tile_item"].time.date().isoformat()
    plot_rgb_ratio(data, Colorbar.INDIVIDUAL, i, SatelliteID.LANDSAT)

## Ratio Diffs

In [None]:
_ = plot_ratio_diffs(data_items, Colorbar.INDIVIDUAL, SatelliteID.LANDSAT, mean_adjust=True)

## II. Prediction

### Predict for a specific date

#### Select Reference Chips

In [None]:
overpass_dates = [item.datetime.date().isoformat() for item in stac_items][::-1]
overpass_dates

In [None]:
print("Reference tile dates:")
for x in reference_data:
    reference_date = x["tile_item"].time.date().isoformat()
    if reference_date in overpass_dates:
        print(f"{reference_date} - possible SBR release (date is in Phase 1)")
    elif reference_date in phase0_release_dates:
        print(f"{reference_date} - known SBR release (date is in Phase 0)")
    else:
        print(f"{reference_date} - not a Phase 1 overpass date")

In [None]:
model_id = "55"
model_idx = model_id.index(model_id)

before_date = "2025-03-08"
earlier_date = "2025-02-28"

reference_chips = select_reference_tiles_from_dates_str(
    reference_data, before_date=before_date, earlier_date=earlier_date
)
before_data = reference_chips[0]
earlier_data = reference_chips[1]

pd.DataFrame(
    {
        "tile": ["main", "before", "earlier"],
        "date": [
            main_data["tile_item"].time.date().isoformat(),
            before_data["tile_item"].time.date().isoformat(),
            earlier_data["tile_item"].time.date().isoformat(),
        ],
        "time": [
            main_data["tile_item"].time.time().isoformat(),
            before_data["tile_item"].time.time().isoformat(),
            earlier_data["tile_item"].time.time().isoformat(),
        ],
        "datetime": [main_data["tile_item"].time, before_data["tile_item"].time, earlier_data["tile_item"].time],
        "observation_angle": [
            main_data["tile_item"].observation_angle,
            before_data["tile_item"].observation_angle,
            earlier_data["tile_item"].observation_angle,
        ],
    }
)

#### Predict

In [None]:
%%time
watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.1,
    watershed_floor_threshold=0.075,
    closing_footprint_size=0,
)
prediction = predict(
    main_data,
    reference_chips,
    watershed_params,
    models[model_idx],
    device,
    band_concatenators[model_idx],
    lossFn,
    create_lookup_table=True,
)

#### Plot Predictions

In [None]:
rgb_main = get_rgb_bands_landsat(main_data["crop_arrays"], BANDS_LANDSAT)
rgb_before = get_rgb_bands_landsat(before_data["crop_arrays"], BANDS_LANDSAT)
rgb_earlier = get_rgb_bands_landsat(earlier_data["crop_arrays"], BANDS_LANDSAT)

ratio_main = get_band_ratio_landsat(main_data["crop_arrays"], BANDS_LANDSAT)
ratio_before = get_band_ratio_landsat(before_data["crop_arrays"], BANDS_LANDSAT)
ratio_earlier = get_band_ratio_landsat(earlier_data["crop_arrays"], BANDS_LANDSAT)

date_main = main_data["tile_item"].time.date().isoformat()
date_before = before_data["tile_item"].time.date().isoformat()
date_earlier = earlier_data["tile_item"].time.date().isoformat()

plot = all_error_analysis_plots(
    rgb_main=rgb_main,
    rgb_before=rgb_before,
    rgb_earlier=rgb_earlier,
    ratio_main=ratio_main,
    ratio_before=ratio_before,
    ratio_earlier=ratio_earlier,
    predicted_frac=prediction.marginal,
    predicted_mask=prediction.mask,
    conditional_pred=prediction.conditional,
    binary_probability=prediction.binary_probability,
    conditional_retrieval=prediction.conditional_retrieval,
    masked_conditional_retrieval=prediction.masked_conditional_retrieval,
    rescaled_retrieval=None,  # We will plot this in the Retrieval and quantification section
    marginal_retrieval=prediction.marginal_retrieval,
    watershed_segmentation_params=watershed_params,
    dates=(date_main, date_before, date_earlier),
    ratio_colorbar=Colorbar.SHARE,  # (0.65, 1.0),
    ratio_diff_colorbar=Colorbar.SHARE,  # (-0.1, 0.1)
)

#### Retrieval

In [None]:
# Get Retrieval for the biggest plume in the center

crop_x = main_data["crop_params"]["swir22"]["crop_start_x"]
crop_y = main_data["crop_params"]["swir22"]["crop_start_y"]
prediction.crop_x = crop_x
prediction.crop_y = crop_y
pred_info = prediction.asdict()

timestamp = main_item.time
raster_meta = main_item.get_raster_meta("swir22", abs_client=abs_client)
crop_crs = raster_meta["crs"]

window = rasterio.windows.Window(crop_x, crop_y, crop_size, crop_size)
crop_transform = rasterio.windows.transform(window, raster_meta["transform"])

# Quantify plumes in retrieval
plume_list = quantify_retrieval(
    prediction.conditional_retrieval,
    crop_transform,
    crop_crs,
    prediction.binary_probability,
    timestamp.date().isoformat(),
    floor_t=watershed_params.watershed_floor_threshold,
    marker_t=watershed_params.marker_threshold,
    spatial_resolution=30.0,
)

if len(plume_list) == 0:
    print("No plumes detected")
else:
    print(f"Found {len(plume_list)} plumes")

    # Get Retrieval for the biggest plume in the center
    plume_list = sorted(plume_list, key=lambda x: x["properties"]["Q"])[::-1]
    center_plumes: PlumeInfo = []
    for plume in plume_list:
        plume_info = PlumeInfo(**plume["properties"])

        intersects = intersects_center(*plume_info.bbox, buffer=center_buffer)
        if intersects:
            pprint.pp(plume_info)
            center_plumes.append(plume)

    print(f"{len(center_plumes)} plumes intersect the center with {center_buffer} pixels of buffer")

### Heatmap of Max Probability + Binary Plot for Each Reference Chip Pairing

In [None]:
model_id = "55"
model_idx = model_ids.index(model_id)

In [None]:
predictions = predict_for_all_pairs(
    main_data,
    reference_data,
    models[model_idx],
    device,
    band_concatenators[model_idx],
    lossFn,
    watershed_params,
    skip_retrieval=True,
)

In [None]:
dates = [d["tile_item"].time for d in reference_data]

plot_max_proba_center_buffer_heatmap(
    predictions=predictions,
    dates=dates,
    center_buffer=center_buffer,
    title=f'Model {model_id} for main={target_date.isoformat().split("T")[0]}',
    satellite_id=SatelliteID.LANDSAT,
    show_topk=5,
    main_date=target_date.isoformat().split("T")[0],
    plot_binary_grid=True,
)

### Predict for multiple dates with "Normal" and Average Reference images

Prepare data for all dates we want to predict for.

In [None]:
%%time
main_data_all = {}
reference_data_all = {}
for item in stac_items:
    target_date = item.datetime.date().isoformat()
    target_date = datetime.strptime(target_date, "%Y-%m-%d")

    items = fetch_landsat_items_for_point(lat=lat, lon=lon, query_datetime=target_date)

    main_data = crop_main_data_landsat(items, abs_client, s3_client, lat, lon, crop_size)
    main_data_all[target_date] = main_data

    main_item = main_data["tile_item"]
    print(f"Main tile for {target_date}: USE {main_item.id}")

    # This is the slowest part of the next function,
    # getting the raster metadata once here speeds it up by a lot
    for item_ in items:
        if item_.id not in item_meta_dict:
            item_.prefetch_l1(s3_client, abs_client)
            # Get the Landsat metadata from the coastal band (arbitrary choice).
            item_meta_dict[item_.id] = item_.get_raster_meta("coastal", abs_client=abs_client)

    reference_data = crop_reference_data_landsat(
        items,
        main_data,
        abs_client,
        s3_client,
        lat,
        lon,
        crop_size,
        required_num_previous_snapshots=num_snapshots,
        max_bad_pixel_perc=max_bad_pixel_perc,
        item_meta_dict=item_meta_dict,
    )
    reference_data_all[target_date] = reference_data
    print()

In [None]:
maxs, sums, avg_ref_counts = plot_normal_and_avg_strategy(
    stac_items,
    main_data_all,
    reference_data_all,
    phase0_release_dates,
    models,
    model_ids,
    band_concatenators,
    device,
    lossFn,
    watershed_params,
    center_buffer,
)

In [None]:
plot_normal_and_avg_strategy_summary(sums, avg_ref_counts, model_ids, ylim=[-2, 30], buffer_width=center_buffer * 2 + 1)

# In detail: Inspection and Submission of individual dates (Q1 2025)

In [None]:
# Prepare wind data once
wind_data = pd.read_csv("../../src/data/ancillary/wind_vectors_gt_vs_models_2025_sbr_sites.csv")
wind_data = wind_data[wind_data["sensor"] == "LS"]
wind_data["sensing_time"] = pd.to_datetime(
    wind_data["date"] + " " + wind_data["overpass_time_utc"], utc=True
).dt.strftime("%Y-%m-%dT%H:%M:%S+0000")

In [None]:
ml_client, _, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

In [None]:
# choose a single date to do inference
stac_idx = 19
target_date = items[stac_idx].time
print(target_date.isoformat())

In [None]:
cloudy_hazy_dates = [
    "2025-01-03",  # hazy
    "2025-03-16",  # hazy
    "2025-02-12",  # cloudy
    "2025-03-23",  # cloudy
    "2025-03-15",  # cloudy
    "2025-03-07",  # cloudy
    "2025-02-04",  # cloudy
]

dark_ratio_dates = [
    "2025-01-27",  # dark ratio, mean 0.792 (normal is ~0.84)
    "2025-03-08",  # dark ratio, mean 0.785 (normal is ~0.84)
]

# Following dates we determined there had been an emission
phase1_believed_release_dates = [
    "2025-01-26",
    "2025-02-28",
    "2025-03-16",
    "2025-03-24",
]

exclude_dates = [
    # "2024-12-02",
    *phase0_release_dates,
    *cloudy_hazy_dates,
    *dark_ratio_dates,
    *phase1_believed_release_dates,
]

print(f"excluding {len(exclude_dates)} dates")
print("\n".join(exclude_dates))

In [None]:
# Prepare main data
main_data = crop_main_data_landsat(items, abs_client, s3_client, lat, lon, crop_size, main_idx=stac_idx)
main_item = main_data["tile_item"]
print(f"Main tile      {main_item.id}")

# Use reference images before/after the main date
max_days_difference = 60  # 75
reference_data_ = []
for reference in reference_data:
    tile_id = reference["tile_item"].id
    ref_date = datetime.strptime(tile_id.split("_")[3], "%Y%m%d")
    ref_date_short = ref_date.isoformat().split("T")[0]
    if ref_date_short in exclude_dates:
        print(f"Reference img on {ref_date_short} is excluded         --> DONT USE")
        continue

    # Calculate the absolute difference in time
    time_difference = abs(target_date.date() - ref_date.date())

    # Check if the difference in days is within the threshold
    # timedelta.days gives the difference purely in days (ignoring hours etc.)
    if time_difference.days <= max_days_difference and time_difference.days > 0:
        reference_data_.append(reference)
        print(
            f'Reference img on {ref_date.isoformat().split("T")[0]} is '
            f'{time_difference.days:3} days distant -->      USE'
        )
    else:
        print(
            f'Reference img on {ref_date.isoformat().split("T")[0]} is '
            f'{time_difference.days:3} days distant --> DONT USE'
        )
len(reference_data_)

In [None]:
# Plot Wind for Main Date
sensing_time = main_item.time  # .isoformat()
print(sensing_time)
wind_components = get_wind_components(wind_data, sensing_time)

u_wind_component, v_wind_component = wind_components["geos"]
wind_speed_geos = calc_effective_wind_speed(u_wind_component, v_wind_component)
wind_direction_geos = calc_wind_direction(u_wind_component, v_wind_component)

u_wind_component, v_wind_component = wind_components["era5"]
wind_speed_era5 = calc_effective_wind_speed(u_wind_component, v_wind_component)
wind_direction_era5 = calc_wind_direction(u_wind_component, v_wind_component)

print("GEOS Wind")
plot_wind(
    lon,
    lat,
    main_item,
    abs_client,
    wind_speed_geos,
    wind_direction_geos,
    satellite_id=SatelliteID.LANDSAT,
    arrow_scale_factor=500,
)
print("ERA5 Wind")
plot_wind(
    lon,
    lat,
    main_item,
    abs_client,
    wind_speed_era5,
    wind_direction_era5,
    satellite_id=SatelliteID.LANDSAT,
    arrow_scale_factor=500,
)

In [None]:
# Find Before/Earlier dates for "Normal" prediction
print(f"Main    date = {main_item.date}")
before_date = None
earlier_date = None
for reference in reference_data:
    tile_id = reference["tile_item"].id
    ref_date = datetime.strptime(tile_id.split("_")[3], "%Y%m%d")
    ref_date_short = ref_date.isoformat().split("T")[0]
    # Calculate the absolute difference in time
    time_difference = target_date.date() - ref_date.date()
    if time_difference.days <= 0:
        continue
    if ref_date_short in exclude_dates:
        print(f"Reference img on {ref_date_short} is excluded         --> DONT USE")
        continue
    if before_date is None:  # Closest is Before (t-1)
        before_date = ref_date.isoformat().split("T")[0]
    elif earlier_date is None:  # Next closest is Earlier (t-2)
        earlier_date = ref_date.isoformat().split("T")[0]
        break

# # Optionally overwrite them
# before_date = "2025-01-19"
# earlier_date = "2025-01-18"

print(f"Before  date = {before_date}")
print(f"Earlier date = {earlier_date}")

In [None]:
row, col = crop_size // 2, crop_size // 2

half_crop = 64  # we want middle 128x128
xmin = row - half_crop
xmax = row + half_crop
ymin = col - half_crop
ymax = col + half_crop
extent = [xmin, xmax, ymin, ymax]

In [None]:
avg_reference_days_diffs = [40, 25, 15]

# set up UI
widget_outputs = [widgets.Output(layout={"height": "1000px", "width": "1200px"}) for _ in model_ids]
for output in widget_outputs:
    with output:
        print("Loading...")
tabs = widgets.Tab(children=widget_outputs, titles=model_ids)
display(tabs)

# Visualize predictions using all reference image combinations as inputs
output_preds = {}
dates = [d["tile_item"].time for d in reference_data_]
for model_idx, (model_id, output) in enumerate(zip(model_ids, widget_outputs, strict=True)):
    main_date = target_date.isoformat().split("T")[0]
    avg_preds = {}

    with output:
        output.clear_output()
        for avg_reference_days_diff in avg_reference_days_diffs:
            # Use reference images before/after the main date
            for timeframe in ["before/after", "before"]:
                if timeframe == "before":
                    avg_reference_days_diff_ = int(avg_reference_days_diff * 2)
                    reference_data_for_avg = get_reference_data_before(
                        reference_data, target_date, max_days_difference=avg_reference_days_diff_
                    )
                else:
                    avg_reference_days_diff_ = avg_reference_days_diff
                    reference_data_for_avg = get_reference_data(
                        reference_data, target_date, max_days_difference=avg_reference_days_diff_
                    )
                # Data preparation for predicting with average last 10 reference images
                valid_refs = [ref for ref in reference_data_for_avg if ref["tile_item"].time not in exclude_dates]
                print(
                    f"Ref images {avg_reference_days_diff_} {timeframe}: "
                    f"Filtered {len(reference_data_for_avg)} refs to "
                    f"{len(valid_refs)} after excluding dates"
                )
                avg_ref_count = len(valid_refs)
                if avg_ref_count == 0:
                    continue
                data_avg = copy.copy(reference_data_for_avg[0])
                data_avg["crop_arrays"] = np.mean([ref["crop_arrays"] for ref in valid_refs], axis=0)
                center_y = data_avg["crop_arrays"][:, xmin:xmax, ymin:ymax].shape[1] // 2
                center_x = data_avg["crop_arrays"][:, xmin:xmax, ymin:ymax].shape[2] // 2

                f, ax = plt.subplots(1, 4, figsize=(20, 5))
                preds_avg = predict(
                    main_data,
                    [data_avg, data_avg],
                    watershed_params,
                    models[model_idx],
                    device,
                    band_concatenators[model_idx],
                    lossFn,
                )
                ratio_main = get_band_ratio_landsat(preds_avg.x_dict["crop_main"][0], BANDS_LANDSAT)[
                    xmin:xmax, ymin:ymax
                ]  # type:ignore
                vmin = np.percentile(ratio_main, 0.5)
                vmax = np.percentile(ratio_main, 99.5)

                ratio_before = get_band_ratio_landsat(preds_avg.x_dict["crop_before"][0], BANDS_LANDSAT)[
                    xmin:xmax, ymin:ymax
                ]  # type:ignore
                ratio_earlier = get_band_ratio_landsat(preds_avg.x_dict["crop_earlier"][0], BANDS_LANDSAT)[
                    xmin:xmax, ymin:ymax
                ]  # type:ignore

                ax[0].imshow(ratio_earlier, vmin=vmin, vmax=vmax)
                ax[0].set_title(
                    f"Avg({avg_reference_days_diff_} days {timeframe} main) = {avg_ref_count} refs\n"
                    f"Min {ratio_before.min():.2f} Max {ratio_before.max():.2f} "
                    f"Mean {ratio_earlier.mean():.3f}",
                    fontsize=15,
                )
                ax[1].imshow(ratio_before, vmin=vmin, vmax=vmax)
                ax[1].set_title(
                    f"Avg({avg_reference_days_diff_} days {timeframe} main) = {avg_ref_count} refs\n"
                    f"Min {ratio_before.min():.2f} Max {ratio_before.max():.2f} "
                    f"Mean {ratio_before.mean():.3f}",
                    fontsize=15,
                )
                ax[2].imshow(ratio_main, vmin=vmin, vmax=vmax)
                ax[2].set_title(
                    f"Main Ratio {main_date}\nMin {ratio_main.min():.2f} Max {ratio_main.max():.2f} "
                    f"Mean {ratio_main.mean():.3f}",
                    fontsize=15,
                )
                ax[3].imshow(preds_avg.binary_probability[xmin:xmax, ymin:ymax], vmin=0.0, vmax=1.0, cmap="hot_r")

                center = get_center_buffer(preds_avg.binary_probability, center_buffer)
                avg_preds[f"{avg_reference_days_diff_}_{timeframe}"] = {
                    "avg_ref_count": avg_ref_count,
                    "max_prob": 100 * center.max(),
                    "sum_prob": 100 * center.sum() / (center.shape[0] * center.shape[1]),
                }
                ax[3].set_title(
                    f"Avg({avg_reference_days_diff_} days {timeframe} main)\nCenter sum(Prob): "
                    f"{100 * center.sum() / (center.shape[0] * center.shape[1]):.1f}%, "
                    f"Max: {100 * center.max():.0f}%",
                    fontsize=16,
                )
                ax[0].scatter(center_x, center_y, color="green", marker="x")
                ax[1].scatter(center_x, center_y, color="green", marker="x")
                ax[2].scatter(center_x, center_y, color="green", marker="x")
                ax[3].scatter(center_x, center_y, color="green", marker="x")
                grid16(ax[0])
                grid16(ax[1])
                grid16(ax[2])
                grid16(ax[3])
                plt.tight_layout()
                plt.show()

        print(f"Model {model_id}")
        for k, v in avg_preds.items():
            print(
                f"Avg{k:15}   Center sum(Prob): {v['sum_prob']:4.1f}%, "
                f"Max: {v['max_prob']:3.0f}% (using avg of {v['avg_ref_count']} ref images)"
            )
        print("=" * 50)

        predictions = predict_for_all_pairs(
            main_data,
            reference_data_,
            models[model_idx],
            device,
            band_concatenators[model_idx],
            lossFn,
            watershed_params,
            skip_retrieval=True,
        )

        print(f"Model {model_id}")
        output_preds[model_id] = plot_max_proba_center_buffer_heatmap(
            predictions=predictions,
            dates=dates,
            center_buffer=center_buffer,
            title=f'Model {model_id} for main={target_date.isoformat().split("T")[0]}',
            show_topk=6,
            topk_based_on="marg_sum",  # "sum", #"max",  # or "sum" to use the sum of center probabilities
            satellite_id=SatelliteID.LANDSAT,
            main_date=target_date.isoformat().split("T")[0],
            before_date=before_date,
            earlier_date=earlier_date,
            plot_topk=True,  # True,
            plot_max_grid=True,  # True,
            plot_binary_grid=False,
            dates_to_exclude=[],
            extent=extent,
        )

1. Validate and visualize the plumes and emission rates of "normal" predictions, topk predictions or any other combination. Build the combination list `combinations_to_visualize` you want to test and call `validate_pred_retrievals(combinations_to_visualize)`.
2. Manually select the combination of model id + before/earlier dates we will use for L/IME/Emission Rate Calculation + Unmasked Retrieval and plume mask .tif outputs. Put this into the decision_dict for the main date, e.g. `'selected_retrieval': {'model_id': '37', 'date_before': '2024-12-25', 'date_earlier': '2024-12-18'}`
    1. Set `feasible` to True if we are able to predict for this main date, i.e. its not cloudy/cloud shadowy
    2. Set `note` to "no_detection" or "detection" and any other notes you have.
    3. Set `watershed_marker_t` and `watershed_floor_t` to your selected watershed parameters if you want different parameters from the default `watershed_marker_t=0.2`,`watershed_floor_t=0.15`
    4. Set `selected_retrieval` as described above
    5. If you have a detection: Set `wind_source` to "era5" or "geos"
    6. If you have a detection: Set `emission_ensemble_selections` to a list of model id + before/earlier dates whose emission rates will be used to calculate our emission rate uncertainty. __Don't put the selected retrieval in here.__
4. Run `sbr_form_outputs()` to get the values ready for copying into the form fillout. This will also export the needed unmasked retrieval and the binary plume mask .tifs from your `selected_retrieval`

In [None]:
# Run once to initialize the RadTranLookup for the main date
granule_item = main_data["tile_item"]
if isinstance(granule_item, LandsatGranuleAccess):
    hapi_data_path = LANDSAT_HAPI_DATA_PATH
elif isinstance(granule_item, Sentinel2Item):
    hapi_data_path = S2_HAPI_DATA_PATH
else:
    raise ValueError(f"Unsupported granule access type: {type(granule_item)}")
lookup_table = RadTranLookupTable.from_params(
    instrument=granule_item.instrument,
    solar_angle=granule_item.solar_angle,
    observation_angle=granule_item.observation_angle,
    hapi_data_path=hapi_data_path,
    min_ch4=0.0,
    max_ch4=21.0,  # this value was selected based on the common value ranges of the sim plume datasets
    spacing_resolution=40000,
    ref_band=granule_item.swir16_band_name,
    band=granule_item.swir22_band_name,
    full_sensor_name=granule_item.sensor_name,
)

In [None]:
# Build the combinations you want to validate/visualize
combinations_to_visualize = []
for model_id, v in output_preds.items():
    # if model_id in ["55"]: # skip certain models
    #     continue
    for j in range(1, 4):  # 6): # use topX dates
        combinations_to_visualize.append(
            {
                "model_id": model_id,
                "date_before": v[f"top_{j}"]["date_before"],
                "date_earlier": v[f"top_{j}"]["date_earlier"],
            }
        )

    # visualize the "normal" predictions
    combinations_to_visualize.append(
        {"model_id": model_id, "date_before": v["normal"]["date_before"], "date_earlier": v["normal"]["date_earlier"]}
    )

# format output for copy/paste
for d in combinations_to_visualize:
    print(12 * " " + f"{json.dumps(d)},")

In [None]:
wind_speed_era5, wind_speed_geos

- Model 154 for B=2025-01-19 and E=2025-01-11: L= 454.0, IME= 2824.3, Emission Rate:  395

In [None]:
watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.1,
    watershed_floor_threshold=0.075,
    closing_footprint_size=0,
)
max_distance_pixels = 10
pixel_width = 30
wind_speed = wind_speed_geos  # (wind_speed_era5 + wind_speed_geos) / 2  # or wind_speed_geos
# print("[WARN] using average windspeed")

validate_pred_retrievals(
    combinations_to_visualize,
    main_data,
    reference_data,
    model_ids,
    watershed_params,
    models,
    device,
    band_concatenators,
    lossFn,
    lookup_table,
    wind_speed,
    SatelliteID.LANDSAT,
    max_distance_pixels,
    pixel_width,
    show_plots=False,
    extent=extent,
)
validate_pred_retrievals(
    combinations_to_visualize,
    main_data,
    reference_data,
    model_ids,
    watershed_params,
    models,
    device,
    band_concatenators,
    lossFn,
    lookup_table,
    wind_speed,
    SatelliteID.LANDSAT,
    max_distance_pixels,
    pixel_width,
    show_plots=True,
    extent=extent,
)

In [None]:
with open("./landsat_phase1_decision_dict.json") as fs:
    decisions_dict = json.load(fs)

In [None]:
sbr_form_outputs(
    decisions_dict["2025-02-19"],
    wind_speed_geos,
    wind_direction_geos,
    wind_speed_era5,
    wind_direction_era5,
    model_ids,
    models,
    device,
    band_concatenators,
    main_data,
    reference_data,
    lossFn,
    lookup_table,
    max_distance_pixels,
    pixel_width,
    main_item,
    target_date,
    SatelliteID.LANDSAT,
    abs_client,
)