In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import copy
import pprint
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
import torch
from lib.models.schemas import WatershedParameters
from rasterio.crs import CRS
from tqdm import tqdm

from sbr_2025 import BANDS
from sbr_2025.utils import intersects_center, select_reference_tiles_from_dates_str
from sbr_2025.utils.plotting import (
    Colorbar,
    all_error_analysis_plots,
    get_band_ratio,
    get_rgb_bands,
    plot_all_ratio,
    plot_all_rgb,
    plot_max_proba_center_buffer_heatmap,
    plot_normal_and_avg_strategy,
    plot_normal_and_avg_strategy_summary,
    plot_ratio_diffs,
    plot_rgb_ratio,
    plot_wind,
    validate_pred_retrievals,
)
from sbr_2025.utils.prediction import (
    PlumeInfo,
    TileInfo,
    export_predition_and_mask_to_geotiff,
    get_center_buffer,
    get_reference_data,
    predict,
    predict_for_all_pairs,
)
from sbr_2025.utils.quantification import (
    get_wind_components,
    sbr_form_outputs,
)
from src.azure_wrap.ml_client_utils import initialize_blob_service_client
from src.data.landsat_data import LandsatGranuleAccess
from src.data.sentinel2 import Sentinel2Item
from src.inference.inference_functions import (
    crop_main_data,
    crop_reference_data,
    fetch_sentinel2_items_for_point,
    query_sentinel2_catalog_for_point,
)
from src.inference.inference_target_location import (
    quantify_retrieval,
)
from src.plotting.plotting_functions import grid16
from src.training.loss_functions import TwoPartLoss
from src.utils.parameters import LANDSAT_HAPI_DATA_PATH, S2_HAPI_DATA_PATH, SatelliteID
from src.utils.quantification_utils import calc_effective_wind_speed, calc_wind_direction
from src.utils.radtran_utils import RadTranLookupTable
from src.utils.utils import initialize_clients, load_model_and_concatenator

🔔🔔🔔🔔

**HOW TO UPDATE THIS NOTEBOOK**

This is a template notebook that is under version control.
When making changes, before staging and committing the changes, run `nbstripout` to remove the output from the notebook.

```bash
# conda install -c conda-forge nbstripout  # install nbstripout if not already installed
nbstripout --drop-empty-cells --extra-keys="metadata.kernelspec.display_name metadata.kernelspec.name" S2_SBR_exploration.ipynb
git add -p S2_SBR_exploration.ipynb
```

🔔🔔🔔🔔

## 0. Setup

In [None]:
ml_client, _, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

In [None]:
lat, lon = 32.82175, -111.78581  # release point for SBRs

# can we predict on larger chips?
crop_size = 256
center_buffer = 5  # number of pixels to search from center

# Date range of SBR Phase 1
start_date = "2024-10-01"
# end_date = "2025-03-31"
end_date = "2025-04-30"  # a month after the end of the SBRs to include reference images after

# these dates have KNOWN releases during Phase 0
phase0_release_dates = [
    "2024-11-14",
    "2024-11-19",
    "2024-11-22",
    "2024-12-02",
    "2024-12-04",
    "2024-12-07",
    "2024-12-09",
    "2024-12-17",
    "2024-12-19",
    "2024-12-22",
    "2024-12-29",
]

watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.1,
    watershed_floor_threshold=0.075,
    closing_footprint_size=0,
)
item_meta_dict = {}

start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")

Models
- 1226: Production model
  - mAvgRecall: 44.6% (Hassi: 77.2%, Marcellus: 11.2%, Permian: 45.4%)
- 1475: b3 encoder
  - mAvgRecall: 46.4% (Hassi: 78.9%, Marcellus: 11.9%, Permian: 48.4%)
- 1486: b4 encoder
  - mAvgRecall: 45.9% (Hassi: 77.9%, Marcellus: 11.9%, Permian: 47.8%)
- 1395: Only deserts model, Hassi val
  - Hassi: 77.3%
- 1422: Weight Decay = 0
  - mAvgRecall: 44.1% (Hassi: 77.2%, Marcellus: 10.2%, Permian: 45.0%)
- 1340: Dropping hardest ~10% parquets
  - mAvgRecall: 44.1% (Hassi: 76.5%, Marcellus: 10.5%, Permian: 45.5%)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ids = ["1226", "1475", "1486", "1395", "1422", "1340"]
models = []
band_concatenators = []
for model_id in model_ids:
    model, band_concatenator, train_params = load_model_and_concatenator(
        f"models:/torchgeo_pwr_unet/{model_id}", device, SatelliteID.S2
    )
    model.eval()
    models.append(model)
    band_concatenators.append(band_concatenator)

lossFn = TwoPartLoss(train_params["binary_threshold"], train_params["MSE_multiplier"])

### Overpass Dates

In [None]:
stac_items = query_sentinel2_catalog_for_point(lat, lon, start_date, end_date, crop_size, sbr_notebook=True)

for item in stac_items:
    print(f"{item.datetime.isoformat(timespec='seconds')} - {item.id}")

overpass_dates = [item.datetime.date().isoformat() for item in stac_items]
overpass_times = [item.datetime.time().isoformat(timespec="seconds") for item in stac_items]

# the following are printed in a format for easier copying into spreadsheets
# for d in overpass_dates:
#     print(d)
# for t in overpass_times:
#     print(t)

In [None]:
# Dates to avoid using for reference tiles: cloudy / emission
bad_reference_dates = {
    "2025-01-06": ["clouds"],
    "2025-01-08": ["clouds"],
}

