# Goal: Allow to run and visualize the new chip and tile selection for S2 step by step

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import pandas as pd

pd.set_option("display.max_colwidth", 250)
import collections
import datetime
import random
import warnings

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shapely
import tqdm
from pystac_client import Client
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from hydra.core.global_hydra import GlobalHydra
import hydra

import json
from omnicloudmask import predict_from_array
import time
import cv2
import rasterio
from pyproj import Transformer
from shapely.geometry import Polygon, box
from rasterio.enums import Resampling

from src.data.generation.base import DataGenerationConfig
from src.data.common.sim_plumes import PlumeType
from src.data.sentinel2 import (
    BAND_RESOLUTIONS,
    Sentinel2Item,
)
from src.data.sentinel2_l1c import Sentinel2L1CItem
from src.data.sentinel2 import SceneClassificationLabel as SCLabel
from src.data.azure_run_data_generation import (
    get_queries_by_cloud_coverage,
    get_quality_thresholds,
    compute_transformation_combinations,
)
from src.data.generate import parse_quality_thresholds, SATELLITE_CLASSES
from src.utils.parameters import SatelliteID
from src.azure_wrap.ml_client_utils import (
    create_ml_client_config,
    get_abfs_output_directory,
    get_default_blob_storage,
    initialize_blob_service_client,
    make_acceptable_uri,
    get_azureml_uri,
    initialize_ml_client,
)
from src.utils.git_utils import get_git_revision_hash


def setup_logging() -> logging.Logger:
    """Configure root logger and with minimal Azure logging."""
    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    azure_logger = logging.getLogger("azure")
    azure_logger.setLevel(logging.ERROR)
    # Return a logger for the calling module
    return logging.getLogger(__name__)


logger = setup_logging()

# Setup

In [None]:
if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()
hydra.initialize(config_path="../src/data/config", version_base=None)
config: DictConfig = hydra.compose(config_name="config", overrides=["satellite=s2"])
print(f"Satellite: {config.satellite.name}")

plume_type = PlumeType(config.plumes.plume_type)
satellite = SatelliteID(config.satellite.id)

# Get transformation combinations
transformations_grid = OmegaConf.to_container(config.satellite_split.transformations_grid, resolve=True)
transformations = compute_transformation_combinations(transformations_grid)  # type: ignore[arg-type]
print(f"Generated {len(transformations)} transformation combinations: {transformations_grid}")

# Initialize Azure ML
ml_client = initialize_ml_client()

# Set up plume catalog
plumes_catalog_uri = get_azureml_uri(ml_client, config.plumes_split.catalog_uri)
plume_catalog_acceptable_uri = make_acceptable_uri(str(plumes_catalog_uri))
print(f"Plumes catalog URI: {plumes_catalog_uri}")

suffix = "2024_02_12_revamp_TEST"

# Set up paths and names
out_base_dir = f"data/{plume_type.value}/{satellite}/{config.split.name}_{suffix}"
experiment_name = f"{satellite}-{plume_type.value}-{config.split.name}-{suffix}"

# Get git revision for tracking
git_revision_hash = get_git_revision_hash()

config.satellite_split.tiles_query_files = [
    # "../src/data/tiles/s2/csv_files/2025_05_22_MGRS_with_IDs_within_OG_val_248.csv",
    "../src/data/tiles/s2/csv_files/2025_05_22_MGRS_with_IDs_within_OG_train_3413.csv",
]
# Get satellite-specific queries grouped by cloud coverage
queries_by_coverage = get_queries_by_cloud_coverage(
    config.satellite_split.tiles_query_files,
    config.satellite.cloud_coverage_threshold,
    satellite,
    config.split.name,
)

df = pd.concat([pd.read_csv(file) for file in config.satellite_split.tiles_query_files], ignore_index=True)
print(len(df))
df["month"] = df["date"].apply(lambda x: x.split("-")[1])
df["year"] = df["date"].apply(lambda x: x.split("-")[0])

all_jobs = [
    (query, transformation_params, cloud_range)
    for cloud_range, queries in queries_by_coverage.items()
    for query in queries
    for transformation_params in transformations
]
# all_jobs
len(all_jobs)

In [None]:
job_idx = np.random.randint(len(all_jobs))  # 294 = Ice
print(job_idx)

query, transformation, cloud_range = all_jobs[job_idx]
out_dir = str(Path(out_base_dir))

quality_thresholds = get_quality_thresholds(config=config, cloud_range=cloud_range, satellite=satellite)

transformation_params = transformation

whole_size = BAND_RESOLUTIONS["B11"]
azure_cluster = False
test = False

quality_thresholds = parse_quality_thresholds(quality_thresholds)
print(quality_thresholds)

# Validate transformation params
assert isinstance(transformation_params, dict) and all(
    isinstance(k, str) and isinstance(v, float) for k, v in transformation_params.items()
), "transformation_params must be a dictionary with each string key having one float value"

from src.utils.parameters import SATELLITE_SPATIAL_RESOLUTIONS

# Create base config
base_config = DataGenerationConfig(
    plume_catalog=plume_catalog_acceptable_uri,
    plume_type=plume_type,
    out_dir=out_dir,
    crop_size=config.crop_size,
    quality_thresholds=quality_thresholds,
    random_seed=config.random_seed,
    transformation_params=transformation_params,
    azure_cluster=azure_cluster,
    git_revision_hash=git_revision_hash,
    test=test,
    ml_client=None,
    s3_client=None,
    storage_options=None,
    psf_sigma=config.satellite.psf_sigma,
    target_spatial_resolution=SATELLITE_SPATIAL_RESOLUTIONS[config.satellite.id],
    concentration_rescale_value=config.plumes.concentration_rescale_value,
    plume_proba_dict=config.satellite_split.plume_proba_dict,
    hapi_data_path=config.satellite.hapi_data_path,
)

