In [None]:
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from lib.models.schemas import WatershedParameters
from lib.plume_masking import retrieval_mask_using_watershed_algo
from scipy import ndimage as ndi
from skimage.measure import label, regionprops
from skimage.segmentation import watershed

from sbr_2025.utils.plotting import (
    Colorbar,
    plot_rgb_ratio,
)
from sbr_2025.utils.quantification import (
    calculate_circle_distance_quantification,
    calculate_major_axis_quantification,
    calculate_sqrtA_quantification,
    find_central_plume,
    quantification_interval,
)
from src.azure_wrap.ml_client_utils import (
    initialize_blob_service_client,
)
from src.inference.inference_functions import (
    crop_main_data,
    crop_reference_data,
    fetch_sentinel2_items_for_point,
    generate_predictions,
)
from src.inference.inference_target_location import (
    add_retrieval_to_pred,
)
from src.plotting.plotting_functions import grid16
from src.training.loss_functions import TwoPartLoss
from src.utils import PACKAGE_ROOT
from src.utils.parameters import SatelliteID
from src.utils.quantification_utils import calc_wind_direction
from src.utils.utils import initialize_clients, load_model_and_concatenator

## 0. Setup

In [None]:
ml_client, _, _, s3_client = initialize_clients(False)
abs_client = initialize_blob_service_client(ml_client)

In [None]:
lat, lon = 32.82175, -111.78581  # release point for SBRs

model_identifier = "models:/torchgeo_pwr_unet/1226"

# can we predict on larger chips?
crop_size = 128
num_snapshots = 6
center_buffer = 10  # number of pixels to search from center

# Date range of SBR Phase 1
start_date = "2025-01-01"
end_date = "2025-03-31"

start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
model, band_concatenator, train_params = load_model_and_concatenator(model_identifier, "cpu", SatelliteID.S2)
model = model.to(device)
model.eval()

lossFn = TwoPartLoss(train_params["binary_threshold"], train_params["MSE_multiplier"])

### Overpass Dates

In [None]:
# these dates have KNOWN releases during Phase 0
phase0_release_dates = [
    "2024-11-14",
    "2024-11-19",
    "2024-11-22",
    "2024-12-02",
    "2024-12-04",
    "2024-12-07",
    "2024-12-09",
    "2024-12-17",
    "2024-12-19",
    "2024-12-22",
    "2024-12-29",
]

In [None]:
cloudy_dates = ["2025-01-06", "2025-01-08", "2024-12-14", "2024-11-29"]

In [None]:
%%time
# choose a single date to do inference
target_date = phase0_release_dates[-1]
target_date = datetime.strptime(target_date, "%Y-%m-%d")

items = fetch_sentinel2_items_for_point(
    lat=lat, lon=lon, query_datetime=target_date, crop_size=crop_size, sbr_notebook=True
)

main_data = crop_main_data(items, abs_client, s3_client, lat, lon, crop_size)
main_item = main_data["tile_item"]

# load all the reference data once, and we'll reuse it for every target date
num_snapshots = 24
reference_data = crop_reference_data(
    items, main_data, abs_client, s3_client, lat, lon, crop_size, required_num_previous_snapshots=num_snapshots
)

print("\n---: Summary")
print(f"Main tile for {target_date}: USE {main_item.id}")
for reference in reference_data:
    tile_id = reference["tile_item"].id
    reference_date = reference["tile_item"].time.date().isoformat()
    print(f"    Reference tile on {reference_date}: USE {tile_id}")

In [None]:
for item in items:
    print(f"ID: {item.id}")
    print(f"Time: {item.time}")
    print("---")

## I. RGBs & Ratios

In [None]:
data_items = [main_data, *reference_data]

for i, data in enumerate(data_items):
    date = data["tile_item"].time.date().isoformat()
    plot_rgb_ratio(data, Colorbar.INDIVIDUAL, i, SatelliteID.S2)

# Iterate over all Phase 0 release dates

In [None]:
wind_data_csv_path = PACKAGE_ROOT.parent / "src" / "data" / "ancillary" / "wind_data_with_era5.csv"
wind_data_df = pd.read_csv(wind_data_csv_path)
wind_data_df.head()

## Getting all predictions

In [None]:
# Load the ground truth plumes CSV

ground_truth_csv_path = PACKAGE_ROOT.parent / "src" / "data" / "ancillary" / "ground_truth_plumes.csv"
ground_truth_df = pd.read_csv(ground_truth_csv_path)

# Filter for SBR Phase 0 plumes (2024 Casa Grande plumes with non-zero quantification)
phase0_plumes = ground_truth_df[
    (ground_truth_df["site"] == "Casa Grande, AZ")
    & (ground_truth_df["date"].str.startswith("2024"))
    & (ground_truth_df["quantification_kg_h"] > 0)
]

print(f"Found {len(phase0_plumes)} SBR Phase 0 plumes with quantification data")
print(phase0_plumes[["date", "quantification_kg_h"]])

manual_reference_selections = {
    "2024-12-29": ("2024-12-24", "2024-12-22"),
}
predictions = {}
for _, plume_row in phase0_plumes.iterrows():
    plume_date_str = plume_row["date"]
    plume_date = datetime.strptime(plume_date_str, "%Y-%m-%d")
    ground_truth_kg_h = plume_row["quantification_kg_h"]

    print(f"\nProcessing plume from {plume_date_str} with ground truth {ground_truth_kg_h} kg/h")

    # Find the main data for this plume date
    main_tile = None
    for ref in reference_data:
        ref_date = ref["tile_item"].time.date().isoformat()
        if ref_date == plume_date_str:
            main_tile = ref
            print(f"Found main tile for {plume_date_str}: {ref['tile_item'].id}")
            break

    if main_tile is None:
        # Check if this plume date matches the main data loaded above
        if main_data["tile_item"].time.date().isoformat() == plume_date_str:
            main_tile = main_data
            print(f"Found main tile in main_data for {plume_date_str}: {main_data['tile_item'].id}")
        else:
            print(f"No tile found for plume date {plume_date_str}, skipping")
            continue
    # Find two reference tiles that don't have controlled releases
    clean_refs = []

    # Check if we have manual reference selections for this date
    if plume_date_str in manual_reference_selections:
        before_date, earlier_date = manual_reference_selections[plume_date_str]
        print(f"Using manually specified reference dates: {before_date} and {earlier_date}")

        # Find the reference tiles matching the manual selections
        for ref in reference_data:
            ref_date = ref["tile_item"].time.date().isoformat()
            if ref_date in (before_date, earlier_date):
                clean_refs.append(ref)
    else:
        # Use the automatic selection logic
        for ref in reference_data:
            ref_date = ref["tile_item"].time.date().isoformat()
            if ref_date in cloudy_dates:
                continue
            if ref_date in phase0_release_dates:
                continue
            if ref_date >= plume_date_str:
                continue
            clean_refs.append(ref)

    if len(clean_refs) < 2:  # noqa: PLR2004
        print(f"Not enough clean reference tiles found for {plume_date_str}, skipping")
        continue

    # Sort by date (newest first) and take the two most recent
    clean_refs.sort(key=lambda x: x["tile_item"].time.date().isoformat(), reverse=True)
    reference_chips = clean_refs[:2]

    before_data = reference_chips[0]
    earlier_data = reference_chips[1]

    print("Using reference tiles from:")
    print(f"  Before: {before_data['tile_item'].time.date().isoformat()}")
    print(f"  Earlier: {earlier_data['tile_item'].time.date().isoformat()}")

    # Generate predictions
    # try:
    preds = generate_predictions(main_tile, reference_chips, model, device, band_concatenator, lossFn)
    preds = add_retrieval_to_pred(preds, main_tile["tile_item"])
    # Add timestamp to predictions
    preds["timestamp"] = main_tile["tile_item"].time
    preds["ground_truth_kg_h"] = ground_truth_kg_h

    predictions[plume_date_str] = preds

    # except Exception as e:
    #     print(f"Error processing {plume_date_str}: {e}")