In [None]:
%%time
# choose a single date to do inference
target_date = overpass_dates[-1]
target_date = datetime.strptime(target_date, "%Y-%m-%d")

number_of_days = (end_date - start_date).days + 10
items_for_entire_phase = fetch_sentinel2_items_for_point(
    lat=lat,
    lon=lon,
    query_datetime=target_date,
    crop_size=crop_size,
    how_many_days_back=number_of_days,
    sbr_notebook=True,
)

main_data = crop_main_data(items_for_entire_phase, abs_client, s3_client, lat, lon, crop_size)
main_item = main_data["tile_item"]

In [None]:
for item in items_for_entire_phase:
    print(item.id)

### Tile info for all overpasses

In [None]:
tile_properties = TileInfo(
    instrument_name=main_item.item.properties["platform"],
    date_analysis=datetime.today().date().isoformat(),
    observation_date=main_item.time.date().isoformat(),
    observation_timestamp=main_item.time.time().isoformat(timespec="seconds"),
    start_time=None,
    end_time=None,
    imaging_mode=main_item.imaging_mode,
    off_nadir_angle=main_item.off_nadir_angle,
    viewing_azimuth=main_item.viewing_azimuth,
    solar_zenith=main_item.solar_zenith,
    solar_azimuth=main_item.solar_azimuth,
    orbit_state=main_item.orbit_state,
)

pprint.pp(tile_properties)
tile_properties.asdict()

In [None]:
tile_infos = []
for t in tqdm(overpass_dates):
    # print(f"Date: {t}")
    date = datetime.strptime(t, "%Y-%m-%d")

    items_for_tile_info = fetch_sentinel2_items_for_point(
        lat=lat, lon=lon, query_datetime=date, crop_size=crop_size, sbr_notebook=True
    )

    main_data = crop_main_data(items_for_tile_info, abs_client, s3_client, lat, lon, crop_size)
    main_item = main_data["tile_item"]

    tile_properties = TileInfo(
        instrument_name=main_item.item.properties["platform"],
        date_analysis=datetime.today().date().isoformat(),
        observation_date=main_item.time.date().isoformat(),
        observation_timestamp=main_item.time.time().isoformat(timespec="seconds"),
        start_time=None,
        end_time=None,
        imaging_mode=main_item.imaging_mode,
        off_nadir_angle=main_item.off_nadir_angle,
        viewing_azimuth=main_item.viewing_azimuth,
        solar_zenith=main_item.solar_zenith,
        solar_azimuth=main_item.solar_azimuth,
        orbit_state=main_item.orbit_state,
    )

    tile_infos.append(tile_properties.asdict())

df = pd.DataFrame(tile_infos, index=overpass_dates)

In [None]:
df.columns

In [None]:
for x in list(df.solar_azimuth):
    print(x)

### Load in Reference Data

In [None]:
# Use 25% clouds/shadows cutoff to include non-obvious reference images that may help
max_bad_pixel_perc = 25  # max % sum of clouds/cloud shadows/nodata in reference chips
num_snapshots = 40

In [None]:
# Print all dates in items_for_entire_phase:
for item in items_for_entire_phase:
    print(item.item.datetime.date().isoformat())

In [None]:
reference_data = crop_reference_data(
    items_for_entire_phase,
    main_data,
    abs_client,
    s3_client,
    lat,
    lon,
    crop_size,
    required_num_previous_snapshots=1000000,  # want all the reference images
    max_bad_pixel_perc=max_bad_pixel_perc,
)

print("\n---: Summary")
print(f"Main tile for {target_date}: USE {main_item.id}")
for reference in reference_data:
    tile_id = reference["tile_item"].id
    reference_date = reference["tile_item"].time.date().isoformat()
    print(f"    Reference tile on {reference_date}: USE {tile_id}")

## I. RGBs & Ratios

In [None]:
### Show all of them together
plot_all_rgb([main_data, *reference_data], SatelliteID.S2)
plot_all_ratio([main_data, *reference_data], SatelliteID.S2)

In [None]:
### Show them individually
data_items = [main_data, *reference_data]

for i, data in enumerate(data_items):
    date = data["tile_item"].time.date().isoformat()
    plot_rgb_ratio(data, Colorbar.INDIVIDUAL, i, SatelliteID.S2)

## Ratio Diffs

In [None]:
_ = plot_ratio_diffs(data_items, Colorbar.INDIVIDUAL, SatelliteID.S2, mean_adjust=True)

## II. Prediction

### Predict for a specific date

#### Select Reference Chips

In [None]:
print("Reference tile dates:")
for x in reference_data:
    reference_date = x["tile_item"].time.date().isoformat()
    if reference_date in overpass_dates:
        print(f"{reference_date} - possible SBR release (date is in Phase 1)")
    elif reference_date in phase0_release_dates:
        print(f"{reference_date} - known SBR release (date is in Phase 0)")
    else:
        print(f"{reference_date} - not a Phase 1 overpass date")

In [None]:
model_id = "1226"
model_idx = model_id.index(model_id)

before_date = "2025-03-29T17:59:09.024000+00:00"  # "2025-03-29"
earlier_date = "2025-03-27T18:10:11.025000+00:00"  # "2025-03-27"

reference_chips = select_reference_tiles_from_dates_str(
    reference_data, before_date=before_date, earlier_date=earlier_date
)
before_data = reference_chips[0]
earlier_data = reference_chips[1]

