In [None]:
# pip install matplotlib dotenv sentinelhub geopandas IProgress 

In [None]:
import os
import geopandas as gpd
import pandas as pd
from dotenv import load_dotenv
from sentinelhub import SHConfig, SentinelHubCatalog, BBox, CRS, DataCollection
from sentinelhub import MimeType, SentinelHubDownloadClient, SentinelHubRequest, bbox_to_dimensions, filter_times
import datetime as dt
import matplotlib.pyplot as plt
import numpy as np
import tifffile
import json
import uuid
import glob

# Load environment variables
load_dotenv()

# Configure Sentinel Hub
config = SHConfig()
config.sh_client_id = os.getenv("SENTINELHUB_CLIENT_ID")
config.sh_client_secret = os.getenv("SENTINELHUB_CLIENT_SECRET")
config.sh_base_url = 'https://sh.dataspace.copernicus.eu'
config.sh_token_url = 'https://identity.dataspace.copernicus.eu/auth/realms/CDSE/protocol/openid-connect/token'

# Initialize the catalog
catalog = SentinelHubCatalog(config=config)

#################################
######### PATCHES LOGIC #########
#################################
# Load patches
patches_gdf = gpd.read_file("../data/patches/highlighted_patches.geojson")
print(f"Loaded {len(patches_gdf)} patches")

# Ensure correct CRS
assert patches_gdf.crs.to_epsg() == 3857, "Expected EPSG:3857 in patches.geojson"

# Extract numeric index from patch_id for sorting
patches_gdf["patch_index"] = patches_gdf["patch_id"].str.extract(r'patch_(\d{5})')[0].astype(int)
patches_gdf = patches_gdf.sort_values("patch_index").reset_index(drop=True)

# Resume from specific patch
start_from_patch = "patch_02149_fa728d65"  # ← Change this to resume from another
if start_from_patch in patches_gdf["patch_id"].values:
    start_index = patches_gdf[patches_gdf["patch_id"] == start_from_patch].index[0]
    patches_gdf = patches_gdf.iloc[start_index:].reset_index(drop=True)
    print(f"Resuming from patch: {start_from_patch}")
else:
    print(f"Start patch '{start_from_patch}' not found. Starting from the beginning.")

# Limit number of patches processed
patch_limit = 2
patches_gdf = patches_gdf.head(patch_limit)
print(f"Processing {len(patches_gdf)} patches...")

# Set paths
processed_dir = "../data/processed/sentinel2"
logfile_path = "../data/processed/patch_download_log.csv"

# List already processed patch_ids from saved files
existing_files = glob.glob(os.path.join(processed_dir, "patch_*.json"))
processed_patch_ids = set()

for path in existing_files:
    fname = os.path.basename(path)
    if fname.startswith("patch_") and fname.endswith(".json"):
        patch_id = fname.split("_")[1]  # e.g., patch_02149_xxx → 02149
        full_id = "_".join(fname.split("_")[1:3])  # e.g., 02149_xxx
        processed_patch_ids.add(f"patch_{full_id}")

# Ensure the directory exists
os.makedirs(os.path.dirname(logfile_path), exist_ok=True)

# Load or initialize patch log
if os.path.exists(logfile_path):
    log_df = pd.read_csv(logfile_path)
    logged_patch_ids = set(log_df["patch_id"])
    print(f"Loaded log with {len(log_df)} entries")
else:
    log_df = pd.DataFrame(columns=["patch_id", "uuid", "timestamp", "file_processed", "status"])
    logged_patch_ids = set()
    print("Initialized new patch log")


# #### REPLACED BY PATCHES ####
# # Define the bounding box
# # bbox = BBox(bbox=[115.24026, -8.52927, 115.28474, -8.48453], crs=CRS.WGS84)
# bbox = BBox(bbox=[115.1716, -8.5968, 115.3534, -8.4170], crs=CRS.WGS84)

##########################
####### LOOOP ############
##########################