In [None]:
# Visualize all binary predictions in a grid
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.flatten()

for i, (date, pred) in enumerate(predictions.items()):
    ax = axes[i]
    binary_mask = pred["binary_probability"]
    ax.imshow(binary_mask, cmap="pink_r", vmin=0, vmax=1, interpolation="nearest")
    grid16(ax=ax)
    ax.grid(True, which="both")
    ax.set_title(f"Date: {date}")

plt.tight_layout()

## Functions implementing quantification methods

In [None]:
def calculate_all_quantification_methods(
    plume_labels, binary_probability, conditional_retrieval, pixel_width, wind_speed, max_distance_pixels=10
):
    """
    Calculate methane quantification using all methods.

    Args:
        plume_labels: Labeled mask of plumes
        binary_probability: Binary probability map
        conditional_retrieval: Conditional retrieval values
        pixel_width: Width of a pixel in meters
        wind_speed: Wind speed in m/s
        max_distance_pixels: Maximum distance in pixels for circle method

    Returns
    -------
        dict: Dictionary with all quantification results
    """
    # Find the central plume
    nearest_plume_label = find_central_plume(plume_labels, max_distance_pixels, pixel_width)
    if nearest_plume_label == 0:
        raise ValueError(f"No central plume found within {max_distance_pixels} pixels")
    central_plume_mask = plume_labels == nearest_plume_label

    # Calculate rescaled retrieval
    max_binary_probability = np.max(binary_probability[central_plume_mask])
    rescaled_retrieval = conditional_retrieval * binary_probability / max_binary_probability

    # Calculate quantification using conditional retrieval
    # Square root area method
    L_sqrtA_conditional, IME_sqrtA_conditional, Q_sqrtA_conditional = calculate_sqrtA_quantification(
        central_plume_mask, conditional_retrieval, pixel_width, wind_speed
    )

    # Major axis method
    L_major_conditional, IME_major_conditional, Q_major_conditional = calculate_major_axis_quantification(
        central_plume_mask, conditional_retrieval, pixel_width, wind_speed
    )

    # Circle distance method
    L_circle_conditional, IME_circle_conditional, Q_circle_conditional = calculate_circle_distance_quantification(
        central_plume_mask,
        conditional_retrieval,
        pixel_width,
        wind_speed,
        min_distance_pixels=5,
        max_distance_pixels=20,
    )

    # Calculate quantification using rescaled retrieval
    # Square root area method
    L_sqrtA_rescaled, IME_sqrtA_rescaled, Q_sqrtA_rescaled = calculate_sqrtA_quantification(
        central_plume_mask, rescaled_retrieval, pixel_width, wind_speed
    )

    # Major axis method
    L_major_rescaled, IME_major_rescaled, Q_major_rescaled = calculate_major_axis_quantification(
        central_plume_mask, rescaled_retrieval, pixel_width, wind_speed
    )

    # Circle distance method
    L_circle_rescaled, IME_circle_rescaled, Q_circle_rescaled = calculate_circle_distance_quantification(
        central_plume_mask, rescaled_retrieval, pixel_width, wind_speed, max_distance_pixels
    )

    # Return all results
    return {
        # Square root area method - conditional
        "L_sqrtA_conditional_m": L_sqrtA_conditional,
        "IME_sqrtA_conditional_mol": IME_sqrtA_conditional,
        "Q_sqrtA_conditional_kg_h": Q_sqrtA_conditional,
        # Major axis method - conditional
        "L_major_conditional_m": L_major_conditional,
        "IME_major_conditional_mol": IME_major_conditional,
        "Q_major_conditional_kg_h": Q_major_conditional,
        # Circle method - conditional
        "L_circle_conditional_m": L_circle_conditional,
        "IME_circle_conditional_mol": IME_circle_conditional,
        "Q_circle_conditional_kg_h": Q_circle_conditional,
        # Square root area method - retrieval
        "L_sqrtA_rescaled_m": L_sqrtA_rescaled,
        "IME_sqrtA_rescaled_mol": IME_sqrtA_rescaled,
        "Q_sqrtA_rescaled_kg_h": Q_sqrtA_rescaled,
        # Major axis method - retrieval
        "L_major_rescaled_m": L_major_rescaled,
        "IME_major_rescaled_mol": IME_major_rescaled,
        "Q_major_rescaled_kg_h": Q_major_rescaled,
        # Circle method - retrieval
        "L_circle_rescaled_m": L_circle_rescaled,
        "IME_circle_rescaled_mol": IME_circle_rescaled,
        "Q_circle_rescaled_kg_h": Q_circle_rescaled,
    }

In [None]:
def get_wind_data_from_csv(wind_data_df: pd.DataFrame, date_str: str, sensor: str = "S2") -> dict:
    """
    Get all available wind data from the CSV dataframe for a specific date and sensor.

    Args:
        wind_data_df: DataFrame containing wind data
        date_str: Date string in format 'YYYY-MM-DD'
        sensor: Sensor type (default: "S2" for Sentinel 2)

    Returns
    -------
        dict: Dictionary with keys 'geos_fp', 'era5', and 'ground_truth', each containing (wind_speed, wind_direction)
              NaN values are propagated automatically if data is not available
    """
    # Find the matching row in the wind data DataFrame
    wind_row = wind_data_df[(wind_data_df["date"] == date_str) & (wind_data_df["sensor"] == sensor)]

    if len(wind_row) == 0:
        raise ValueError(f"No wind data found for date={date_str}, sensor={sensor}")

    wind_data = {}

    # Get GEOS-FP wind data
    u_wind_geos = wind_row["geos_ux"].values[0]
    v_wind_geos = wind_row["geos_uy"].values[0]
    wind_speed_geos = np.sqrt(u_wind_geos**2 + v_wind_geos**2)
    wind_direction_geos = calc_wind_direction(u_wind_geos, v_wind_geos)
    wind_data["geos_fp"] = (wind_speed_geos, wind_direction_geos)

    # Get ERA5 wind data
    u_wind_era5 = wind_row["era5_ux"].values[0]
    v_wind_era5 = wind_row["era5_uy"].values[0]
    wind_speed_era5 = np.sqrt(u_wind_era5**2 + v_wind_era5**2)
    wind_direction_era5 = calc_wind_direction(u_wind_era5, v_wind_era5)
    wind_data["era5"] = (wind_speed_era5, wind_direction_era5)

    # Get ground truth wind data
    u_wind_gt = wind_row["gt_ux"].values[0]
    v_wind_gt = wind_row["gt_uy"].values[0]
    wind_speed_gt = np.sqrt(u_wind_gt**2 + v_wind_gt**2)
    wind_direction_gt = calc_wind_direction(u_wind_gt, v_wind_gt)
    wind_data["ground_truth"] = (wind_speed_gt, wind_direction_gt)

    return wind_data