pd.DataFrame(
    {
        "tile": ["main", "before", "earlier"],
        "date": [
            main_data["tile_item"].time.date().isoformat(),
            before_data["tile_item"].time.date().isoformat(),
            earlier_data["tile_item"].time.date().isoformat(),
        ],
        "time": [
            main_data["tile_item"].time.time().isoformat(),
            before_data["tile_item"].time.time().isoformat(),
            earlier_data["tile_item"].time.time().isoformat(),
        ],
        "datetime": [main_data["tile_item"].time, before_data["tile_item"].time, earlier_data["tile_item"].time],
        "observation_angle": [
            main_data["tile_item"].observation_angle,
            before_data["tile_item"].observation_angle,
            earlier_data["tile_item"].observation_angle,
        ],
    }
)

#### Predict

In [None]:
%%time
watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.1,
    watershed_floor_threshold=0.075,
    closing_footprint_size=0,
)
prediction = predict(
    main_data,
    reference_chips,
    watershed_params,
    models[model_idx],
    device,
    band_concatenators[model_idx],
    lossFn,
    create_lookup_table=True,
)

# Or predict for all reference image combinations as inputs
# predictions = predict_for_all_pairs(
#     main_data, reference_data, models[model_idx], device, band_concatenators[model_idx],
#     lossFn, watershed_params, skip_retrieval=True,
# )

#### Plot Predictions

In [None]:
# Optional: if you have run the above code for all reference combinations
# preds = predictions[before_date, earlier_date]

# if predicting on larger than 128x128 and want to crop to center 128x128
row, col = crop_size // 2, crop_size // 2

half_crop = 64  # we want middle 128x128
xmin = row - half_crop
xmax = row + half_crop
ymin = col - half_crop
ymax = col + half_crop
extent = [xmin, xmax, ymin, ymax]

rgb_main = get_rgb_bands(main_data["crop_arrays"], BANDS)
rgb_before = get_rgb_bands(before_data["crop_arrays"], BANDS)
rgb_earlier = get_rgb_bands(earlier_data["crop_arrays"], BANDS)

ratio_main = get_band_ratio(main_data["crop_arrays"], BANDS)
ratio_before = get_band_ratio(before_data["crop_arrays"], BANDS)
ratio_earlier = get_band_ratio(earlier_data["crop_arrays"], BANDS)

date_main = main_data["tile_item"].time.date().isoformat()
date_before = before_data["tile_item"].time.date().isoformat()
date_earlier = earlier_data["tile_item"].time.date().isoformat()

plot = all_error_analysis_plots(
    rgb_main=rgb_main,
    rgb_before=rgb_before,
    rgb_earlier=rgb_earlier,
    ratio_main=ratio_main,
    ratio_before=ratio_before,
    ratio_earlier=ratio_earlier,
    predicted_frac=prediction.marginal,
    predicted_mask=prediction.mask,
    conditional_pred=prediction.conditional,
    binary_probability=prediction.binary_probability,
    conditional_retrieval=prediction.conditional_retrieval,
    masked_conditional_retrieval=prediction.masked_conditional_retrieval,
    rescaled_retrieval=None,  # We will plot this in the Retrieval and quantification section
    marginal_retrieval=prediction.marginal_retrieval,
    watershed_segmentation_params=watershed_params,
    dates=(date_main, date_before, date_earlier),
    ratio_colorbar=Colorbar.SHARE,  # (0.65, 1.0),
    ratio_diff_colorbar=Colorbar.SHARE,  # (-0.1, 0.1)
)

#### Export prediction to GeoTIFF

In [None]:
export_predition_and_mask_to_geotiff(
    unmasked_retrieval_mol_m2=prediction.marginal,
    binary_mask=prediction.mask.astype(np.uint8),
    main_item_transform=main_item.get_raster_meta("B12")["transform"],
    crop_start_x=main_data["crop_params"]["B12"]["crop_start_x"],
    crop_start_y=main_data["crop_params"]["B12"]["crop_start_y"],
    main_item_crs=CRS.from_string(main_item.crs),
    observation_date=target_date.date().isoformat(),
    satellite_name=SatelliteID.S2.name,
)

#### Quantification

In [None]:
# Get Retrieval for the biggest plume in the center

crop_x = main_data["crop_params"]["B12"]["crop_start_x"]
crop_y = main_data["crop_params"]["B12"]["crop_start_y"]
prediction.crop_x = crop_x
prediction.crop_y = crop_y
pred_info = prediction.asdict()

timestamp = main_item.time

crop_crs = CRS.from_string(main_item.crs)
window = rasterio.windows.Window(crop_x, crop_y, crop_size, crop_size)
crop_transform = rasterio.windows.transform(window, main_item.get_transform("B12"))

# Quantify plumes in retrieval
plume_list = quantify_retrieval(
    prediction.conditional_retrieval,
    crop_transform,
    crop_crs,
    prediction.binary_probability,
    timestamp.date().isoformat(),
    floor_t=watershed_params.watershed_floor_threshold,
    marker_t=watershed_params.marker_threshold,
)

if len(plume_list) == 0:
    print("No plumes detected")
else:
    print(f"Found {len(plume_list)} plumes")

    # Get Retrieval for the biggest plume in the center
    plume_list = sorted(plume_list, key=lambda x: x["properties"]["Q"])[::-1]
    center_plumes: PlumeInfo = []
    for plume in plume_list:
        plume_info = PlumeInfo(**plume["properties"])

        intersects = intersects_center(*plume_info.bbox, buffer=center_buffer)
        if intersects:
            pprint.pp(plume_info)
            center_plumes.append(plume)

    print(f"{len(center_plumes)} plumes intersect the center with {center_buffer} pixels of buffer")

