In [1]:
# --- IMPORTS ---
import numpy as np
import stackstac
import pystac_client
import planetary_computer
import xrspatial.multispectral as ms
import dask.array as da
from dask.distributed import Client, LocalCluster
from urllib3.util.retry import Retry
from dask.diagnostics import ProgressBar
import xarray as xr
import bottleneck
import matplotlib.pyplot as plt
import matplotlib as mpl
import rioxarray
import os
import ipyleaflet
import dask


# --- PARAMETERS ---
# Choose how to group the time dimension: "month", "week", or "year"
time_grouping = "month"  # Options: "month", "week", "year"
spatial_resolution = 30  # Spatial resolution in meters
bands_to_load = ['B02', 'B03', 'B04', 'SCL']  # Blue, Green, Red, Scene Classification
local_cluster = True  # Set to False if using Coiled

In [None]:

# --- DASK CLUSTER SETUP ---
if local_cluster:
    cluster = LocalCluster()
    client = Client(cluster)
else:
    import coiled
    cluster = coiled.Cluster(name="Timelapse", shutdown_on_close=True)
    cluster.adapt(n_workers=1, maximum=8)
    client = cluster.get_client()


In [18]:

# --- MAP FOR BOUNDING BOX SELECTION ---
m = ipyleaflet.Map(scroll_wheel_zoom=True)
m.center = (41.64933994767867, -69.94438630063088)
m.zoom = 12
m.layout.height = "500px"
m.layout.width = "500px"
display(m)