## Obtaining all quantifications

In [None]:
def two_pass_watershed(binary_probability, watershed_floor_threshold=0.075, initial_mask_threshold=0.5, connectivity=2):
    """
    Perform watershed segmentation on a binary probability map using maximum probability points as markers.

    Args:
        binary_probability: Probability map of plume detection
        watershed_floor_threshold: Lower threshold for watershed segmentation (default: 0.075)
        initial_mask_threshold: Threshold for initial binary mask creation (default: 0.5)
        connectivity: Connectivity for labeling (default: 2)

    Returns
    -------
        Segmented mask after watershed algorithm
    """
    # First pass masking: create binary mask from probability
    predicted_mask = binary_probability > initial_mask_threshold

    # Label connected regions
    labeled_mask = label(predicted_mask, connectivity=connectivity)

    # Create a marker mask with maximum probability point in each plume region
    marker_mask = np.zeros_like(predicted_mask)
    for region in regionprops(labeled_mask):
        # Get coordinates for this plume region
        coords = region.coords  # (row, col) coordinates of all points in region

        # Get probability values for this region
        region_probs = binary_probability[coords[:, 0], coords[:, 1]]

        # Find position of maximum probability in this region
        max_prob_idx = np.argmax(region_probs)
        max_prob_position = coords[max_prob_idx]

        # Mark the maximum point in the marker mask
        marker_mask[max_prob_position[0], max_prob_position[1]] = 1

    # Generate the markers for watershed
    markers = ndi.label(marker_mask)[0]

    # Run watershed with markers
    segmented_mask = watershed(-binary_probability, markers, mask=binary_probability > watershed_floor_threshold)

    return segmented_mask

In [None]:
# Update the results processing to include all masking algorithms
all_results = []
masking_algorithms = [
    (
        "strict_watershed",
        lambda bp: label(
            retrieval_mask_using_watershed_algo(
                WatershedParameters(
                    marker_distance=1, marker_threshold=0.1, watershed_floor_threshold=0.1, closing_footprint_size=0
                ),
                bp,
            )
        ),
    ),
    (
        "sensitive_watershed",
        lambda bp: label(
            retrieval_mask_using_watershed_algo(
                WatershedParameters(
                    marker_distance=1, marker_threshold=0.1, watershed_floor_threshold=0.05, closing_footprint_size=0
                ),
                bp,
            )
        ),
    ),
    ("two_pass_watershed", lambda bp: two_pass_watershed(bp, 0.05, 0.1, 2)),
]

for plume_date_str, preds in predictions.items():
    binary_probability = preds["binary_probability"].numpy()
    retrieval = preds["conditional_retrieval"]

    # Get wind data from all sources
    sensing_time = preds["timestamp"]
    try:
        wind_data = get_wind_data_from_csv(wind_data_df, sensing_time.date().isoformat())
    except ValueError as e:
        print(f"Error getting wind data for {plume_date_str}: {e}")
        continue

    # Calculate quantification for each masking algorithm
    pixel_width = 20

    for mask_name, mask_func in masking_algorithms:
        # Apply the masking algorithm
        try:
            predicted_mask = mask_func(binary_probability)

            # Base result dictionary
            result = {
                "date": plume_date_str,
                "masking_algorithm": mask_name,
                "ground_truth_kg_h": preds["ground_truth_kg_h"],
                "before_date": before_data["tile_item"].time.date().isoformat(),
                "earlier_date": earlier_data["tile_item"].time.date().isoformat(),
            }

            # Store wind speeds and directions
            for wind_source, (wind_speed, wind_direction) in wind_data.items():
                result[f"{wind_source}_wind_speed_m_s"] = wind_speed
                result[f"{wind_source}_wind_direction_deg"] = wind_direction

            # Calculate quantification for each wind source
            for wind_source, (wind_speed, _) in wind_data.items():
                # Skip if wind speed is NaN
                if pd.isna(wind_speed):
                    print(f"No {wind_source} wind data available for {plume_date_str}, skipping this source")
                    continue

                quant_results = calculate_all_quantification_methods(
                    predicted_mask, binary_probability, retrieval, pixel_width, wind_speed
                )

                # Add wind_source prefix to all quantification results
                for key, value in quant_results.items():
                    result[f"{wind_source}_{key}"] = value

                print(
                    f"Processed {plume_date_str} with {mask_name}, {wind_source} wind data: "
                    f"sqrtA: {quant_results['Q_sqrtA_conditional_kg_h']:.3f} kg/h, "
                    f"circle rescaled: {quant_results['Q_circle_rescaled_kg_h']:.3f} kg/h"
                )

            all_results.append(result)

        except Exception as e:
            print(f"Error applying {mask_name} to {plume_date_str}: {e}")

## Analysis

In [None]:
# Create results dataframe with enhanced visualization
all_results_df = pd.DataFrame(all_results)

# Define methods and wind sources for analysis
methods = [
    ("sqrtA+conditional", "Q_sqrtA_conditional_kg_h"),
    ("major_axis+conditional", "Q_major_conditional_kg_h"),
    ("circle+conditional", "Q_circle_conditional_kg_h"),
    ("sqrtA+rescaled", "Q_sqrtA_rescaled_kg_h"),
    ("major_axis+rescaled", "Q_major_rescaled_kg_h"),
    ("circle+rescaled", "Q_circle_rescaled_kg_h"),
]

wind_sources = ["geos_fp", "era5", "ground_truth"]
mask_algos_labels = all_results_df["masking_algorithm"].unique()

# Calculate error metrics for all methods, wind sources, and masking algorithms
error_metrics = {}
for mask_algo in mask_algos_labels:
    error_metrics[mask_algo] = {}
    mask_df = all_results_df[all_results_df["masking_algorithm"] == mask_algo].copy()

    for wind_source in wind_sources:
        error_metrics[mask_algo][wind_source] = {}
        for method_name, column in methods:
            full_column = f"{wind_source}_{column}"
            # Calculate percentage error
            error_pct = (mask_df[full_column] - mask_df["ground_truth_kg_h"]) / mask_df["ground_truth_kg_h"] * 100
            mask_df.loc[:, f"{wind_source}_{method_name}_error_pct"] = error_pct

            # Calculate metrics
            mae = np.abs(error_pct).mean()
            rmse = np.sqrt((error_pct**2).mean())
            bias = error_pct.mean()

            error_metrics[mask_algo][wind_source][method_name] = {
                "MAPE (%)": mae,
                "RMSPE (%)": rmse,
                "MPE (%)": bias,
            }

# Display results dataframe
print("\nQuantification Results with Different Masking Algorithms:")
display(all_results_df)

# Create a comprehensive error metrics table
metrics_table = []
for mask_algo in mask_algos_labels:
    for wind_source in wind_sources:
        # if wind_source in error_metrics[mask_algo]:
        for method_name in [m[0] for m in methods]:
            metrics_table.append(
                {
                    "Masking Algorithm": mask_algo,
                    "Wind Source": wind_source.upper(),
                    "Method": method_name,
                    "MAPE (%)": error_metrics[mask_algo][wind_source][method_name]["MAPE (%)"],
                    "RMSPE (%)": error_metrics[mask_algo][wind_source][method_name]["RMSPE (%)"],
                    "MPE (%)": error_metrics[mask_algo][wind_source][method_name]["MPE (%)"],
                }
            )