### Heatmap of Max Probability + Binary Plot for Each Reference Chip Pairing

In [None]:
model_id = "1226"
model_idx = model_id.index(model_id)

predictions = predict_for_all_pairs(
    main_data,
    reference_data,
    models[model_idx],
    device,
    band_concatenators[model_idx],
    lossFn,
    watershed_params,
    skip_retrieval=False,
)
len(predictions)

In [None]:
dates = [d["tile_item"].time.date().isoformat() for d in reference_data]

plot_max_proba_center_buffer_heatmap(
    predictions=predictions,
    dates=dates,
    center_buffer=center_buffer,
    satellite_id="S2",
    title=f'Model {model_id} for main={target_date.isoformat().split("T")[0]}',
    show_topk=5,
    main_date=target_date.isoformat().split("T")[0],
    plot_binary_grid=True,
)

### Predict for multiple dates

In [None]:
%%time
main_data_all = {}
reference_data_all = {}
for item in stac_items[:]:
    target_date = item.datetime  # .date().isoformat()
    # target_date = datetime.strptime(target_date, "%Y-%m-%d")

    items = fetch_sentinel2_items_for_point(
        lat=lat, lon=lon, query_datetime=target_date, crop_size=crop_size, sbr_notebook=True
    )

    main_data = crop_main_data(items, abs_client, s3_client, lat, lon, crop_size)
    main_data_all[target_date] = main_data

    main_item = main_data["tile_item"]
    print(f"Main tile for {target_date}: USE {main_item.id}")
    reference_data = crop_reference_data(
        items,
        main_data,
        abs_client,
        s3_client,
        lat,
        lon,
        crop_size,
        required_num_previous_snapshots=num_snapshots,
        max_bad_pixel_perc=max_bad_pixel_perc,
    )
    reference_data_all[target_date] = reference_data
    print()

In [None]:
maxs, sums, avg_ref_counts = plot_normal_and_avg_strategy(
    stac_items,
    main_data_all,
    reference_data_all,
    phase0_release_dates,
    models,
    model_ids,
    band_concatenators,
    device,
    lossFn,
    watershed_params,
    center_buffer,
)

In [None]:
plot_normal_and_avg_strategy_summary(sums, avg_ref_counts, model_ids, ylim=[-2, 60], buffer_width=center_buffer * 2 + 1)

# In detail: Inspection and Submission of individual dates (Q1 2025)

In [None]:
# Prepare wind data once
from src.utils import PROJECT_ROOT

wind_data = pd.read_csv(PROJECT_ROOT / "src" / "data" / "ancillary" / "wind_vectors_gt_vs_models_2025_sbr_sites.csv")
wind_data = wind_data[wind_data["sensor"] == "S2"]
wind_data["sensing_time"] = pd.to_datetime(
    wind_data["date"] + " " + wind_data["overpass_time_utc"], utc=True
).dt.strftime("%Y-%m-%dT%H:%M:%S+0000")

In [None]:
ml_client, _, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

In [None]:
# choose a single date to do inference
# stac_idx = -9
target_date = "2025-02-29"
stac_idx = next(
    (i for i, item in enumerate(items_for_entire_phase) if item.item.datetime.date().isoformat() == target_date), -1
)

# stac_idx=8
target_date = items_for_entire_phase[stac_idx].time
# target_date = datetime.strptime(target_date, "%Y-%m-%d")
print(target_date)
print(stac_idx)

In [None]:
exclude_dates = [
    *phase0_release_dates,
    # "2025-04-05", # possible detection
    # "2025-03-24", # detection
    # "2025-03-22", # detection
    "2025-03-17",  # cloudy ref
    "2025-03-16",  # hazy
    "2025-03-14",  # dark ratio, mean 0.773 (normal is ~0.83)
    "2025-03-12",  # cloudy ref
    "2025-03-07",  # cloudy ref
    "2025-03-04",  # sus
    "2025-03-02",  # dark ref
    # "2025-02-25", # release date
    # "2025-02-22", # release date
    # "2025-02-20", # release date
    "2025-02-12",  # cloudy ref
    "2025-02-07",  # cloudy ref
    "2025-01-28",  # cloudy ref
    "2025-01-16",  # dark ratio, mean 0.814 (normal is ~0.83)
    "2025-01-08",  # cloudy ref
    "2025-01-06",  # cloudy ref
    "2024-12-27",  # dark ratio, mean 0.800 (normal is ~0.83)
    "2024-12-24",  # dark ratio, mean 0.814 (normal is ~0.83)
    "2024-12-07",  # dark ratio, mean 0.801 (normal is ~0.83)
]

In [None]:
# Prepare main data
main_data = crop_main_data(items_for_entire_phase, abs_client, s3_client, lat, lon, crop_size, main_idx=stac_idx)
main_item = main_data["tile_item"]
print(f"Main tile      {main_item.id}")

# Use reference images before/after the main date
max_days_difference = 30  # 60
reference_data_ = []
for reference in reference_data:
    tile_id = reference["tile_item"].id

    ref_date = reference["tile_item"].time
    ref_date_short = ref_date.isoformat().split("T")[0]
    if ref_date_short in exclude_dates:
        print(f"Reference img on {ref_date_short} is excluded         --> DONT USE")
        continue
    # Calculate the absolute difference in time
    time_difference = abs(main_item.time - ref_date)
    print(time_difference)
    print(time_difference.total_seconds())

    # Check if the difference in days is within the threshold
    # timedelta.days gives the difference purely in days (ignoring hours etc.)
    if time_difference.days <= max_days_difference and time_difference.total_seconds() > (max_seconds := 60):
        reference_data_.append(reference)
        print(f"Reference img on {ref_date_short} is {time_difference.days:3} days distant -->      USE")
    else:
        print(f"Reference img on {ref_date_short} is {time_difference.days:3} days distant --> DONT USE")