# Run and visualize a random ID

In [None]:
all_jobs[0]

In [None]:
"S2B_MSIL2A_20200721T185919_R013_T10SEJ_20200817T012938"

In [None]:
df[(df["mgrs"] == "10SEJ") & (df["date"] == "2020-07-21")].index[0]

## Find main tile and reference tiles, download bands, predict cloud and cloud shadow masks with omnicloud
- Set `visualize_tiles=True` to see the large tiles visualized + Omnicloud

In [None]:
job_idx

In [None]:
visualize_tiles = False

for k in range(1):
    try:
        job_idx = np.random.randint(len(all_jobs))
        # job_idx = df[(df["mgrs"] == "10SEJ") & (df["date"] == "2020-07-21")].index[0]

        query, transformation, cloud_range = all_jobs[job_idx]
        sentinel_MGRS = query["mgrs"]
        sentinel_date = query["date"]
        s2_identifier = f"{sentinel_MGRS}_{sentinel_date}"
        print(job_idx, s2_identifier)

        quality_thresholds = parse_quality_thresholds(quality_thresholds)
        print(f"{quality_thresholds=}")
        transformation_params = transformation_params
        print(f"{transformation_params=}")

        sentinel_date_obj = datetime.datetime.strptime(sentinel_date, "%Y-%m-%d").date()
        print(sentinel_date_obj)

        # Create transformation-cloud_coverage-specific output directory
        out_dir = str(Path(out_base_dir))  # / transform_str / cloud_bucket_str)

        # Get satellite-specific class and parameters
        SatelliteClass = SATELLITE_CLASSES[satellite]
        bands = config.satellite.bands
        if "," in bands:
            bands = bands.split(",")
        satellite_params = {
            "sentinel_MGRS": sentinel_MGRS,
            "sentinel_date": sentinel_date_obj,
            "bands": bands,
            "time_delta_days": config.satellite.time_delta_days,
            "nb_reference_ids": config.satellite.nb_reference_ids,
            "omnicloud_cloud_t": config.satellite.omnicloud_cloud_t,
            "omnicloud_shadow_t": config.satellite.omnicloud_shadow_t,
            # "reference_chip_max_bad_px_perc": config.satellite.reference_chip_max_bad_px_perc,
        }
        # Try 10 or 25
        nb_reference_ids = 10
        # Create and run pipeline
        pipeline = SatelliteClass(**satellite_params, **base_config.model_dump())
        pipeline.nb_reference_ids = nb_reference_ids

        pipeline.visualize_tiles = visualize_tiles
        data = pipeline.download_data()
        break
    except Exception as err:
        print(err)
        import traceback

        print(traceback.print_exception(None, err, err.__traceback__))

## Visualize chipping
- Setting `pipeline.visualize_crops=True` and `pipeline.visualize_crops_show_frac = 0.05` will visualize 5% of the created chips
- Setting `pipeline.visualize_insertion=True` will visualize the insertion of methane of the created chips
- Running the next cell with `pipeline.visualize_crops=False` and `pipeline.visualize_insertion=False` will create all chips and print the report at the end of what happened. The output dataframe is saved to "test.parquet". A report with some stats of what happened is saved at "test.json"

In [None]:
# Reset Crop counts
pipeline.non_overlapping_count = 0
pipeline.overlapping_count = 0
pipeline.too_much_main_nodata_count = 0
pipeline.succeed_5perc_count = 0
pipeline.failed_5perc_count = 0
pipeline.reference_indices_all = []

pipeline.visualize_crops = False
pipeline.visualize_crops_show_frac = 0.05
pipeline.visualize_insertion = False
crops = pipeline.generate_crops(data)

data_items = pipeline.generate_synthetic_data_items(pipeline.plume_files, crops)

local_parquet_path = pipeline.save_parquet(data_items, save_cloud=False, save_local=True)

In [None]:
df = pd.read_parquet("test.parquet")
df.shape

In [None]:
with open("test.json", "rb") as f:
    json_data = f.read()  # Read as bytes
json_data = json.loads(json_data.decode("utf-8"))
json_data

In [None]:
# df.columns[:50], df.columns[50:100], df.columns[100:]

In [None]:
df[
    [
        "exclusion_perc",
        "chip_cloud_combined_perc_main",
        "chip_cloud_shadow_omni_perc_main",
        "how_many_plumes_we_wanted",
        "how_many_plumes_we_inserted",
        "tile_SCL_SNOW_perc_main",
    ]
].sample(10)

In [None]:
df[
    [
        "exclusion_perc",
        "chip_cloud_combined_perc_main",
        "chip_cloud_shadow_omni_perc_main",
        "tile_cloud_omni_perc_main",
        "tile_cloud_combined_perc_main",
        "tile_cloud_shadow_omni_perc_main",
        "tile_cloud_shadow_combined_perc_main",
        "tile_no_data_perc_main",
    ]
].sample(10)

In [None]:
df[
    [
        "plume_files",
        "plume_sizes",
        "how_many_plumes_we_wanted",
        "how_many_plumes_we_inserted",
        "plumes_inserted_idxs",
        "plume_emissions",
        "frac_abs_sum",
        "bands",
        "size",
        "crop_x",
        "crop_y",
        "main_and_reference_ids",
        "main_and_reference_dates",
    ]
].sample(2)