metrics_df = pd.DataFrame(metrics_table)

# Sort by masking algorithm, wind source, and MAE
metrics_df = metrics_df.sort_values(["Masking Algorithm", "Wind Source", "MAPE (%)"])

# Format the metrics to 2 decimal places
for col in ["MAPE (%)", "RMSPE (%)", "MPE (%)"]:
    metrics_df[col] = metrics_df[col].map("{:.2f}".format)

# Display the metrics table
print("\nError Metrics by Masking Algorithm, Wind Source, and Method:")
display(metrics_df)

# Highlight the best method for each masking algorithm and wind source
best_methods = []
for mask_algo in mask_algos_labels:
    for wind_source in wind_sources:
        best_method = min(error_metrics[mask_algo][wind_source].items(), key=lambda x: x[1]["MAPE (%)"])
        best_methods.append(
            {
                "Masking Algorithm": mask_algo,
                "Wind Source": wind_source.upper(),
                "Best Method": best_method[0],
                "MAPE (%)": best_method[1]["MAPE (%)"],
                "RMSPE (%)": best_method[1]["RMSPE (%)"],
                "MPE (%)": best_method[1]["MPE (%)"],
            }
        )

best_df = pd.DataFrame(best_methods)

# Format the metrics to 2 decimal places
for col in ["MAPE (%)", "RMSPE (%)", "MPE (%)"]:
    best_df[col] = best_df[col].map("{:.2f}".format)

print("\nBest Method by Masking Algorithm and Wind Source:")
display(best_df)