len(reference_data_)

In [None]:
# Plot Wind for Main Date
sensing_time = main_item.time  # .isoformat()
print(sensing_time)
wind_components = get_wind_components(wind_data, sensing_time)

u_wind_component, v_wind_component = wind_components["geos"]
wind_speed_geos = calc_effective_wind_speed(u_wind_component, v_wind_component)
wind_direction_geos = calc_wind_direction(u_wind_component, v_wind_component)

u_wind_component, v_wind_component = wind_components["era5"]
wind_speed_era5 = calc_effective_wind_speed(u_wind_component, v_wind_component)
wind_direction_era5 = calc_wind_direction(u_wind_component, v_wind_component)

print("GEOS Wind")
plot_wind(
    lon,
    lat,
    main_item,
    abs_client,
    wind_speed_geos,
    wind_direction_geos,
    satellite_id=SatelliteID.S2,
    arrow_scale_factor=500,
)
print("ERA5 Wind")
plot_wind(
    lon,
    lat,
    main_item,
    abs_client,
    wind_speed_era5,
    wind_direction_era5,
    satellite_id=SatelliteID.S2,
    arrow_scale_factor=500,
)

In [None]:
print(f"Main    date = {main_item.datetime_.isoformat().split('T')[0]}")
before_date = None
earlier_date = None
for reference in reference_data:
    tile_id = reference["tile_item"].id
    # ref_date = datetime.strptime(tile_id.split("_")[3], "%Y%m%d")
    ref_date = reference[
        "tile_item"
    ].time  # datetime.strptime(reference["tile_item"].item.datetime.date().isoformat(), "%Y-%m-%d")
    ref_date_short = ref_date.isoformat().split("T")[0]
    # Calculate the absolute difference in time
    time_difference = target_date - ref_date
    if time_difference.days <= 0:
        continue
    if ref_date_short in exclude_dates:
        print(f"Reference img on {ref_date_short} is excluded         --> DONT USE")
        continue
    if before_date is None:  # Closest is Before (t-1)
        before_date = ref_date  # .isoformat().split("T")[0]
        print(f"Before  date = {before_date}")
    elif earlier_date is None:  # Next closest is Earlier (t-2)
        earlier_date = ref_date  # .isoformat().split("T")[0]
        print(f"Earlier date = {earlier_date}")
        break

# Optionally overwrite them (need the exact timestamp as well, or will need to manually modify code
# i.e the dict keys in predict_for_all_pairs)
# from datetime import timezone
# before_date = datetime(2024, 10, 20, 18, 3, 9, 24000, tzinfo=timezone.utc)
# earlier_date = datetime(2024, 9, 25, 18, 1, 11, 24000, tzinfo=timezone.utc)
print(f"Before  date = {before_date.isoformat()}")
print(f"Earlier date = {earlier_date.isoformat()}")

## Inspecting every combination of model and reference images

In [None]:
row, col = crop_size // 2, crop_size // 2

half_crop = 64  # we want middle 128x128
xmin = row - half_crop
xmax = row + half_crop
ymin = col - half_crop
ymax = col + half_crop
extent = [xmin, xmax, ymin, ymax]

In [None]:
avg_reference_days_diffs = [30, 20, 15, 10, 5]

