In [None]:
#!/usr/bin/env python

import sys

sys.path.insert(0, "../src")

from coastmonitor.io.drive_config import configure_instance

configure_instance(branch="dev")

import logging
import os

import dask

dask.config.set({"dataframe.query-planning": False})
import re
from collections import defaultdict

import dask_geopandas
import fsspec
from dotenv import load_dotenv
from geopandas.array import GeometryDtype

from coastpy.geo.quadtiles_utils import add_geo_columns

VERSION = "2024-12-21"
OUT_BASE_URI = f"az://shorelinemonitor-raw-series/release/{VERSION}"
TMP_BASE_URI = OUT_BASE_URI.replace("az://", "az://tmp/")
LOG_BASE_URI = OUT_BASE_URI.replace("az://", "az://log/")

TMP_OBS_PART_URI = f"{TMP_BASE_URI}/obs/part"
TMP_OBS_BOX_URI = f"{TMP_BASE_URI}/obs/box"

LOG_OBS_PART_URI = f"{LOG_BASE_URI}/obs/part"
LOG_OBS_BOX_URI = f"{LOG_BASE_URI}/obs/box"
LOG_FAILED_TRANSECTS_URI = f"{LOG_BASE_URI}/obs/failed-transects"


load_dotenv(override=True)

sas_token = os.getenv("AZURE_STORAGE_SAS_TOKEN")
account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME")
storage_options = {"account_name": account_name, "credential": sas_token}

logging.getLogger("azure").setLevel(logging.WARNING)


def to_out_href(base_href, filename):
    return base_href + filename


DTYPES = {
    # Identifiers
    "transect_id": str,  # Unique identifier for the transect
    "shoreline_id": str,  # Unique identifier for the shoreline observation
    # Date/Time
    "datetime": "datetime64[ns]",  # Timestamp of the shoreline observation
    # Positional Attributes
    "geometry": GeometryDtype(),  # Shoreline geometry as a LineString
    "lon": "float32",  # Longitude of the shoreline position
    "lat": "float32",  # Latitude of the shoreline position
    "chainage": "float32",  # Distance along the transect
    "utm_epsg": int,  # EPSG code for the UTM zone
    "transect_lon": "float32",  # Longitude of the transect origin
    "transect_lat": "float32",  # Latitude of the transect origin
    "quadkey": str,  # Quadkey for spatial indexing
    "bbox": object,  # Bounding box for spatial indexing
    # Derived Metrics
    "sinuosity": "float32",  # Ratio of actual length to straight-line length
    "self_intersection_density": "float32",  # Density of self-intersections per unit length
    "fractal_dimension": "float32",  # Fractal complexity of the shoreline
    "is_shoal": bool,  # Flag indicating if the shoreline is over a shoal
    # Observation Attributes
    "obs_group": int,  # Group identifier for aggregated observations
    "obs_group_stdev": "float32",  # Standard deviation within the observation group
    "obs_group_range": "float32",  # Range of values within the observation group
    "obs_group_count": int,  # Count of observations within the group
    "obs_is_qa": bool,  # Flag indicating whether the observation passed QA
    # Metrics from source data
    "otsu_threshold": "float32",  # Otsu threshold value for water detection
    "otsu_separability": "float32",  # Otsu separability score
    "composite:image_id": "string",  # Identifier for the composite image
    "composite:start_datetime": "datetime64[ns, UTC]",  # Start of composite period
    "composite:end_datetime": "datetime64[ns, UTC]",  # End of composite period
    "composite:determination_datetimes": "object",  # Array of datetime objects
    "composite:cloud_cover": "object",  # Array of cloud cover percentages
}


def group_files_by_box(filenames):
    """
    Group filenames by box numbers. Handles filenames that include optional parts.

    Args:
        filenames (list): List of filenames to be grouped.

    Returns:
        dict: A dictionary with box numbers as keys and lists of filenames as values.
    """
    pattern = re.compile(r"(box_\d+)_.*\.parquet")

    # pattern = re.compile(r"(box_\d+)_*.parquet")
    grouped_files = defaultdict(list)

    for filename in filenames:
        match = pattern.search(filename)
        if match:
            grouped_files[match.group(1)].append(filename)

    return dict(grouped_files)

In [None]:
# from distributed import Client

# client = Client()
# print(client.dashboard_link)

fs = fsspec.filesystem("az", **storage_options)
files = fs.ls(TMP_OBS_PART_URI)
files = group_files_by_box(files)

# NOTE list the already processed files in the box directory
try:
    processed_files = fs.ls(TMP_OBS_BOX_URI)
    box_pattern = re.compile(r"(box_\d+).parquet")

    processed_boxes = []
    for filename in processed_files:
        match = box_pattern.search(filename)
        if match:
            processed_boxes.append(match.group(1))

except FileNotFoundError:
    processed_boxes = []

# NOTE: remove already processed boxes
files = {
    box_id: filenames
    for box_id, filenames in files.items()
    if box_id not in processed_boxes
}

for box_id, filenames in files.items():
    print(f"Processing box {box_id} with {len(filenames)} files")
    break

In [None]:
ddf = dask_geopandas.read_parquet(
    filenames, filesystem=fs, gather_spatial_partitions=False, columns=["transect_id"]
).persist()

In [None]:
r = []
for p in ddf.partitions:
    r.append(p.transect_id.unique().compute())

In [None]:
import geopandas as gpd