Map(center=[41.64933994767867, -69.94438630063088], controls=(ZoomControl(options=['position', 'zoom_in_text',…

In [19]:

# Wait for user to select area, then get bounding box
# (You may want to add a widget for interactive selection in a real workflow)
bounding_box = (m.west, m.south, m.east, m.north)


In [20]:

# --- LOAD DATA FROM PLANETARY COMPUTER ---
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

search = catalog.search(
    collections=["sentinel-2-l2a"],
    bbox=bounding_box,
    datetime="2024-01-01/2024-12-31"
)
items = search.item_collection()
print(f"Found {len(items)} items in the selected area and time range.")


Found 145 items in the selected area and time range.


In [21]:

# Stack the data using stackstac
data = stackstac.stack(
    items,
    assets=bands_to_load,
    resolution=spatial_resolution,
    epsg=3857,
    bounds_latlon=bounding_box
)

print("Array size information:")
print(f"Shape: {data.shape}")
print(f"Size in bytes: {data.data.nbytes}")
print(f"Size in GB: {data.data.nbytes / 1e9:.2f} GB")
print(f"Number of chunks: {data.data.npartitions}")
print(f"Chunksize: {(data.data.nbytes / data.data.npartitions) / 1e6:.2f} MB")

data


Array size information:
Shape: (145, 4, 319, 319)
Size in bytes: 472171040
Size in GB: 0.47 GB
Number of chunks: 580
Chunksize: 0.81 MB


Unnamed: 0,Array,Chunk
Bytes,450.30 MiB,795.01 kiB
Shape,"(145, 4, 319, 319)","(1, 1, 319, 319)"
Dask graph,580 chunks in 3 graph layers,580 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 450.30 MiB 795.01 kiB Shape (145, 4, 319, 319) (1, 1, 319, 319) Dask graph 580 chunks in 3 graph layers Data type float64 numpy.ndarray",145  1  319  319  4,

Unnamed: 0,Array,Chunk
Bytes,450.30 MiB,795.01 kiB
Shape,"(145, 4, 319, 319)","(1, 1, 319, 319)"
Dask graph,580 chunks in 3 graph layers,580 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [24]:
import xarray as xr
import numpy as np
import dask.array as da

# --- Your initial setup code ---
# Establish SCL band on its own
scl_band = data.sel(band='SCL')

# Classes to keep (i.e not clouds or defective pixels)
good_classes = [4, 5, 6]

# Clear mask, i.e where pixels are not cloudy
clear_mask = scl_band.isin(good_classes)

# Apply mask to all bands. This may create NaNs.
data = data.where(clear_mask)

# Establish spectral bands on their own
spectral_bands = data.drop_sel(band = 'SCL')


# --- START OF THE NEW, CORRECTED SOLUTION ---

def mode_using_bincount(arr):
    """
    A fast, NumPy-based function to find the mode of a 1D array of integers.
    This function will be applied to each pixel's time-series.
    """
    valid_data = arr[~np.isnan(arr)]
    if valid_data.size == 0:
        return np.nan
    return np.bincount(valid_data.astype(int)).argmax()

def dask_mode_runner(group):
    """
    This function takes a daily group (an xarray.DataArray) and applies the mode
    calculation along its time axis in a parallel, Dask-friendly way.
    
    IT NOW CORRECTLY RETURNS AN xarray.DataArray.
    """
    # apply_along_axis runs the mode function on 1D slices along the time axis (axis 0).
    # This returns a raw Dask Array.
    result_dask_array = da.apply_along_axis(
        mode_using_bincount,
        axis=0,  # The 'time' dimension of the group
        arr=group.data
    ).astype(np.uint8)

    # --- THIS IS THE FIX ---
    # Wrap the raw Dask Array result back into an xarray.DataArray.
    # We must provide the correct dimensions and coordinates from the original group,
    # excluding the 'time' dimension which has been collapsed by the mode calculation.
    return xr.DataArray(
        result_dask_array,
        coords={
            "y": group.coords["y"],
            "x": group.coords["x"],
        },
        dims=["y", "x"],
    )

# 1. As before, compute the grouping key first.
daily_groups = scl_band.time.dt.floor('D')
computed_groups = daily_groups.compute()

# 2. Group the SCL data and apply our new Dask-friendly mode runner.
scl_mode = scl_band.groupby(computed_groups).apply(dask_mode_runner)

# 3. Group the spectral data and compute the median.
spectral_median = spectral_bands.groupby(computed_groups).median(dim='time', skipna=True)

# 4. Combine the results and rename the 'group' dimension back to 'time'.
# The SCL band name needs to be re-assigned before concatenation
scl_mode = scl_mode.assign_coords(band='SCL')
daily_composites = xr.concat([spectral_median, scl_mode], dim='band')
daily_composites = daily_composites.rename({'floor': 'time'})

# 5. Fill any gaps that may have been created by the masking.
daily_composites = daily_composites.ffill("time").bfill("time")


print("--- Final Lazy DataArray Definition ---")
daily_composites

--- Final Lazy DataArray Definition ---


Unnamed: 0,Array,Chunk
Bytes,220.49 MiB,2.33 MiB
Shape,"(71, 4, 319, 319)","(1, 3, 319, 319)"
Dask graph,142 chunks in 591 graph layers,142 chunks in 591 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 220.49 MiB 2.33 MiB Shape (71, 4, 319, 319) (1, 3, 319, 319) Dask graph 142 chunks in 591 graph layers Data type float64 numpy.ndarray",71  1  319  319  4,

Unnamed: 0,Array,Chunk
Bytes,220.49 MiB,2.33 MiB
Shape,"(71, 4, 319, 319)","(1, 3, 319, 319)"
Dask graph,142 chunks in 591 graph layers,142 chunks in 591 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [26]:

'''def select_best_time_group(data_array, time_grouping, good_classes=[4, 5, 6]):
    """
    Selects the best image (with the highest percentage of good pixels) for each time group (month, week, or year).

    Parameters:
    -----------
    data_array : xr.DataArray
        The input xarray DataArray containing a 'band' dimension with 'SCL' classification.
    time_grouping : str
        The time grouping to use. Must be one of 'month', 'week', or 'year'.
    good_classes : list, optional
        List of SCL class values considered as 'good' pixels. Default is [4, 5, 6].

    Returns:
    --------
    xr.DataArray
        DataArray containing the best image for each time group.
    """

    # --- GROUPING LOGIC ---
    grouping_dict = {
        "month": "time.month",
        "week": "time.week",
        "year": "time.year"
    }
    if time_grouping not in grouping_dict:
        raise ValueError(f"Invalid time_grouping: {time_grouping}. Choose from {list(grouping_dict.keys())}")

    grouping_accessor = grouping_dict[time_grouping]

    # --- GOOD PIXEL CALCULATION ---
    classification = data_array.sel(band="SCL").squeeze(drop=True)
    good_pixels = classification.isin(good_classes)
    good_pixel_count = good_pixels.sum(dim=["x", "y"])
    valid_pixel_count = (classification != 0).sum(dim=["x", "y"])
    good_pixel_percentage = ((good_pixel_count / valid_pixel_count) * 100).fillna(0)

    # --- COMPUTE PERCENTAGES ---
    good_pixel_percentage = good_pixel_percentage.compute()

    # --- SELECT BEST TIME STEP IN EACH GROUP ---
    def select_best_in_group(group):
        """
        Selects the time step with the highest percentage of good pixels in a group.
        """
        return group.isel(time=group.argmax(dim="time"))

    best_entries = good_pixel_percentage.groupby(grouping_accessor).apply(select_best_in_group)
    best_timestamps = best_entries.time.values
    best_time_group_data = data_array.sel(time=best_timestamps)
    best_time_group_data = best_time_group_data.persist()

    print(f"Selected best time step for each {time_grouping}.")
    return best_time_group_data'''

'''def select_best_time_group_optimized(data_array, time_grouping="month", good_classes=[4, 5, 6]):
    """
    Selects the best image (with the highest percentage of good pixels) for each time group (month, week, or year)
    using a lazy, Dask-aware method.
    """
    print("--- Lazily calculating good pixel percentage ---")
    
    classification = data_array.sel(band="SCL", drop=True)
    
    good_pixels = classification.isin(good_classes)
    good_pixel_count = good_pixels.sum(dim=["x", "y"])
    valid_pixel_count = (classification != 0).sum(dim=["x", "y"])
    good_pixel_percentage = ((good_pixel_count / valid_pixel_count) * 100).fillna(0)

    print(f"--- Lazily finding best timestamp for each {time_grouping} ---")
    
    grouping_dict = {
        "month": data_array.time.dt.month,
        "week": data_array.time.dt.isocalendar().week,
        "year": data_array.time.dt.year
    }
    if time_grouping not in grouping_dict:
        raise ValueError(f"Invalid time_grouping: {time_grouping}.")
        
    grouping_accessor = grouping_dict[time_grouping]

    def find_best_timestamp(group):
        """
        This function is applied to each monthly group.
        It now correctly returns an xarray.DataArray.
        """
        best_position = group.argmax(dim="time")
        best_timestamp_scalar = group.time.isel(time=best_position)
        
        # --- THIS IS THE FIX ---
        # Wrap the scalar result in xr.DataArray() to ensure the return type
        # is correct for the groupby().apply() operation.
        return xr.DataArray(best_timestamp_scalar)

    # Group by the accessor, then apply our custom function to each group.
    lazy_best_timestamps = good_pixel_percentage.groupby(grouping_accessor).apply(find_best_timestamp)

    # Compute the small list of timestamps.
    print("--- Computing the small list of best timestamps ---")
    with ProgressBar():
        best_timestamps = lazy_best_timestamps.compute()

    print("--- Lazily selecting final images ---")
    # We now need to use the .values to select, as we have a DataArray of timestamps
    best_images = data_array.sel(time=best_timestamps.values)

    return best_images'''

def select_best_time_group_dataframe(data_array, time_grouping="month", good_classes=[4, 5, 6]):
    """
    Selects the best image for each time group using a robust, manually-constructed Dask DataFrame.
    """
    print("--- Lazily calculating good pixel percentage ---")
    
    classification = data_array.sel(band="SCL", drop=True)
    
    good_pixels = classification.isin(good_classes)
    good_pixel_count = good_pixels.sum(dim=["x", "y"])
    valid_pixel_count = (classification != 0).sum(dim=["x", "y"])
    good_pixel_percentage = ((good_pixel_count / valid_pixel_count) * 100).fillna(0)

    print(f"--- Finding best timestamp for each {time_grouping} using Dask DataFrame ---")

    # --- THIS IS THE FINAL, GUARANTEED DATAFRAME LOGIC ---
    
    # 1. Manually construct a Dask DataFrame from the lazy Dask arrays.
    # This gives us full control and avoids the problematic .to_dask_dataframe() method.
    df = dask.dataframe.from_dask_array(
        good_pixel_percentage.data,       # The lazy data
        columns=['percentage'],           # Explicitly name the single column
        index=good_pixel_percentage.time  # Use the time coordinate as the index
    )

    # 2. Create the grouping column (e.g., month number) in the DataFrame.
    grouping_dict = {
        "month": df.index.month,
        "week": df.index.isocalendar().week,
        "year": df.index.year
    }
    if time_grouping not in grouping_dict:
        raise ValueError(f"Invalid time_grouping: {time_grouping}.")
    df['group_id'] = grouping_dict[time_grouping]

    # 3. Use the highly optimized groupby().idxmax() pattern from Dask DataFrame.
    best_indices_lazy = df.groupby('group_id')['percentage'].idxmax()

    # 4. Compute the result. This is a very small and fast computation.
    print("--- Computing the small list of best timestamps ---")
    with ProgressBar():
        best_timestamps = best_indices_lazy.compute()

    # The result is a pandas Series, so we extract the values.
    print("--- Lazily selecting final images ---")
    best_images = data_array.sel(time=best_timestamps.values)

    return best_images


best_images = select_best_time_group_dataframe(daily_composites, "month")

--- Lazily calculating good pixel percentage ---
--- Finding best timestamp for each month using Dask DataFrame ---


ValueError: 'index' must be an instance of dask.dataframe.Index

In [None]:
# Calculate global min and max for consistent scaling across all time steps
rgb_data = best_images.sel(band=['B02', 'B03', 'B04'])
global_min, global_max = dask.compute(rgb_data.min(), rgb_data.max())

# Plot with consistent vmin and vmax across all subplots
rgb_data.plot.imshow(
    col="time", 
    rgb="band", 
    col_wrap=5,
    vmin=global_min,
    vmax=global_max
)

In [None]:
# Import geogif for creating animated GIFs from geospatial data
import geogif

geogif = geogif.dgif(rgb_data, fps=4)
geogif.compute()