# Visualize predictions using all reference image combinations as inputs
output_preds = {}
dates = [d["tile_item"].time for d in reference_data_]
for model_idx, model_id in enumerate(model_ids):
    main_date = target_date.isoformat().split("T")[0]
    avg_preds = {}
    for avg_reference_days_diff in avg_reference_days_diffs:
        # Use reference images before/after the main date
        reference_data_for_avg = get_reference_data(
            reference_data, target_date, max_days_difference=avg_reference_days_diff
        )
        # Data preparation for predicting with average last 10 reference images
        valid_refs = [
            ref for ref in reference_data_for_avg if ref["tile_item"].time.date().isoformat() not in exclude_dates
        ]
        print(
            f"Ref images {avg_reference_days_diff} before/after: Filtered {len(reference_data_for_avg)} refs to "
            f"{len(valid_refs)} after excluding dates"
        )
        avg_ref_count = len(valid_refs)
        data_avg = copy.copy(reference_data_for_avg[0])
        data_avg["crop_arrays"] = np.mean([ref["crop_arrays"] for ref in valid_refs], axis=0)
        center_y = data_avg["crop_arrays"][:, xmin:xmax, ymin:ymax].shape[1] // 2
        center_x = data_avg["crop_arrays"][:, xmin:xmax, ymin:ymax].shape[2] // 2

        f, ax = plt.subplots(1, 4, figsize=(20, 5))
        preds_avg = predict(
            main_data,
            [data_avg, data_avg],
            watershed_params,
            models[model_idx],
            device,
            band_concatenators[model_idx],
            lossFn,
        )
        ratio_main = get_band_ratio(preds_avg.x_dict["crop_main"][0], BANDS)[xmin:xmax, ymin:ymax]  # type:ignore
        vmin = np.percentile(ratio_main, 0.5)
        vmax = np.percentile(ratio_main, 99.5)

        ratio_before = get_band_ratio(preds_avg.x_dict["crop_before"][0], BANDS)[xmin:xmax, ymin:ymax]  # type:ignore
        ratio_earlier = get_band_ratio(preds_avg.x_dict["crop_earlier"][0], BANDS)[xmin:xmax, ymin:ymax]  # type:ignore

        ax[0].imshow(ratio_earlier, vmin=vmin, vmax=vmax)
        ax[0].set_title(
            f"Avg({avg_reference_days_diff} days before/after main) = {avg_ref_count} refs\n"
            f"Min {ratio_before.min():.2f} Max {ratio_before.max():.2f} "
            f"Mean {ratio_earlier.mean():.3f}",
            fontsize=15,
        )
        ax[1].imshow(ratio_before, vmin=vmin, vmax=vmax)
        ax[1].set_title(
            f"Avg({avg_reference_days_diff} days before/after main) = {avg_ref_count} refs\n"
            f"Min {ratio_before.min():.2f} Max {ratio_before.max():.2f} "
            f"Mean {ratio_before.mean():.3f}",
            fontsize=15,
        )
        ax[2].imshow(ratio_main, vmin=vmin, vmax=vmax)
        ax[2].set_title(
            f"Main Ratio {main_date}\nMin {ratio_main.min():.2f} Max {ratio_main.max():.2f} "
            f"Mean {ratio_main.mean():.3f}",
            fontsize=15,
        )
        ax[3].imshow(preds_avg.binary_probability[xmin:xmax, ymin:ymax], vmin=0.0, vmax=1.0, cmap="hot_r")

        center = get_center_buffer(preds_avg.binary_probability, center_buffer)
        avg_preds[avg_reference_days_diff] = {
            "avg_ref_count": avg_ref_count,
            "max_prob": 100 * center.max(),
            "sum_prob": 100 * center.sum() / (center.shape[0] * center.shape[1]),
        }
        ax[3].set_title(
            f"Avg({avg_reference_days_diff} days before/after main)\nCenter sum(Prob): "
            f"{100 * center.sum() / (center.shape[0] * center.shape[1]):.1f}%, "
            f"Max: {100 * center.max():.0f}%",
            fontsize=16,
        )
        ax[0].scatter(center_x, center_y, color="green", marker="x")
        ax[1].scatter(center_x, center_y, color="green", marker="x")
        ax[2].scatter(center_x, center_y, color="green", marker="x")
        ax[3].scatter(center_x, center_y, color="green", marker="x")
        grid16(ax[0])
        grid16(ax[1])
        grid16(ax[2])
        grid16(ax[3])
        plt.tight_layout()
        plt.show()

    print(f"Model {model_id}")
    for k, v in avg_preds.items():
        print(
            f"Avg{k}   Center sum(Prob): {v['sum_prob']:4.1f}%, "
            f"Max: {v['max_prob']:3.0f}% (using avg of {v['avg_ref_count']} ref images)"
        )
    print("=" * 50)

    predictions = predict_for_all_pairs(
        main_data,
        reference_data_,
        models[model_idx],
        device,
        band_concatenators[model_idx],
        lossFn,
        watershed_params,
        skip_retrieval=True,
    )
    output_preds[model_id] = plot_max_proba_center_buffer_heatmap(
        predictions=predictions,
        dates=dates,
        center_buffer=center_buffer,
        title=f"Model {model_id} for main={main_date}",
        show_topk=5,
        topk_based_on="max",  # or "sum" to use the sum of center probabilities
        satellite_id=SatelliteID.S2,
        main_date=target_date.isoformat().split("T")[0],
        before_date=before_date,
        earlier_date=earlier_date,
        plot_topk=True,  # True,
        plot_max_grid=True,  # True,
        plot_binary_grid=False,
        extent=extent,
    )
    print("\n\n")

1. Validate and visualize the plumes and emission rates of "normal" predictions, topk predictions or any other combination. Build the combination list `combinations_to_visualize` you want to test and call `validate_pred_retrievals(combinations_to_visualize)`.
2. Manually select the combination of model id + before/earlier dates we will use for L/IME/Emission Rate Calculation + Unmasked Retrieval and plume mask .tif outputs. Put this into the decision_dict for the main date, e.g. `'selected_retrieval': {'model_id': '37', 'date_before': '2024-12-25', 'date_earlier': '2024-12-18'}`
    1. Set `feasible` to True if we are able to predict for this main date, i.e. its not cloudy/cloud shadowy
    2. Set `note` to "no_detection" or "detection" and any other notes you have.
    3. Set `watershed_marker_t` and `watershed_floor_t` to your selected watershed parameters if you want different parameters from the default `watershed_marker_t=0.2`,`watershed_floor_t=0.15`
    4. Set `selected_retrieval` as described above
    5. If you have a detection: Set `wind_source` to "era5" or "geos"
    6. If you have a detection: Set `emission_ensemble_selections` to a list of model id + before/earlier dates whose emission rates will be used to calculate our emission rate uncertainty. __Don't put the selected retrieval in here.__
4. Run `sbr_form_outputs()` to get the values ready for copying into the form fillout. This will also export the needed unmasked retrieval and the binary plume mask .tifs from your `selected_retrieval`

## Retrieval and quantification

In [None]:
# Run once to initialize the RadTranLookup for the main date
granule_item = main_data["tile_item"]
if isinstance(granule_item, LandsatGranuleAccess):
    hapi_data_path = LANDSAT_HAPI_DATA_PATH