# Create visualizations to compare masking algorithms
# For each wind source, create a plot comparing the best method across masking algorithms
for wind_source in wind_sources:
    plt.figure(figsize=(12, 8))

    markers = ["o", "s", "^"]
    colors = ["blue", "green", "red"]

    for i, mask_algo in enumerate(mask_algos_labels):
        # Find the best method for this masking algorithm and wind source
        best_method = min(error_metrics[mask_algo][wind_source].items(), key=lambda x: x[1]["MAPE (%)"])
        best_method_name = best_method[0]

        # Get the column name for this method
        method_column = next(col for name, col in methods if name == best_method_name)
        full_column = f"{wind_source}_{method_column}"

        # Filter data for this masking algorithm
        mask_df = all_results_df[all_results_df["masking_algorithm"] == mask_algo]

        plt.scatter(
            mask_df["ground_truth_kg_h"],
            mask_df[full_column],
            label=f"{mask_algo} ({best_method_name})",
            marker=markers[i % len(markers)],
            color=colors[i % len(colors)],
            s=100,
            alpha=0.7,
        )

    # Set reasonable x and y limits based on the data
    max_gt = all_results_df["ground_truth_kg_h"].max()
    max_pred = (
        all_results_df[[col for col in all_results_df.columns if col.startswith(f"{wind_source}_Q_")]].max().max()
    )
    max_value = max(max_gt, max_pred) * 1.1

    plt.plot([0, max_value], [0, max_value], "k--", label="1:1 line")
    plt.xlim(0, max_gt * 1.1)
    plt.ylim(0, max_gt * 2.0)
    plt.xlabel("Ground Truth (kg/h)", fontsize=12)
    plt.ylabel("Predicted (kg/h)", fontsize=12)
    plt.title(f"Best Methods by Masking Algorithm using {wind_source.upper()} Wind Data", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
# Convert MAPE column from string to float before sorting
metrics_df["MAPE (%)"] = metrics_df["MAPE (%)"].astype(float)
# Sort the metrics dataframe by MAPE and display as markdown for easy copy-paste to GitHub/GitLab issues
sorted_metrics_df = metrics_df.sort_values(by="MAPE (%)")
display(sorted_metrics_df)
print(sorted_metrics_df.to_markdown(index=False))

In [None]:
# Filter all_results_df for 2024-12-19 and ERA5 wind source
date_to_filter = "2024-12-19"
wind_source = "era5"

# Get all columns related to ERA5
era5_columns = [col for col in all_results_df.columns if wind_source in col]

# Create a filtered dataframe with date, masking_algorithm, ground_truth, and ERA5 columns
filtered_df = all_results_df[all_results_df["date"] == date_to_filter][
    ["date", "masking_algorithm", "ground_truth_kg_h", *era5_columns]
]

# Display the filtered dataframe
print(f"\nResults for {date_to_filter} with {wind_source.upper()} wind source:")
display(filtered_df)

quant_columns = [col for col in era5_columns if "L_" in col]
focused_df = all_results_df[all_results_df["date"] == date_to_filter][
    ["date", "masking_algorithm", "ground_truth_kg_h", *quant_columns]
]

print(f"\nQuantification results for {date_to_filter} with {wind_source.upper()} wind source:")
display(focused_df)

In [None]:
# --- Helper functions for plotting ---


def _get_plot_attributes(method_name, colors, markers):
    """Determine plot color and marker based on method name."""
    if "conditional" in method_name:
        color = colors["conditional"]
    elif "rescaled" in method_name:
        color = colors["rescaled"]
    else:
        raise ValueError(f"Unknown scaling method in {method_name}")

    if "sqrtA" in method_name:
        marker = markers["sqrtA"]
    elif "major_axis" in method_name or "major" in method_name:
        marker = markers["major"]
    elif "circle" in method_name:
        marker = markers["circle"]
    else:
        raise ValueError(f"Unknown lengthscale method in {method_name}")
    return color, marker


def _plot_quantification_scatter(ax, df, wind_source, methods_to_plot, method_names, colors, markers):
    """Scatter plot comparing quantification methods."""
    max_pred_val = 0
    for method_col, method_name in zip(methods_to_plot, method_names, strict=True):
        full_col = f"{wind_source}_{method_col}"
        if full_col not in df.columns:
            continue

        color, marker = _get_plot_attributes(method_name, colors, markers)
        valid_preds = df[full_col].dropna()
        if not valid_preds.empty:
            max_pred_val = max(max_pred_val, valid_preds.max())

        ax.scatter(
            df["ground_truth_kg_h"],
            df[full_col],
            label=method_name,
            color=color,
            marker=marker,
            s=100,
            alpha=0.7,
        )

    max_gt = df["ground_truth_kg_h"].max() if not df["ground_truth_kg_h"].empty else 0
    max_val = max(max_gt, max_pred_val) * 1.1
    max_val = max(max_val, 1.0)  # Ensure minimum limit

    ax.plot([0, max_val], [0, max_val], "k--", label="1:1 line")
    ax.set_xlim(0, max(max_gt * 1.1, 1.0))
    ax.set_ylim(0, max_val)
    ax.set_xlabel("Ground Truth (kg/h)", fontsize=12)
    ax.set_ylabel("Predicted (kg/h)", fontsize=12)
    ax.set_title(f"Quantification Methods Comparison ({wind_source.upper()})", fontsize=14)
    ax.legend()
    ax.grid(True)


def _plot_error_bars(ax, df, wind_source, methods_to_plot, method_names, colors, markers):
    """Plot the bar chart of mean absolute percentage errors."""
    error_data = []
    bar_colors = []
    valid_method_names = []

    for method_col, method_name in zip(methods_to_plot, method_names, strict=True):
        full_col = f"{wind_source}_{method_col}"
        if full_col in df.columns and not df[full_col].isna().all():
            errors = (df[full_col] - df["ground_truth_kg_h"]) / df["ground_truth_kg_h"] * 100
            errors = errors.dropna()  # Ensure we only use valid error calculations
            if not errors.empty:
                error_data.append(abs(errors).mean())
                color, _ = _get_plot_attributes(method_name, colors, markers)
                bar_colors.append(color)
                valid_method_names.append(method_name)

    if error_data:  # Only plot if there is data
        ax.bar(valid_method_names, error_data, color=bar_colors, alpha=0.7)
        ax.set_ylabel("Mean Absolute Percentage Error (%)", fontsize=12)
        ax.set_title("Error by Quantification Method", fontsize=14)
        plt.sca(ax)  # Set current axis for plt.xticks
        plt.xticks(rotation=45, ha="right")
        ax.grid(True, axis="y")
    else:
        ax.text(0.5, 0.5, "No data to plot", ha="center", va="center", transform=ax.transAxes)
        ax.set_title("Error by Quantification Method", fontsize=14)


def _plot_time_series(ax, df, wind_source, methods_to_plot, method_names, colors, markers):
    """Plot the time series of predictions vs ground truth."""
    # Convert date strings to datetime for proper ordering
    df["datetime"] = pd.to_datetime(df["date"])
    df = df.sort_values("datetime")
    dates = df["date"].tolist()  # Use list of dates for ticks

    # Plot ground truth
    ax.plot(
        dates,  # Use dates directly for x-axis
        df["ground_truth_kg_h"],
        "k-",
        marker="o",
        linewidth=2,
        markersize=10,
        label="Ground Truth",
    )

    # Plot predictions for each method
    for method_col, method_name in zip(methods_to_plot, method_names, strict=True):
        full_col = f"{wind_source}_{method_col}"
        if full_col in df.columns and not df[full_col].isna().all():
            color, marker = _get_plot_attributes(method_name, colors, markers)
            ax.plot(
                dates,  # Use dates directly for x-axis
                df[full_col],
                marker=marker,
                linestyle="-",  # Ensure lines connect points
                linewidth=1.5,
                markersize=8,
                label=method_name,
                color=color,
                alpha=0.7,
            )

    ax.set_xlabel("Date", fontsize=12)
    ax.set_ylabel("Quantification (kg/h)", fontsize=12)
    ax.set_title("Predictions vs Ground Truth Over Time", fontsize=14)
    ax.legend()
    ax.grid(True)
    plt.sca(ax)  # Set current axis for plt.xticks
    plt.xticks(rotation=45, ha="right")  # Adjust rotation for better readability


def _calculate_and_display_stats(df, wind_source, methods_to_plot, method_names):
    """Calculate and display summary statistics."""
    stats_data = []
    for method_col, method_name in zip(methods_to_plot, method_names, strict=True):
        full_col = f"{wind_source}_{method_col}"
        if full_col in df.columns and not df[full_col].isna().all():
            # Calculate errors, handling potential NaNs
            valid_indices = df[full_col].notna() & df["ground_truth_kg_h"].notna()
            preds = df.loc[valid_indices, full_col]
            gt = df.loc[valid_indices, "ground_truth_kg_h"]

            if not preds.empty:  # Proceed only if there are valid pairs
                abs_errors = abs(preds - gt)
                # Avoid division by zero or NaN result
                epsilon = 1e-6
                rel_errors = np.where(
                    np.abs(gt) > epsilon,  # Use epsilon for safe division
                    abs(preds - gt) / gt * 100,
                    np.nan,  # Assign NaN if ground truth is near zero
                )
                rel_errors = rel_errors[~np.isnan(rel_errors)]  # Remove NaNs before aggregation

                if len(rel_errors) > 0:  # Check if any valid relative errors exist
                    stats_data.append(
                        {
                            "Method": method_name,
                            "MAE (kg/h)": abs_errors.mean(),
                            "MAPE (%)": np.mean(rel_errors),  # Use np.mean for NaN handling already done
                            "Min Error (%)": np.min(rel_errors),
                            "Max Error (%)": np.max(rel_errors),
                            "Std Dev (%)": np.std(rel_errors),
                        }
                    )
                else:  # Handle case with no valid relative errors
                    stats_data.append(
                        {
                            "Method": method_name,
                            "MAE (kg/h)": abs_errors.mean(),
                            "MAPE (%)": np.nan,
                            "Min Error (%)": np.nan,
                            "Max Error (%)": np.nan,
                            "Std Dev (%)": np.nan,
                        }
                    )

    if stats_data:
        stats_df = pd.DataFrame(stats_data)
        print("\nSummary Statistics:")
        display(stats_df)
    else:
        print("\nNo valid data to calculate statistics.")


# Plot results for two_pass_watershed and ERA5
def plot_two_pass_era5_results():
    """
    Plot analysis of two_pass_watershed algorithm with ERA5 wind data.

    Creates a multi-panel figure showing:
    1. Scatter plot comparing different quantification methods
    2. Bar chart of percentage errors
    3. Time series of predictions vs ground truth
    4. Summary statistics table
    """
    # Filter data for two_pass_watershed and ERA5
    mask_algo = "two_pass_watershed"
    wind_source = "era5"

    # Ensure all_results_df is accessible (assuming it's defined globally in the notebook)
    try:
        filtered_df = all_results_df[all_results_df["masking_algorithm"] == mask_algo].copy()
    except NameError:
        print("Error: all_results_df not found. Please ensure it is defined.")
        return
    if filtered_df.empty:
        print(f"No data found for masking algorithm '{mask_algo}'. Skipping plot.")
        return

    # Define methods and plot attributes
    methods_to_plot = ["Q_circle_rescaled_kg_h", "Q_sqrtA_rescaled_kg_h", "Q_major_rescaled_kg_h"]
    method_names = ["circle+rescaled", "sqrtA+rescaled", "major_axis+rescaled"]
    colors = {"conditional": "blue", "rescaled": "red"}
    markers = {"sqrtA": "^", "major": "s", "circle": "o"}

    # Create figure with multiple subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f"Analysis for {mask_algo} with {wind_source.upper()} Wind Data", fontsize=16)
    axes = axes.flatten()

    # Call helper functions to generate plots
    _plot_quantification_scatter(axes[0], filtered_df, wind_source, methods_to_plot, method_names, colors, markers)
    _plot_error_bars(axes[1], filtered_df, wind_source, methods_to_plot, method_names, colors, markers)
    _plot_time_series(axes[2], filtered_df, wind_source, methods_to_plot, method_names, colors, markers)

    # Calculate and display statistics (outside the plotting area)
    _calculate_and_display_stats(filtered_df, wind_source, methods_to_plot, method_names)

    # Remove the unused 4th subplot axis
    axes[3].axis("off")

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to prevent title overlap
    plt.show()


# Run the analysis
plot_two_pass_era5_results()

In [None]:
# Create visualizations of the masks for all examples
# Loop through all dates to visualize all three masking algorithms for each date
for example_date, example_preds in predictions.items():
    binary_probability = example_preds["binary_probability"].numpy()

    # Find the ground truth quantification for this date
    gt_value = all_results_df[all_results_df["date"] == example_date]["ground_truth_kg_h"].values
    gt_str = f"{gt_value[0]:.0f} kg/h" if len(gt_value) > 0 else "N/A"

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    # Plot the binary probability
    im0 = axes[0].imshow(binary_probability, cmap="pink_r", vmin=0, vmax=1)
    axes[0].set_title("Binary Probability")
    grid16(axes[0])
    plt.colorbar(im0, ax=axes[0])

    # Plot each masking algorithm
    for i, (mask_name, mask_func) in enumerate(masking_algorithms):
        mask = mask_func(binary_probability)
        im = axes[i + 1].imshow(mask, cmap="tab20", interpolation="nearest")
        grid16(axes[i + 1])
        axes[i + 1].set_title(f"{mask_name}")

    plt.tight_layout()
    plt.suptitle(f"Comparison of Masking Algorithms for {example_date} (Ground Truth: {gt_str})", fontsize=16)
    plt.subplots_adjust(top=0.85)
    plt.show()

In [None]:
# Create a visualization comparing the same method across different masking algorithms
def visualize_method_across_masks(method_name, column_suffix):
    """
    Visualize the performance of a specific quantification method across different masking algorithms.

    Args:
        method_name: Name of the method for the plot title
        column_suffix: Column suffix for the method in the results dataframe
    """
    wind_source = "era5"
    full_column_prefix = f"{wind_source}_{column_suffix}"

    plt.figure(figsize=(12, 8))

    markers = ["o", "s", "^"]
    colors = ["blue", "green", "red"]

    for i, mask_algo in enumerate(mask_algos_labels):
        mask_df = all_results_df[all_results_df["masking_algorithm"] == mask_algo]
        full_column = f"{wind_source}_{column_suffix}"

        if full_column in mask_df.columns and not mask_df[full_column].isna().all():
            plt.scatter(
                mask_df["ground_truth_kg_h"],
                mask_df[full_column],
                label=f"{mask_algo}",
                marker=markers[i % len(markers)],
                color=colors[i % len(colors)],
                s=100,
                alpha=0.7,
            )

    # Set reasonable x and y limits based on the data
    max_gt = all_results_df["ground_truth_kg_h"].max()
    max_pred = all_results_df[[col for col in all_results_df.columns if col.startswith(full_column_prefix)]].max().max()
    max_value = max(max_gt, max_pred) * 1.1

    plt.plot([0, max_value], [0, max_value], "k--", label="1:1 line")
    plt.xlim(0, max_gt * 1.1)
    plt.ylim(0, max_value)
    plt.xlabel("Ground Truth (kg/h)", fontsize=12)
    plt.ylabel("Predicted (kg/h)", fontsize=12)
    plt.title(f"{method_name} Method with {wind_source.upper()} Wind Data Across Masking Algorithms", fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True)
    plt.tight_layout()
    plt.show()


# Compare circle+rescaled with sqrtA+conditional (a common baseline method)
visualize_method_across_masks("circle+rescaled", "Q_circle_rescaled_kg_h")
visualize_method_across_masks("sqrtA+conditional", "Q_sqrtA_conditional_kg_h")
visualize_method_across_masks("sqrtA+rescaled", "Q_sqrtA_rescaled_kg_h")
visualize_method_across_masks("major_axis+rescaled", "Q_major_rescaled_kg_h")

In [None]:
def plot_retrievals_comparison(date_str, preds, vmax=0.1):
    """
    Plot candidate retrieval methods for a given date's predictions.

    Args:
        date_str: Date string for the title
        preds: Prediction dictionary containing the retrievals
        vmax: Maximum value for colorbar (default: 0.1 mol/m2)
    """
    binary_probability = preds["binary_probability"].numpy()
    marginal_retrieval = preds["marginal_retrieval"]
    conditional_retrieval = preds["conditional_retrieval"]

    cmap = "pink_r"

    # Calculate rescaled retrieval
    # Find the maximum binary probability in the central 20x20 pixels
    h, w = binary_probability.shape
    center_y, center_x = h // 2, w // 2
    center_region = binary_probability[center_y - 10 : center_y + 10, center_x - 10 : center_x + 10]
    max_prob = center_region.max()

    # Rescale the marginal retrieval by the maximum probability
    rescaled_retrieval = marginal_retrieval / max_prob if max_prob > 0 else marginal_retrieval.copy()

    # Create the figure with 4 subplots
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    # Plot binary probability
    im0 = axes[0].imshow(binary_probability, cmap="pink_r", vmin=0, vmax=1, interpolation="nearest")
    axes[0].set_title("Binary Probability")
    grid16(axes[0])
    plt.colorbar(im0, ax=axes[0])

    # Plot marginal retrieval
    im1 = axes[1].imshow(marginal_retrieval, cmap=cmap, vmin=0, vmax=vmax, interpolation="nearest")
    axes[1].set_title("Marginal Retrieval (mol/m²)")
    grid16(axes[1])
    plt.colorbar(im1, ax=axes[1])

    # Plot conditional retrieval
    im2 = axes[2].imshow(conditional_retrieval, cmap=cmap, vmin=0, vmax=vmax, interpolation="nearest")
    axes[2].set_title("Conditional Retrieval (mol/m²)")
    grid16(axes[2])
    plt.colorbar(im2, ax=axes[2])

    # Plot rescaled retrieval
    im3 = axes[3].imshow(rescaled_retrieval, cmap=cmap, vmin=0, vmax=vmax, interpolation="nearest")
    axes[3].set_title(f"Rescaled Retrieval (mol/m²)\nMax Prob: {max_prob:.3f}")
    grid16(axes[3])
    plt.colorbar(im3, ax=axes[3])

    # Add ground truth info if available
    gt_value = all_results_df[all_results_df["date"] == date_str]["ground_truth_kg_h"].values
    gt_str = f"Ground Truth: {gt_value[0]:.0f} kg/h" if len(gt_value) > 0 else ""

    plt.suptitle(f"Retrieval Comparison for {date_str} {gt_str}", fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)
    plt.show()


# Plot retrievals for each date
for date_str, preds in predictions.items():
    plot_retrievals_comparison(date_str, preds)

# Sanity checks with Gorroño plumes

In [None]:
# from src.validation.fpr_dt_pipeline import load_gorrono_plumes
import tempfile
from pathlib import Path

from src.azure_wrap.blob_storage_sdk_v2 import download_from_blob
from src.azure_wrap.ml_client_utils import (
    get_azureml_uri,
    make_acceptable_uri,
)


def load_gorrono_plumes(ml_client) -> list[np.ndarray]:
    """Download Gorroño plumes to a temporary directory, convert them to mol/m² and return their paths."""
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = Path(temp_dir)
        gorrono_plumes_uri = make_acceptable_uri(
            str(get_azureml_uri(ml_client, "orbio-data/methane_enhancements_molpercm2"))
        )
        download_from_blob(gorrono_plumes_uri, temp_path, recursive=True)

        # We need to convert the enhancements from mol/cm² to mol/m² as the radtran functions expect mol/m²
        raw_enhancements = [np.load(temp_path / f"{i}/methane_enhancement.npy") * 1e4 for i in range(5)]

    return raw_enhancements


gorrono_plumes = load_gorrono_plumes(ml_client)
gorrono_wind_speed = 3.5  # all Gorroño simulations were run with 3.5 m/s wind speed
# according to the paper

# Plot the plumes
plt.figure(figsize=(12, 8))
for i, plume in enumerate(gorrono_plumes):
    plt.subplot(2, 3, i + 1)
    plt.imshow(plume, cmap="pink_r", vmin=0, vmax=0.2)
    plt.title(f"Plume {i+1}\nmax: {plume.max():.2f} mol/m²")
    plt.xlim(200, 300)
    plt.ylim(200, 300)
    plt.grid()

In [None]:
concentration_thresholds = [0.0005, 0.001, 0.002, 0.005, 0.01]
# def calculate_all_quantification_methods(
#     plume_labels, binary_probability, conditional_retrieval, pixel_width, wind_speed, max_distance_pixels=10
# ):
all_quantifications = []
for i, plume in enumerate(gorrono_plumes):
    for thresh in concentration_thresholds:
        binary_mask = plume > thresh
        binary_probability = binary_mask.astype(float)  # 0 or 1, no uncertainty
        conditional_retrieval = plume
        pixel_width = 20  # Sentinel 2 resolution
        wind_speed = gorrono_wind_speed
        max_distance_pixels = 10
        quantifications = calculate_all_quantification_methods(
            binary_mask, binary_probability, conditional_retrieval, pixel_width, wind_speed, max_distance_pixels=10
        )
        quantifications["plume_index"] = i + 1
        quantifications["threshold"] = thresh
        all_quantifications.append(quantifications)

In [None]:
quantifications_df = pd.DataFrame(all_quantifications)

# Plot each plume with different thresholds


for i in range(len(gorrono_plumes)):
    df_sub = quantifications_df[quantifications_df["plume_index"] == i + 1]
    thresh = df_sub["threshold"]
    Q_sqrtA = df_sub["Q_sqrtA_conditional_kg_h"]
    Q_circle = df_sub["Q_circle_conditional_kg_h"]
    Q_major = df_sub["Q_major_conditional_kg_h"]
    plt.plot(thresh, Q_sqrtA, color="#0072B2", marker="s", linestyle="-", label="sqrtA" if i == 0 else "")
    plt.plot(thresh, Q_circle, color="#D55E00", marker="o", linestyle="-", label="circle" if i == 0 else "")
    plt.plot(thresh, Q_major, color="#009E73", marker="^", linestyle="-", label="major" if i == 0 else "")


plt.xlabel("Concentration Threshold (mol/m²) for plume mask")
plt.ylabel("Quantification (kg/h)")
plt.grid(True)
# Add a horizontal dashed line at 1000 kg/hr
plt.axhline(y=1000, color="black", linestyle="--", label="ground truth")
plt.legend()
plt.title("Quantifications of the Gorroño plumes")

In [None]:
quantifications_df["log_error_sqrtA"] = np.log(quantifications_df["Q_sqrtA_conditional_kg_h"] / 1000)
quantifications_df["log_error_circle"] = np.log(quantifications_df["Q_circle_conditional_kg_h"] / 1000)
quantifications_df["log_error_major"] = np.log(quantifications_df["Q_major_conditional_kg_h"] / 1000)
quantifications_df.groupby("threshold")["log_error_sqrtA"].apply(lambda x: np.abs(x).median())
quantifications_df.groupby("threshold")["log_error_circle"].apply(lambda x: np.abs(x).median())
# Estimate the median absolute error in the logs of the major axis method
quantifications_df.groupby("threshold")["log_error_major"].apply(lambda x: np.abs(x).median())

# Confidence intervals

In [None]:
alternative_models = [
    load_model_and_concatenator(model_id, device, SatelliteID.S2)[0]
    for model_id in [
        "models:/torchgeo_pwr_unet/1395",
        "models:/torchgeo_pwr_unet/1391",
        "models:/torchgeo_pwr_unet/1486",
    ]
]
for model in alternative_models:
    model.eval()

In [None]:
# apply the alternative models to each date
# code copy-pasted from the central predictions above
# In reality, we'll also want to include a choice of reference tiles
# in the ensemble.
alternative_predictions = {}
for _, plume_row in phase0_plumes.iterrows():
    plume_date_str = plume_row["date"]
    plume_date = datetime.strptime(plume_date_str, "%Y-%m-%d")
    ground_truth_kg_h = plume_row["quantification_kg_h"]

    print(f"\nProcessing plume from {plume_date_str} with ground truth {ground_truth_kg_h} kg/h")

    # Find the main data for this plume date
    main_tile = None
    for ref in reference_data:
        ref_date = ref["tile_item"].time.date().isoformat()
        if ref_date == plume_date_str:
            main_tile = ref
            print(f"Found main tile for {plume_date_str}: {ref['tile_item'].id}")
            break

    if main_tile is None:
        # Check if this plume date matches the main data loaded above
        if main_data["tile_item"].time.date().isoformat() == plume_date_str:
            main_tile = main_data
            print(f"Found main tile in main_data for {plume_date_str}: {main_data['tile_item'].id}")
        else:
            print(f"No tile found for plume date {plume_date_str}, skipping")
            continue
    # Find two reference tiles that don't have controlled releases
    clean_refs = []

    # Check if we have manual reference selections for this date
    if plume_date_str in manual_reference_selections:
        before_date, earlier_date = manual_reference_selections[plume_date_str]
        print(f"Using manually specified reference dates: {before_date} and {earlier_date}")

        # Find the reference tiles matching the manual selections
        for ref in reference_data:
            ref_date = ref["tile_item"].time.date().isoformat()
            if ref_date in (before_date, earlier_date):
                clean_refs.append(ref)
    else:
        # Use the automatic selection logic
        for ref in reference_data:
            ref_date = ref["tile_item"].time.date().isoformat()
            if ref_date in cloudy_dates:
                continue
            if ref_date in phase0_release_dates:
                continue
            if ref_date >= plume_date_str:
                continue
            clean_refs.append(ref)

    if len(clean_refs) < 2:  # noqa: PLR2004
        print(f"Not enough clean reference tiles found for {plume_date_str}, skipping")
        continue

    # Sort by date (newest first) and take the two most recent
    clean_refs.sort(key=lambda x: x["tile_item"].time.date().isoformat(), reverse=True)
    reference_chips = clean_refs[:2]

    before_data = reference_chips[0]
    earlier_data = reference_chips[1]

    print("Using reference tiles from:")
    print(f"  Before: {before_data['tile_item'].time.date().isoformat()}")
    print(f"  Earlier: {earlier_data['tile_item'].time.date().isoformat()}")

    # Generate predictions
    def get_preds_for_model(model):
        """Generate predictions for a given model."""
        preds = generate_predictions(main_tile, reference_chips, model, device, band_concatenator, lossFn)  # noqa: B023
        preds = add_retrieval_to_pred(preds, main_tile["tile_item"])  # noqa: B023
        # Add timestamp to predictions
        preds["timestamp"] = main_tile["tile_item"].time  # noqa: B023
        preds["ground_truth_kg_h"] = ground_truth_kg_h  # noqa: B023
        return preds

    all_model_preds = [get_preds_for_model(model) for model in alternative_models]

    alternative_predictions[plume_date_str] = all_model_preds

In [None]:
# Create a figure for plotting
plt.figure(figsize=(10, 8))

for date_str, alt_preds in alternative_predictions.items():
    central_preds = predictions[date_str]

    def get_quantification(preds):
        """Calculate the emission rate quantification for a plume."""
        binary_probability = preds["binary_probability"].numpy()

        pixel_width = 20  # Sentinel 2 resolution
        sensing_time = preds["timestamp"]
        wind_data = get_wind_data_from_csv(wind_data_df, sensing_time.date().isoformat())
        wind_speed, wind_direction = wind_data["geos_fp"]

        # just do strict watershed here
        labelled_mask = label(
            retrieval_mask_using_watershed_algo(
                WatershedParameters(
                    marker_distance=1, marker_threshold=0.1, watershed_floor_threshold=0.1, closing_footprint_size=0
                ),
                binary_probability,
            )
        )
        central_plume_id = find_central_plume(labelled_mask, max_distance_pixels=10, pixel_width=20)
        central_plume_mask = labelled_mask == central_plume_id

        conditional_retrieval = preds["conditional_retrieval"]
        marginal_retrieval = conditional_retrieval * binary_probability
        rescaled_retrieval = marginal_retrieval / binary_probability[central_plume_mask].max()

        L_major, IME, Q_major = calculate_major_axis_quantification(
            central_plume_mask, rescaled_retrieval, pixel_width, wind_speed
        )
        return Q_major

    Q_ensemble = [get_quantification(pred) for pred in alt_preds]
    print("-" * 100)
    print(f"Ensemble of quantification on {date_str}: {Q_ensemble}")
    Q_central = get_quantification(central_preds)
    print(f"Central quantification on {date_str}: {Q_central}")

    Q_interval = quantification_interval(Q_central, Q_ensemble)
    print(f"Quantification interval for {date_str}: {Q_interval}")

    # Extract ground truth value
    ground_truth = central_preds["ground_truth_kg_h"]

    # Plot the point with error bars
    plt.errorbar(
        x=ground_truth,
        y=Q_central,
        yerr=[[Q_central - Q_interval[0]], [Q_interval[1] - Q_central]],
        fmt="o",
        capsize=5,
        markersize=8,
        color="blue",
        ecolor="gray",
    )
    # add a date label next to the point, angled at 45 degrees
    month_day = date_str.split("-")[1] + "-" + date_str.split("-")[2]
    plt.text(ground_truth, Q_central, month_day, fontsize=10, ha="left", va="bottom", color="grey", rotation=45)
# Add a diagonal line representing perfect estimation (y=x)
plt.plot([0, 1500], [0, 1500], "k--", alpha=0.7, label="Perfect Estimation")

# Add labels and title
plt.xlabel("Ground Truth (kg/h)", fontsize=12)
plt.ylabel("Estimated Emission Rate (kg/h)", fontsize=12)
plt.title("Emission Rate Estimation with Uncertainty", fontsize=14)

# Add grid and legend
plt.grid(True, alpha=0.3)

# Show the plot
plt.tight_layout()

In [None]:
# Let's look at the error bars if we use ground truth wind speed
plt.figure(figsize=(10, 8))

for date_str, alt_preds in alternative_predictions.items():
    central_preds = predictions[date_str]

    def get_quantification(preds):
        """Calculate the emission rate quantification for a plume."""
        binary_probability = preds["binary_probability"].numpy()

        pixel_width = 20  # Sentinel 2 resolution
        sensing_time = preds["timestamp"]
        wind_data = get_wind_data_from_csv(wind_data_df, sensing_time.date().isoformat())
        wind_speed, wind_direction = wind_data["ground_truth"]

        # just do strict watershed here
        labelled_mask = label(
            retrieval_mask_using_watershed_algo(
                WatershedParameters(
                    marker_distance=1, marker_threshold=0.1, watershed_floor_threshold=0.05, closing_footprint_size=0
                ),
                binary_probability,
            )
        )
        central_plume_id = find_central_plume(labelled_mask, max_distance_pixels=10, pixel_width=20)
        central_plume_mask = labelled_mask == central_plume_id

        conditional_retrieval = preds["conditional_retrieval"]
        marginal_retrieval = conditional_retrieval * binary_probability
        rescaled_retrieval = marginal_retrieval / binary_probability[central_plume_mask].max()

        L_major, IME, Q_major = calculate_major_axis_quantification(
            central_plume_mask,
            rescaled_retrieval,
            pixel_width,
            wind_speed,
        )
        return Q_major

    Q_ensemble = [get_quantification(pred) for pred in alt_preds]
    print("-" * 100)
    print(f"Ensemble of quantification on {date_str}: {Q_ensemble}")
    Q_central = get_quantification(central_preds)
    print(f"Central quantification on {date_str}: {Q_central}")

    Q_interval = quantification_interval(Q_central, Q_ensemble, wind_MdALE=0.0)  # using ground truth wind speed
    print(f"Quantification interval for {date_str}: {Q_interval}")

    # Extract ground truth value
    ground_truth = central_preds["ground_truth_kg_h"]

    # Plot the point with error bars
    plt.errorbar(
        x=ground_truth,
        y=Q_central,
        yerr=[[Q_central - Q_interval[0]], [Q_interval[1] - Q_central]],
        fmt="o",
        capsize=5,
        markersize=8,
        color="blue",
        ecolor="gray",
    )
    # add a date label next to the point, angled at 45 degrees
    month_day = date_str.split("-")[1] + "-" + date_str.split("-")[2]
    plt.text(ground_truth, Q_central, month_day, fontsize=10, ha="left", va="bottom", color="grey", rotation=45)
# Add a diagonal line representing perfect estimation (y=x)
plt.plot([0, 1500], [0, 1500], "k--", alpha=0.7, label="Perfect Estimation")

# Add labels and title
plt.xlabel("Ground Truth (kg/h)", fontsize=12)
plt.ylabel("Estimated Emission Rate (kg/h)", fontsize=12)
plt.title("Emission Rate Estimation with Ground Truth Wind Speed", fontsize=14)

# Add grid and legend
plt.grid(True, alpha=0.3)

# Show the plot
plt.tight_layout()

In [None]:
wind_day = wind_data_df[wind_data_df["date"] == "2024-11-22"].iloc[0]
print(f"GEOS-FP wind speed: {np.sqrt(wind_day['geos_ux']**2 + wind_day['geos_uy']**2):.1f}")
print(f"ERA5 wind speed: {np.sqrt(wind_day['era5_ux']**2 + wind_day['era5_uy']**2):.1f}")
print(f"Ground truth wind speed: {np.sqrt(wind_day['gt_ux']**2 + wind_day['gt_uy']**2):.1f}")

In [None]:
wind_data_df["geos_fp"] = np.sqrt(wind_data_df["geos_ux"] ** 2 + wind_data_df["geos_uy"] ** 2)
wind_data_df["era5"] = np.sqrt(wind_data_df["era5_ux"] ** 2 + wind_data_df["era5_uy"] ** 2)
wind_data_df["gt"] = np.sqrt(wind_data_df["gt_ux"] ** 2 + wind_data_df["gt_uy"] ** 2)
print(f"GEOS-FP MdALE: {np.nanmedian(np.abs(np.log(wind_data_df['geos_fp']) - np.log(wind_data_df['gt']))):.2f}")
print(f"ERA5 MdALE: {np.nanmedian(np.abs(np.log(wind_data_df['era5']) - np.log(wind_data_df['gt']))):.2f}")
plt.plot(wind_data_df["gt"], wind_data_df["era5"], "o")
plt.plot(wind_data_df["gt"], wind_data_df["geos_fp"], "p")
plt.xlabel("Ground truth wind speed (m/s)")
plt.ylabel("Estimated wind speed (m/s)")
plt.grid(True)
plt.plot([0, 10], [0, 10], "k--", alpha=0.7, label="1:1 line")
plt.legend()
plt.show()