# Loop over patches
for idx, row in patches_gdf.iterrows():
    patch_id = row["patch_id"]
    geom = row["geometry"]

    if patch_id in processed_patch_ids:
        print(f"Patch {patch_id} already processed. Skipping.")
        log_df = pd.concat([log_df, pd.DataFrame([{
            "patch_id": patch_id,
            "uuid": None,
            "timestamp": None,
            "file_processed": None,
            "status": "skipped"
        }])], ignore_index=True)
        log_df.to_csv(logfile_path, index=False)
        continue

    if geom.geom_type != "Polygon":
        print(f"Skipping non-polygon geometry in patch {patch_id}")
        log_df = pd.concat([log_df, pd.DataFrame([{
            "patch_id": patch_id,
            "uuid": None,
            "timestamp": None,
            "file_processed": None,
            "status": "non-polygon",
        }])], ignore_index=True)
        log_df.to_csv(logfile_path, index=False)
        continue

    # Convert geometry bounds to a BBox in Web Mercator
    patch_bbox = BBox(bbox=geom.bounds, crs=CRS.POP_WEB)  # EPSG:3857 == POP_WEB

    print(f"\nProcessing patch {patch_id} with patch_bbox: {patch_bbox}")

    # Run existing code logic here, replacing the old static `bbox` with `patch_bbox`
    # E.g., search, get timestamps, create requests, download, save, etc.

    # Define the time range for the search
    time_interval = ("2010-01-01", "2026-12-31")
    cloud_cover_filter = 2  # max % cloud cover allowed

    # Perform a search within the bounding box and time range
    search_iterator = catalog.search(
        DataCollection.SENTINEL2_L2A,
        bbox=patch_bbox,
        time=time_interval,
        filter= f"eo:cloud_cover < {cloud_cover_filter}",  # Optional filter for cloud cover
        fields={"include": ["id", "properties.datetime", "properties.eo:cloud_cover", "geometry"], "exclude": []},
    )

    # Convert the results to a list of features
    features = list(search_iterator)

    # Create a GeoDataFrame from the features
    results_gdf = gpd.GeoDataFrame.from_features(features)

    # Display the results
    print("Total number of results:", len(features))
    print(results_gdf)

    # Find unique acquisitions
    time_difference = dt.timedelta(hours=1)

    all_timestamps = search_iterator.get_timestamps()
    unique_acquisitions = filter_times(all_timestamps, time_difference)

    if not unique_acquisitions:
        print(f"No acquisitions found for patch {patch_id}. Skipping.")
        # ✅ ADD THIS:
        log_df = pd.concat([log_df, pd.DataFrame([{
            "patch_id": patch_id,
            "uuid": None,
            "timestamp": None,
            "file_processed": None,
            "status": "no_acquisitions"
        }])], ignore_index=True)
        log_df.to_csv(logfile_path, index=False)
        continue

    print('unique_acquisitions: ',unique_acquisitions)

    # Define the directory to save the images
    save_dir = "../data/raw/sentinel2"
    os.makedirs(save_dir, exist_ok=True)

    # Specify the bands to download
    bands_num = 13
    bands_units = "DN"
    sampleType = "uint16"

    # Define the resolution (max resolution is 10m for Sentinel-2)
    # Note: Sentinel-2 has different resolutions for different bands (10m, 20m, 60m)
    # Here we set the resolution to 100m for all bands, but you can adjust this as needed
    # For example, if you want to download all bands at 20m resolution, set resolution = 20
    # Supported resolution for B01, B02, B03, B04, B05, B06, B07, B08, B8A is 10m
    # Supported resolution for B11, B12 is 20m
    # Supported resolution for B09 is 60m
    # Supported resolution for SCL is 20m
    # SCL is the Scene Classification Layer that provides information about the scene such as clouds, water, etc.
    # SCL Value	Meaning
    # 0	No Data
    # 1	Saturated/Defective
    # 2	Dark Area Pixels
    # 3	Cloud Shadow
    # 4	Vegetation
    # 5	Bare Soils
    # 6	Water
    # 7	Clouds low probability / Unclassified
    # 8	Clouds medium probability
    # 9	Clouds high probability
    # 10	Thin Cirrus
    # 11	Snow or Ice
    resolution = 10 # Band Resolution

    all_bands_evalscript = f"""
    //VERSION=3

    function setup() {{
        return {{
            input: [{{
                bands: ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12", "SCL"],
                units: "{bands_units}"
            }}],
            output: {{
                bands: {bands_num},
                sampleType: "{sampleType}"
            }}
        }};
    }}

    function evaluatePixel(sample) {{
        return [
            sample.B01, sample.B02, sample.B03, sample.B04, sample.B05, sample.B06,
            sample.B07, sample.B08, sample.B8A, sample.B09, sample.B11, sample.B12,
            sample.SCL
        ];
    }}
    """

    # Store requests to download
    # Note: The requests are not executed yet, they are just stored in a list
    process_requests = []

    # Check if image already exists for unique acquisitions and specified resolution
    for timestamp in unique_acquisitions:
        iso_time = timestamp.isoformat().replace(":", "").replace("-", "")
        base_filename = f"patch_{patch_id}_{iso_time}_res{resolution}"
        filename = os.path.join(save_dir, f"{base_filename}.tiff")

        # Check if the file already exists
        if not os.path.exists(filename):
            request = SentinelHubRequest(
                evalscript=all_bands_evalscript,
                input_data=[
                    SentinelHubRequest.input_data(
                        data_collection=DataCollection.SENTINEL2_L2A.define_from("s2l2a", service_url=config.sh_base_url),
                        time_interval=(timestamp - time_difference, timestamp + time_difference),
                    )
                ],
                responses=[SentinelHubRequest.output_response("default", MimeType.TIFF)],  # Use TIFF for multi-band data
                bbox=patch_bbox,
                size=bbox_to_dimensions(patch_bbox, resolution=resolution),  # Adjust resolution as needed
                config=config,
            )
            process_requests.append(request)
        else:
            print(f"File {filename} already exists. Skipping download.")

    print("process_requests:", len(process_requests))

    # Download the images
    client = SentinelHubDownloadClient(config=config)

    download_requests = [request.download_list[0] for request in process_requests]

    if not download_requests:
        print(f"No new images to download for patch {patch_id}")
        continue

    # Download the images
    data = client.download(download_requests)

    # Example: Check dtype of each image in the downloaded data list
    for i, array in enumerate(data):
        print(f"Image {i}: shape={array.shape}, dtype={array.dtype}")

    print(data)

    ###############################
    ###### PLOTTING & SAVING  #####
    ###############################

    # Configuration
    pixel_scaling = 1
    base_raw_dir = "../data/raw/sentinel2"
    base_proc_dir = "../data/processed/sentinel2"
    os.makedirs(base_raw_dir, exist_ok=True)
    os.makedirs(base_proc_dir, exist_ok=True)

    # Bounding box info (used in metadata, not folder naming)
    ul_lon = round(patch_bbox.lower_left[0], 4)
    ul_lat = round(patch_bbox.upper_right[1], 4)
    br_lon = round(patch_bbox.upper_right[0], 4)
    br_lat = round(patch_bbox.lower_left[1], 4)

    # Define band list and original resolutions
    band_list = ["B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12", "SCL"]
    original_band_res = {
        "B01": 60, "B02": 10, "B03": 10, "B04": 10,
        "B05": 20, "B06": 20, "B07": 20,
        "B08": 10, "B8A": 20, "B09": 60,
        "B11": 20, "B12": 20, "SCL": 20
    }

    def normalize_by_scl(image, scl, valid_scl_values=[4, 5]):
        """
        Normalize image using percentile scaling over valid SCL mask.
        Returns normalized image, valid mask, list of p2, and p98 per band.
        """
        valid_mask = np.isin(scl, valid_scl_values)
        norm_image = np.zeros_like(image, dtype=np.float32)
        norm_p2, norm_p98 = [], []

        for band in range(image.shape[2]):
            band_data = image[:, :, band]
            band_valid = band_data[valid_mask]
            if band_valid.size > 0:
                p2, p98 = np.percentile(band_valid, (2, 98))
            else:
                p2, p98 = 0, 1
            norm_p2.append(float(p2))
            norm_p98.append(float(p98))
            band_data = np.clip(band_data, p2, p98)
            band_data = (band_data - p2) / (p98 - p2 + 1e-6)
            norm_image[:, :, band] = band_data

        return norm_image, valid_mask, norm_p2, norm_p98

    # Save each image and metadata
    for idx, (image, timestamp) in enumerate(zip(data, unique_acquisitions)):
        if not isinstance(timestamp, dt.date):
            print(f"Invalid timestamp at index {idx}. Skipping image.")
            continue

        # Generate unique filename
        # File naming
        short_uuid = str(uuid.uuid4())[:8]
        iso_time = timestamp.isoformat().replace(":", "").replace("-", "")
        # base_filename = f"{short_uuid}_{iso_time}_res{resolution}"
        base_filename = f"patch_{patch_id}_{timestamp.isoformat().replace(':', '').replace('-', '')}_res{resolution}"

        # File paths
        raw_tiff_path = os.path.join(base_raw_dir, f"{base_filename}.tiff")
        proc_tiff_path = os.path.join(base_proc_dir, f"{base_filename}.tiff")
        json_path = os.path.join(base_proc_dir, f"{base_filename}.json")  # metadata with processed version

        image = image.astype(np.float32)

        # Extract SCL and spectral bands
        scl = image[:, :, -1]
        spectral = image[:, :, :-1]

        # Normalize spectral bands using valid SCL pixels
        valid_scl_values = [4, 5]
        norm_image, valid_mask, norm_p2, norm_p98 = normalize_by_scl(spectral, scl, valid_scl_values=valid_scl_values)

        # ### PLOTTING ONLY ###
        # # Visualize RGB with pixel scaling
        # if norm_image.shape[2] >= 4:
        #     rgb_indices = [3, 2, 1]  # B04, B03, B02
        #     rgb_image = norm_image[:, :, rgb_indices] * pixel_scaling
        #     rgb_image[~valid_mask] = 1.0  # Set invalid (cloud) pixels to white
        #     plt.imshow(rgb_image)
        #     plt.title("True Color (masked using SCL)")
        #     plt.axis("off")
        #     plt.show()
        # else:
        #     print(f"Warning: Image at index {idx} doesn't have 3+ bands. Skipping display.")

        # # Save raw TIFF (unmodified reflectance + SCL)
        # try:
        #     image_raw = np.transpose(image, (2, 0, 1))  # (bands, height, width)
        #     tifffile.imwrite(raw_tiff_path, image_raw)
        #     print(f"Saved RAW: {raw_tiff_path}")
        # except Exception as e:
        #     print(f"Error saving raw image at index {idx}: {e}")
        #     continue

        # Save processed TIFF (normalized reflectance only, excluding SCL)
        try:
            image_norm = np.transpose(norm_image, (2, 0, 1))  # (bands, height, width)
            tifffile.imwrite(proc_tiff_path, image_norm)
            print(f"Saved PROCESSED: {proc_tiff_path}")

        except Exception as e:
            print(f"Error saving processed image at index {idx}: {e}")
            # Log the failure
            log_df = pd.concat([log_df, pd.DataFrame([{
                "patch_id": patch_id,
                "uuid": short_uuid,
                "timestamp": timestamp.isoformat(),
                "file_processed": os.path.basename(proc_tiff_path),
                "status": f"error: {str(e)[:100]}"
            }])], ignore_index=True)
            log_df.to_csv(logfile_path, index=False)

            continue

        # Save metadata
        metadata = {
            "source": "SENTINEL2_L2A",
            "patch_id": patch_id,
            "uuid": short_uuid,
            "timestamp": timestamp.isoformat(),
            "resolution": resolution,
            "sampleType": sampleType,
            "bands_units": bands_units,
            "bands_num": bands_num,
            "bands_shape": image.shape,
            "bands": band_list,
            "bbox": {
                "upper_left": [ul_lat, ul_lon],
                "bottom_right": [br_lat, br_lon]
            },
            "original_band_resolutions": original_band_res,
            "cloud_mask_applied": True,
            "max_search_cloud_cover": cloud_cover_filter,
            "normalization": {
                "percentiles": [2, 98],
                "p2": norm_p2,
                "p98": norm_p98
            },
            "file_raw": os.path.basename(raw_tiff_path),
            "file_processed": os.path.basename(proc_tiff_path),
        }

        with open(json_path, "w") as f:
            json.dump(metadata, f, indent=4)
            print(f"Saved metadata: {json_path}")

        # Append log entry for this image
        log_df = pd.concat([log_df, pd.DataFrame([{
            "patch_id": patch_id,
            "uuid": short_uuid,
            "timestamp": timestamp.isoformat(),
            "file_processed": os.path.basename(proc_tiff_path),
            "status": "success"
        }])], ignore_index=True)

        # Save log immediately to avoid losing progress
        log_df.to_csv(logfile_path, index=False)
        print(f"Logged patch {patch_id} to {logfile_path}")


    ### "OUTER LOOP"
    # Print first image shape
    if data:
        print(f"Shape of the first image: {data[0].shape}")
    else:
        print("No data available.")

    # ######################################
    # ### Plotting Multispectral Bands #####
    # ######################################

    # # Band order including SCL as last band (index 12)
    # band_order = ["B01", "B02", "B03", "B04", "B05", "B06",
    #               "B07", "B08", "B8A", "B09", "B11", "B12", "SCL"]
    # numbered_band_order = {index: band for index, band in enumerate(band_order)}

    # def extract_bands(image, band_names, band_dict):
    #     """
    #     Extract specific bands from image based on band names.
    #     """
    #     indices = [idx for name in band_names for idx, b in band_dict.items() if b == name]
    #     return image[:, :, indices]

    # def normalize_by_scl(image, scl, valid_scl_values=[4, 5]):
    #     """
    #     Normalize image using percentile scaling over valid SCL mask.
    #     Returns normalized image and the mask used.
    #     """
    #     valid_mask = np.isin(scl, valid_scl_values)
    #     norm_image = np.zeros_like(image, dtype=np.float32)

    #     for band in range(image.shape[2]):
    #         band_data = image[:, :, band]
    #         band_valid = band_data[valid_mask]
    #         if band_valid.size > 0:
    #             p2, p98 = np.percentile(band_valid, (2, 98))
    #         else:
    #             p2, p98 = 0, 1
    #         band_data = np.clip(band_data, p2, p98)
    #         band_data = (band_data - p2) / (p98 - p2 + 1e-6)
    #         norm_image[:, :, band] = band_data

    #     return norm_image, valid_mask

    # # ---- Load your DN image (shape: H, W, 13), for example:
    # # image = tifffile.imread(path).transpose(1, 2, 0)

    # # Split into spectral + SCL
    # spectral = image[:, :, :-1]
    # scl = image[:, :, -1]

    # # Normalize using SCL land pixels
    # norm_image, valid_mask = normalize_by_scl(spectral, scl)

    # # Extract different band composites
    # rgb = norm_image[:, :, [band_order.index(b) for b in ["B04", "B03", "B02"]]]
    # vegetation = norm_image[:, :, [band_order.index(b) for b in ["B08", "B04", "B03"]]]
    # urban = norm_image[:, :, [band_order.index(b) for b in ["B12", "B11", "B04"]]]
    # water = norm_image[:, :, [band_order.index(b) for b in ["B08", "B11", "B02"]]]

    # # Optional: Mask out invalid areas (e.g., clouds) as white
    # rgb[~valid_mask] = 1.0
    # vegetation[~valid_mask] = 1.0
    # urban[~valid_mask] = 1.0
    # water[~valid_mask] = 1.0

    # # ---- Plot
    # fig, axes = plt.subplots(2, 2, figsize=(15, 15))

    # axes[0, 0].imshow(rgb)
    # axes[0, 0].set_title("True Color (B04, B03, B02)")
    # axes[0, 0].axis('off')

    # axes[0, 1].imshow(vegetation)
    # axes[0, 1].set_title("False Color Vegetation (B08, B04, B03)")
    # axes[0, 1].axis('off')

    # axes[1, 0].imshow(urban)
    # axes[1, 0].set_title("Urban/Buildings (B12, B11, B04)")
    # axes[1, 0].axis('off')

    # axes[1, 1].imshow(water)
    # axes[1, 1].set_title("Water Detection (B08, B11, B02)")
    # axes[1, 1].axis('off')

    # plt.tight_layout()
    # plt.show()