elif isinstance(granule_item, Sentinel2Item):
    hapi_data_path = S2_HAPI_DATA_PATH
else:
    raise ValueError(f"Unsupported granule access type: {type(granule_item)}")
lookup_table = RadTranLookupTable.from_params(
    instrument=granule_item.instrument,
    solar_angle=granule_item.solar_angle,
    observation_angle=granule_item.observation_angle,
    hapi_data_path=hapi_data_path,
    min_ch4=0.0,
    max_ch4=21.0,  # this value was selected based on the common value ranges of the sim plume datasets
    spacing_resolution=40000,
    ref_band=granule_item.swir16_band_name,
    band=granule_item.swir22_band_name,
    full_sensor_name=granule_item.sensor_name,
)

In [None]:
# Build the combinations you want to validate/visualize
combinations_to_visualize = []
for model_id, v in output_preds.items():
    if model_id in []:  # skip certain models
        continue
    top_n = 3
    for j in range(1, top_n + 1):  # 6): # use topX dates
        combinations_to_visualize.append(
            {
                "model_id": model_id,
                "date_before": v[f"top_{j}"]["date_before"],
                "date_earlier": v[f"top_{j}"]["date_earlier"],
            }
        )

    # visualize the "normal" predictions
    combinations_to_visualize.append(
        {"model_id": model_id, "date_before": v["normal"]["date_before"], "date_earlier": v["normal"]["date_earlier"]}
    )
combinations_to_visualize

In [None]:
# convenience cell to prepare the selected ensemble
# for copy-pasting into the decisions dict
import rich  # may need to pip install

# Sort combinations_to_visualize by model_id, then date_before, then date_earlier
# remove selected model from ensemble
combinations_for_ensemble = combinations_to_visualize[:-1]
combinations_for_ensemble = sorted(
    combinations_for_ensemble, key=lambda x: (x["model_id"], x["date_before"], x["date_earlier"])
)

rich.print(combinations_for_ensemble)

In [None]:
watershed_params = WatershedParameters(
    marker_distance=1,
    marker_threshold=0.2,
    watershed_floor_threshold=0.05,
    closing_footprint_size=0,
)
max_distance_pixels = 10
pixel_width = 30
wind_speed = wind_speed_era5  # or wind_speed_geos

validate_pred_retrievals(
    combinations_to_visualize,
    main_data,
    reference_data,
    model_ids,
    watershed_params,
    models,
    device,
    band_concatenators,
    lossFn,
    lookup_table,
    wind_speed,
    satellite_id=SatelliteID.S2,
    max_distance_pixels=max_distance_pixels,
    pixel_width=pixel_width,
    show_plots=False,
)
validate_pred_retrievals(
    combinations_to_visualize,
    main_data,
    reference_data,
    model_ids,
    watershed_params,
    models,
    device,
    band_concatenators,
    lossFn,
    lookup_table,
    wind_speed,
    satellite_id=SatelliteID.S2,
    max_distance_pixels=max_distance_pixels,
    pixel_width=pixel_width,
    show_plots=True,
    extent=extent,
    skip_plot_if_no_central_plume=False,
)

In [None]:
before_date, earlier_date

## Decisions and outputs