with fs.open(
    "az://shorelinemonitor-raw-series/release/2024-12-17/box_028.parquet"
) as f:
    # with fs.open("az://shorelinemonitor-raw-series/release/2024-12-17/box_142.parquet") as f:
    df = gpd.read_parquet(f)

In [None]:
import math

import numpy as np
import pandas as pd


def calculate_divisions(
    index_array: np.ndarray | pd.Series, npartitions: int
) -> np.ndarray:
    """
    Calculate division indices for repartitioning a Dask dataframe.

    Args:
        index_array (Union[np.ndarray, pd.Series]): Array of index values.
        npartitions (int): Desired number of partitions.

    Returns:
        np.ndarray: Division indices for repartitioning.

    Example:
        >>> coastline_names = ddf["coastline_name"].unique().compute()
        >>> divisions = calculate_divisions(coastline_names, 20)
    """
    if isinstance(index_array, pd.Series):
        index_array = index_array.values

    step = math.ceil(len(index_array) / npartitions)
    divisions = np.concatenate([index_array[0:-1:step], [index_array[-1]]])
    return divisions

In [None]:
transect_ids = ddf["transect_id"].compute()
divisions = calculate_divisions(transect_ids, 20)

In [None]:
import numpy as np
import pandas as pd


def split_dataframe_by_unique_ids(df, max_rows):
    """
    Split a sorted DataFrame into partitions, ensuring unique `transect_id` values per partition.

    Args:
        df (pd.DataFrame): The input DataFrame, already sorted by `transect_quadkey` and `transect_id`.
        max_rows (int): Approximate maximum number of rows per partition.

    Returns:
        list[pd.DataFrame]: A list of DataFrame partitions with unique `transect_id` values.
    """
    total_rows = len(df)
    if total_rows == 0:
        return []

    # Estimate the number of partitions
    num_partitions = int(np.ceil(total_rows / max_rows))
    step_size = total_rows // num_partitions

    # Identify initial split indices
    split_indices = list(range(step_size, total_rows, step_size))

    # Adjust indices to ensure unique `transect_id` per partition
    adjusted_indices = []
    for idx in split_indices:
        while (
            idx < len(df)
            and df.loc[idx, "transect_id"] == df.loc[idx - 1, "transect_id"]
        ):
            idx += 1
        if idx < len(df):  # Avoid adding out-of-bounds indices
            adjusted_indices.append(idx)

    # Create partitions
    partitions = []
    start_idx = 0
    for idx in adjusted_indices:
        partitions.append(df.iloc[start_idx:idx])
        start_idx = idx

    # Add the remaining rows as the last partition
    if start_idx < len(df):
        partitions.append(df.iloc[start_idx:])

    return partitions

In [None]:
import geopandas as gpd

with fs.open(
    "az://shorelinemonitor-raw-series/release/2024-12-17/box_028.parquet"
) as f:
    # with fs.open("az://shorelinemonitor-raw-series/release/2024-12-17/box_142.parquet") as f:
    df = gpd.read_parquet(f)
    print(df.shape)

df = add_geo_columns(df, geo_columns=["quadkey", "bbox"], quadkey_zoom_level=12)
quadkeys = add_geo_columns(
    gpd.GeoSeries.from_xy(df.transect_lon, df.transect_lat, crs=4326).to_frame(
        "geometry"
    ),
    geo_columns=["quadkey"],
)
df["transect_quadkey"] = quadkeys.quadkey
df = df.sort_values(["transect_quadkey", "transect_id", "datetime"]).reset_index(
    drop=True
)

In [None]:
r = split_dataframe_by_unique_ids(df, 1000)

In [None]:
ids = []
for df_ in r:
    ids.append(df_.transect_id.unique())

In [None]:
check_partition_uniqueness(ids)

In [None]:
pd.concat(r).shape

In [None]:
for df_ in r:
    print(df_.shape)

In [None]:
r[-1]

In [None]:
df.shape

In [None]:
100 / (328707 / len(df)) / 2.83

In [None]:
from coastpy.geo.size import estimate_memory_usage_per_row

estimate_memory_usage_per_row(df)

In [None]:
# 1e6 rows per partition
from coastpy.utils.size import size_to_bytes

size_to_bytes("100MB") / 319

In [None]:
len(r)

In [None]:
r[0].isin(r[3]).sum()

In [None]:
seen_ids = set()
for p in r:

In [None]:
def check_partition_uniqueness(r):
    """
    Check if transect_id values are unique across partitions.
    """
    # Initialize a set to track seen IDs and another for duplicates
    seen_ids = set()
    duplicates = set()

    # Check uniqueness across partitions
    for unique_ids in r:
        for transect_id in unique_ids:
            if transect_id in seen_ids:
                duplicates.add(transect_id)
            else:
                seen_ids.add(transect_id)

    # Determine if all transect_ids are unique across partitions
    is_unique = len(duplicates) == 0

    return is_unique, duplicates


check_partition_uniqueness(r)

In [None]:
r

In [None]:
r[0][:5]

In [None]:
# Initialize a set to track seen IDs
seen_ids = set()
duplicates = set()

# Check uniqueness across partitions
for unique_ids in r:
    for transect_id in unique_ids:
        if transect_id in seen_ids:
            duplicates.add(transect_id)
        else:
            seen_ids.add(transect_id)

# Report duplicates if any
if duplicates:
    print("Duplicate transect IDs found across partitions:", duplicates)
else:
    print("All transect IDs are unique across partitions.")

In [None]:
for p in ddf.partitions:
    p
    break

In [None]:
r = ddf.compute()

In [None]:
p.compute()

In [None]:
dask.compute(r)