In [None]:
decisions_dict = {
    "2025-01-01T18:16:49.024000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1422",
            "date_before": "2024-12-14T18:07:51.024000+00:00",
            "date_earlier": "2024-12-12T18:16:59.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-03T18:07:41.024000+00:00": {
        "note": "no_detection",
        "feasible": True,
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-01-01T18:16:49.024000+00:00",
            "date_earlier": "2024-12-14T18:07:51.024000+00:00",
        },
    },
    "2025-01-06T18:17:31.024000+00:00": {
        "feasible": False,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-08T18:06:29.024000+00:00": {
        "feasible": False,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-11T18:16:29.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1422",
            "date_before": "2025-01-26T18:16:51.025000+00:00",
            "date_earlier": "2025-01-26T18:16:51.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-13T18:07:11.024000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-01-11T18:16:29.024000+00:00",
            "date_earlier": "2025-01-03T18:07:41.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-16T18:17:11.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1486",
            "date_before": "2025-01-26T18:16:51.025000+00:00",
            "date_earlier": "2025-01-21T18:15:49.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-18T18:05:59.024000+00:00": {
        "feasible": True,
        "note": "no-detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-01-13T18:07:11.024000+00:00",
            "date_earlier": "2025-01-11T18:16:29.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-21T18:15:49.024000+00:00": {
        "feasible": True,
        "note": "no-detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-01-18T18:05:59.024000+00:00",
            "date_earlier": "2025-01-13T18:07:11.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-26T18:16:51.025000+00:00": {
        "feasible": True,
        "note": "no-detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-01-21T18:15:49.024000+00:00",
            "date_earlier": "2025-01-18T18:05:59.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-28T18:05:19.024000+00:00": {
        "feasible": False,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-01-31T18:15:09.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1422",
            "date_before": "2025-01-01T18:16:49.024000+00:00",
            "date_earlier": "2025-02-15T18:15:01.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-02T18:06:01.024000+00:00": {
        "feasible": False,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-05T18:16:01.025000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-01-31T18:15:09.024000+00:00",
            "date_earlier": "2025-01-26T18:16:51.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-07T18:04:19.024000+00:00": {
        "feasible": False,  # CLOUDS
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-12T18:05:11.025000+00:00": {
        "feasible": False,  # CLOUDS
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-15T18:15:01.024000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1475",
            "date_before": "2025-02-05T18:16:01.025000+00:00",
            "date_earlier": "2025-01-31T18:15:09.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-17T18:03:19.024000+00:00": {
        "feasible": False,  # CLOUDS
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-20T18:13:09.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1340",
            "date_before": "2025-01-21T18:15:49.024000+00:00",
            "date_earlier": "2025-03-09T18:00:59.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-22T18:04:01.025000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1486",
            "date_before": "2025-02-05T18:16:01.025000+00:00",
            "date_earlier": "2025-02-25T18:13:51.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-25T18:13:51.025000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1486",
            "date_before": "2025-02-22T18:04:01.025000+00:00",
            "date_earlier": "2025-02-20T18:13:09.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-02-27T18:02:09.024000+00:00": {
        "feasible": False,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-02T18:11:49.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1475",
            "date_before": "2025-02-15T18:15:01.024000+00:00",
            "date_earlier": "2025-03-29T17:59:09.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.16,
    },
    "2025-03-04T18:02:51.025000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1475",
            "date_before": "2025-03-16T18:07:51.024000+00:00",
            "date_earlier": "2025-03-29T17:59:09.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-07T18:12:41.025000+00:00": {
        "feasible": False,  # CLOUDS
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-09T18:00:59.024000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-02-25T18:13:51.025000+00:00",
            "date_earlier": "2025-02-22T18:04:01.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-12T18:10:39.024000+00:00": {
        "feasible": False,  # CLOUDS
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-14T18:01:41.024000+00:00": {
        "feasible": False,  # WEIRD DARK ALBEDO
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-16T18:07:51.024000+00:00": {
        "feasible": False,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1486",
            "date_before": "2025-03-24T18:00:31.025000+00:00",
            "date_earlier": "2025-03-04T18:02:51.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-17T18:11:21.025000+00:00": {
        "feasible": False,  # CLOUDS
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {},
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-19T17:59:49.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1226",
            "date_before": "2025-03-29T17:59:09.024000+00:00",
            "date_earlier": "2025-04-06T18:09:41.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-22T18:09:29.024000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1486",
            "date_before": "2025-04-11T18:09:19.024000+00:00",
            "date_earlier": "2025-03-16T18:07:51.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-24T18:00:31.025000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1486",
            "date_before": "2025-03-29T17:59:09.024000+00:00",
            "date_earlier": "2025-03-22T18:09:29.024000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-27T18:10:11.025000+00:00": {
        "feasible": True,
        "note": "detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1422",
            "date_before": "2025-04-11T18:09:19.024000+00:00",
            "date_earlier": "2025-04-06T18:09:41.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-29T17:59:09.024000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1422",
            "date_before": "2025-03-27T18:10:11.025000+00:00",
            "date_earlier": "2025-03-24T18:00:31.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
    "2025-03-29T18:17:51.024000+00:00": {
        "feasible": True,
        "note": "no_detection",
        "wind_source": "era5",
        "emission_ensemble_selections": [],
        "selected_retrieval": {
            "model_id": "1475",
            "date_before": "2025-03-27T18:10:11.025000+00:00",
            "date_earlier": "2025-03-24T18:00:31.025000+00:00",
        },
        "watershed_marker_t": 0.2,
        "watershed_floor_t": 0.15,
    },
}

In [None]:
for item in items_for_entire_phase:
    print(item.time.isoformat())

In [None]:
# this can take some time, and we don't want the client to expire
import shutil
from pathlib import Path

from src.utils import PROJECT_ROOT

# delete current submission outputs so we can overwrite
output_dir = PROJECT_ROOT / "sbr_2025" / "notebooks" / "data" / "submission_geotiffs" / "phase1_submission"
if output_dir.exists() and output_dir.is_dir():
    shutil.rmtree(output_dir)

output_dir.mkdir(parents=True, exist_ok=True)
assert output_dir.is_dir()

ml_client, _, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

main_data = crop_main_data(items_for_entire_phase, abs_client, s3_client, lat, lon, crop_size, main_idx=0)
print(main_data["tile_item"].time)

reference_data = crop_reference_data(
    items_for_entire_phase,
    main_data,
    abs_client,
    s3_client,
    lat,
    lon,
    crop_size,
    required_num_previous_snapshots=num_snapshots,
    max_bad_pixel_perc=max_bad_pixel_perc,
)

for dt, decision in decisions_dict.items():
    if not decision.get("feasible"):
        continue

    stac_idx = next((i for i, item in enumerate(items_for_entire_phase) if item.time.isoformat() == dt), -1)
    main_data = crop_main_data(items_for_entire_phase, abs_client, s3_client, lat, lon, crop_size, main_idx=stac_idx)
    print(f"----: {dt} - {stac_idx}")

    sbr_form_outputs(
        decision,
        wind_data,
        model_ids,
        models,
        device,
        band_concatenators,
        main_data,
        reference_data,  # send through all reference data
        lossFn,
        lookup_table,
        max_distance_pixels,
        pixel_width,
        SatelliteID.S2,
        abs_client,
    )
    print()

In [None]:
# copy-paste output path of enhancement tif to visualise


for tif in sorted(list(Path("data/submission_geotiffs/phase1_submission/").glob("**/*.tif"))):
    with rasterio.open(tif) as ds:
        enhancement = ds.read(1)
        plt.imshow(enhancement, cmap="plasma", vmin=0, vmax=2)  # using same color scale as in the 2024 Sherwin paper
        plt.title(tif.name)
        plt.colorbar(label="Enhancement (ppm)")
        plt.show()