In [4]:
# Import required packages
import math
import folium
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import xarray as xr
from matplotlib import colors as mcolours
from matplotlib.animation import FuncAnimation
from pathlib import Path
from pyproj import Transformer
from shapely.geometry import box
from skimage.exposure import rescale_intensity
import matplotlib.cm as cm



def _degree_to_zoom_level(l1, l2, margin=0.0):
    """
    Calculates an integer zoom level based on the difference between two geographic coordinates.

    This function estimates an appropriate map zoom level such that the bounding box defined by 
    the two coordinates fits nicely within the viewport, optionally including a margin.

    Parameters:
    -----------
    l1 : float
        The first coordinate (latitude or longitude in degrees).

    l2 : float
        The second coordinate (latitude or longitude in degrees).

    margin : float, optional (default=0.0)
        A fractional margin to increase the bounding box size, e.g., 0.1 adds 10% padding.

    Returns:
    --------
    zoom_level_int : int
        An integer zoom level (typically between 0 and 18), where higher values mean closer zoom.

    Notes:
    ------
    - If the coordinates are identical (`degree == 0`), returns a default zoom level of 18.
    - Uses the formula based on logarithm of the ratio between full map width (360 degrees) and the bounding box.
    """


    degree = abs(l1 - l2) * (1 + margin)
    zoom_level_int = 0
    if degree != 0:
        zoom_level_float = math.log(360 / degree) / math.log(2)
        zoom_level_int = int(zoom_level_float)
    else:
        zoom_level_int = 18
    return zoom_level_int

def display_map(x, y, crs='EPSG:4326', margin=-0.5, zoom_bias=0, centroid=None):
    """ 
    Generates an interactive map displaying a bounding rectangle or centroid overlay on Google Maps imagery.

    This function takes coordinate bounds in any projected coordinate reference system (CRS), transforms them 
    to latitude and longitude (EPSG:4326), and plots an interactive folium map. It overlays a red bounding 
    rectangle outlining the coordinate extent or optionally a red circle marking a centroid point.

    The map's zoom level is automatically calculated to frame the bounding box as tightly as possible 
    without clipping, with options to adjust zoom level and add padding.

    Last modified: July 2025
    
    Adapted from a function by Otto Wagner: 
    https://github.com/ceos-seo/data_cube_utilities/tree/master/data_cube_utilities

    Parameters
    ----------
    x : tuple of float
        Tuple of (min, max) x coordinates in the specified CRS.
    y : tuple of float
        Tuple of (min, max) y coordinates in the specified CRS.
    crs : str, optional
        Coordinate reference system of the input coordinates (default 'EPSG:4326').
    margin : float, optional
        Degrees of latitude/longitude padding added around the bounding box to increase spacing 
        between the rectangle and map edges (default -0.5).
    zoom_bias : float or int, optional
        Adjustment to zoom level; positive values zoom in, negative zoom out (default 0).
    centroid : tuple of float or None, optional
        Optional centroid coordinate as (latitude, longitude). If provided, a red circle will 
        mark this point instead of drawing the bounding rectangle.

    Returns
    -------
    folium.Map
        A folium interactive map centered on the bounding box or centroid, with overlays and zoom 
        level optimized to the input coordinates.

    Example
    -------
    >>> display_map((500000, 510000), (2000000, 2010000), crs='EPSG:3857', margin=0.1)
    """
    # Convert each corner coordinates to lat-lon
    all_x = (x[0], x[1], x[0], x[1])
    all_y = (y[0], y[0], y[1], y[1])
    transformer = Transformer.from_crs(crs, "EPSG:4326")
    all_longitude, all_latitude = transformer.transform(all_x, all_y)

    # Calculate zoom level based on coordinates
    lat_zoom_level = _degree_to_zoom_level(
        min(all_latitude), max(all_latitude), margin=margin) + zoom_bias
    lon_zoom_level = _degree_to_zoom_level(
        min(all_longitude), max(all_longitude), margin=margin) + zoom_bias
    zoom_level = min(lat_zoom_level, lon_zoom_level)

    # Identify centre point for plotting
    center = [np.mean(all_latitude), np.mean(all_longitude)]

    # Create map
    interactive_map = folium.Map(
        location=center,
        zoom_start=zoom_level,
        tiles="http://mt1.google.com/vt/lyrs=y&z={z}&x={x}&y={y}",
        attr="Google")

    # Create bounding box coordinates to overlay on map
    line_segments = [(all_latitude[0], all_longitude[0]),
                     (all_latitude[1], all_longitude[1]),
                     (all_latitude[3], all_longitude[3]),
                     (all_latitude[2], all_longitude[2]),
                     (all_latitude[0], all_longitude[0])]

    

    # Add the centroid point as an overlay 
    if centroid is not None:
        interactive_map.add_child(
        folium.Circle(location=[centroid[0],centroid[-1]],
                                 color='red',
                                 opacity=1,
                                radius=10000,
                               fill=False
                              ),
        
        
        )



        
    else:
        # Add bounding box as an overlay
        interactive_map.add_child(
        folium.features.PolyLine(locations=line_segments,
                                 color='red',
                                 opacity=0.8))
        

    # Add clickable lat-lon popup box
    interactive_map.add_child(folium.features.LatLngPopup())

    return interactive_map

def rgb(ds,
        bands=['nbart_red', 'nbart_green', 'nbart_blue'],
        index=None,
        index_dim='time',
        robust=True,
        percentile_stretch=None,
        col_wrap=4,
        size=6,
        aspect=None,
        titles=None,
        savefig_path=None,
        savefig_kwargs={},
        **kwargs):
    """
    Plots RGB images from an xarray Dataset using specified bands, with support for single or multiple observations.

    This function serves as a convenient wrapper around xarray’s `.plot.imshow()` for creating true-color or false-color
    composite images from satellite data. It allows selecting specific observations by index or creating faceted plots
    when multiple images are selected.

    Images can optionally be saved to a file by specifying a save path.

    Last modified: July 2025

    Adapted from dc_rgb.py by John Rattz:
    https://github.com/ceos-seo/data_cube_utilities/blob/master/data_cube_utilities/dc_rgb.py

    Parameters
    ----------
    ds : xarray.Dataset
        Input dataset containing imagery bands with spatial dimensions and optionally a time dimension.
    
    bands : list of str, optional
        List of three band names (strings) to use for RGB channels. Defaults to
        ['nbart_red', 'nbart_green', 'nbart_blue'].

    index : int or list of int, optional
        Index or list of indices along the `index_dim` dimension selecting observations to plot.
        If multiple indices are given, a faceted plot will be created. Defaults to None (plot all).

    index_dim : str, optional
        Dimension name to index along when selecting observations. Defaults to 'time'.

    robust : bool, optional
        Whether to scale the image color limits using 2nd and 98th percentiles (robust stretching).
        Defaults to True.

    percentile_stretch : tuple(float, float), optional
        Tuple specifying manual percentile clipping (e.g., (0.02, 0.98)) for color limits. Overrides `robust` if set.
        Defaults to None.

    col_wrap : int, optional
        Number of columns in faceted plot when plotting multiple images. Defaults to 4.

    size : int or float, optional
        Height in inches of each subplot. Defaults to 6.

    aspect : float or None, optional
        Aspect ratio (width/height) of each subplot. If None, computed automatically based on dataset geobox.

    titles : str or list of str, optional
        Custom titles for each subplot when plotting multiple images. Defaults to None (uses default titles).

    savefig_path : str, optional
        File path to save the generated figure. If None, figure is not saved.

    savefig_kwargs : dict, optional
        Additional keyword arguments passed to `matplotlib.pyplot.savefig()` when saving the figure.

    **kwargs
        Additional keyword arguments passed to `xarray.plot.imshow()` (e.g., `ax` to specify matplotlib axes).

    Returns
    -------
    matplotlib.axes.Axes or FacetGrid
        The matplotlib axes object or seaborn FacetGrid object created by xarray plotting.

    Raises
    ------
    Exception
        If input dataset has multiple observations but no `index` or `col` argument is supplied, instructing user
        to provide either.

    Example
    -------
    >>> rgb(ds, index=0)  # Plot the first image in the time dimension
    >>> rgb(ds, index=[0,1], titles=['Jan', 'Feb'])  # Faceted plot of first two images with custom titles
    >>> rgb(ds, savefig_path='output.png')  # Save the RGB plot to a file
    """

    
    # Get names of x and y dims
    y_dim, x_dim = ds.odc.spatial_dims

    # If ax is supplied via kwargs, ignore aspect and size
    if 'ax' in kwargs:

        # Create empty aspect size kwarg that will be passed to imshow
        aspect_size_kwarg = {}
    else:
        # Compute image aspect
        if not aspect:
            aspect = ds.odc.geobox.aspect

        # Populate aspect size kwarg with aspect and size data
        aspect_size_kwarg = {'aspect': aspect, 'size': size}

    # If no value is supplied for `index` (the default), plot using default
    # values and arguments passed via `**kwargs`
    if index is None:

        # Select bands and convert to DataArray
        da = ds[bands].to_array().compute()

        # If percentile_stretch == True, clip plotting to percentile vmin, vmax
        if percentile_stretch:
            vmin, vmax = da.quantile(percentile_stretch).values
            kwargs.update({'vmin': vmin, 'vmax': vmax})

        # If there are more than three dimensions and the index dimension == 1,
        # squeeze this dimension out to remove it
        if ((len(ds.dims) > 2) and ('col' not in kwargs) and
            (len(da[index_dim]) == 1)):

            da = da.squeeze(dim=index_dim)

        # If there are more than three dimensions and the index dimension
        # is longer than 1, raise exception to tell user to use 'col'/`index`
        elif ((len(ds.dims) > 2) and ('col' not in kwargs) and
              (len(da[index_dim]) > 1)):

            raise Exception(
                f'The input dataset `ds` has more than two dimensions: '
                f'{list(ds.dims.keys())}. Please select a single observation '
                'using e.g. `index=0`, or enable faceted plotting by adding '
                'the arguments e.g. `col="time", col_wrap=4` to the function '
                'call')

        img = da.plot.imshow(x=x_dim,
                             y=y_dim,
                             robust=robust,
                             col_wrap=col_wrap,
                             **aspect_size_kwarg,
                             **kwargs)
        if titles is not None:
            for ax, title in zip(img.axs.flat, titles):
                ax.set_title(title)

    # If values provided for `index`, extract corresponding observations and
    # plot as either single image or facet plot
    else:

        # If a float is supplied instead of an integer index, raise exception
        if isinstance(index, float):
            raise Exception(
                f'Please supply `index` as either an integer or a list of '
                'integers')

        # If col argument is supplied as well as `index`, raise exception
        if 'col' in kwargs:
            raise Exception(
                f'Cannot supply both `index` and `col`; please remove one and '
                'try again')

        # Convert index to generic type list so that number of indices supplied
        # can be computed
        index = index if isinstance(index, list) else [index]

        # Select bands and observations and convert to DataArray
        da = ds[bands].isel(**{index_dim: index}).to_array().compute()

        # If percentile_stretch == True, clip plotting to percentile vmin, vmax
        if percentile_stretch:
            vmin, vmax = da.quantile(percentile_stretch).values
            kwargs.update({'vmin': vmin, 'vmax': vmax})

        # If multiple index values are supplied, plot as a faceted plot
        if len(index) > 1:

            img = da.plot.imshow(x=x_dim,
                                 y=y_dim,
                                 robust=robust,
                                 col=index_dim,
                                 col_wrap=col_wrap,
                                 **aspect_size_kwarg,
                                 **kwargs)
            if titles is not None:
                for ax, title in zip(img.axs.flat, titles):
                    ax.set_title(title)

        # If only one index is supplied, squeeze out index_dim and plot as a
        # single panel
        else:

            img = da.squeeze(dim=index_dim).plot.imshow(robust=robust,
                                                        **aspect_size_kwarg,
                                                        **kwargs)
            if titles is not None:
                for ax, title in zip(img.axs.flat, titles):
                    ax.set_title(title)

    # If an export path is provided, save image to file. Individual and
    # faceted plots have a different API (figure vs fig) so we get around this
    # using a try statement:
    if savefig_path:

        print(f'Exporting image to {savefig_path}')

        try:
            img.fig.savefig(savefig_path, **savefig_kwargs)
        except:
            img.figure.savefig(savefig_path, **savefig_kwargs)

def single_band(ds,
        band=None,
        index=None,
        index_dim='time',
        robust=True,
        vmin=None,
        vmax=None,
        label=None,
        col_wrap=4,
        size=6,
        aspect=None,
        titles=None,
        savefig_path=None,
        savefig_kwargs={},
        **kwargs):
    """
    Parameters
    ----------  
    ds : xarray datarray
        A two-dimensional or multi-dimensional array to plot as an RGB 
        image. If the array has more than two dimensions (e.g. multiple 
        observations along a 'time' dimension), either use `index` to 
        select one (`index=0`) or multiple observations 
        (`index=[0, 1]`), or create a custom faceted plot using e.g. 
        `col="time"`.       
    bands : list of strings, optional
        A list of three strings giving the band names to plot. Defaults 
        to '['nbart_red', 'nbart_green', 'nbart_blue']'.
    index : integer or list of integers, optional
        `index` can be used to select one (`index=0`) or multiple 
        observations (`index=[0, 1]`) from the input dataset for 
        plotting. If multiple images are requested these will be plotted
        as a faceted plot.
    index_dim : string, optional
        The dimension along which observations should be plotted if 
        multiple observations are requested using `index`. Defaults to 
        `time`.
    robust : bool, optional
        Produces an enhanced image where the colormap range is computed 
        with 2nd and 98th percentiles instead of the extreme values. 
        Defaults to True.
    percentile_stretch : tuple of floats
        An tuple of two floats (between 0.00 and 1.00) that can be used 
        to clip the colormap range to manually specified percentiles to 
        get more control over the brightness and contrast of the image. 
        The default is None; '(0.02, 0.98)' is equivelent to 
        `robust=True`. If this parameter is used, `robust` will have no 
        effect.
    col_wrap : integer, optional
        The number of columns allowed in faceted plots. Defaults to 4.
    size : integer, optional
        The height (in inches) of each plot. Defaults to 6.
    aspect : integer, optional
        Aspect ratio of each facet in the plot, so that aspect * size 
        gives width of each facet in inches. Defaults to None, which 
        will calculate the aspect based on the x and y dimensions of 
        the input data.
    titles : string or list of strings, optional
        Replace the xarray 'time' dimension on plot titles with a string
        or list of string titles, when a list of index values are
        provided, of your choice. Defaults to None.
    savefig_path : string, optional
        Path to export image file for the RGB plot. Defaults to None, 
        which does not export an image file.
    savefig_kwargs : dict, optional
        A dict of keyword arguments to pass to 
        `matplotlib.pyplot.savefig` when exporting an image file. For 
        all available options, see: 
        https://matplotlib.org/api/_as_gen/matplotlib.pyplot.savefig.html        
    **kwargs : optional
        Additional keyword arguments to pass to `xarray.plot.imshow()`.
        For example, the function can be used to plot into an existing
        matplotlib axes object by passing an `ax` keyword argument.
        For more options, see:
        http://xarray.pydata.org/en/stable/generated/xarray.plot.imshow.html  
        
    Returns
    -------
    An RGB plot of one or multiple observations, and optionally an image
    file written to file.
    
    """
    
    # Get names of x and y dims
    y_dim, x_dim = ds.odc.spatial_dims

    # If ax is supplied via kwargs, ignore aspect and size
    if 'ax' in kwargs:

        # Create empty aspect size kwarg that will be passed to imshow
        aspect_size_kwarg = {}
    else:
        # Compute image aspect
        if not aspect:
            aspect = ds.odc.geobox.aspect

        # Populate aspect size kwarg with aspect and size data
        aspect_size_kwarg = {'aspect': aspect, 'size': size}

    # If no value is supplied for `index` (the default), plot using default
    # values and arguments passed via `**kwargs`
    if index is None:

        # Select bands and convert to DataArray
        da = ds.to_array().compute()

        # If percentile_stretch == True, clip plotting to percentile vmin, vmax
        
        kwargs.update({'vmin': vmin, 'vmax': vmax})

        # If there are more than three dimensions and the index dimension == 1,
        # squeeze this dimension out to remove it
        if ((len(ds.dims) > 2) and ('col' not in kwargs) and
            (len(da[index_dim]) == 1)):

            da = da.squeeze(dim=index_dim)

        # If there are more than three dimensions and the index dimension
        # is longer than 1, raise exception to tell user to use 'col'/`index`
        elif ((len(ds.dims) > 2) and ('col' not in kwargs) and
              (len(da[index_dim]) > 1)):

            raise Exception(
                f'The input dataset `ds` has more than two dimensions: '
                f'{list(ds.dims.keys())}. Please select a single observation '
                'using e.g. `index=0`, or enable faceted plotting by adding '
                'the arguments e.g. `col="time", col_wrap=4` to the function '
                'call')

        img = da.plot.imshow(x=x_dim,
                             y=y_dim,
                             robust=robust,
                             col_wrap=col_wrap,
                             **aspect_size_kwarg,
                             **kwargs)
        if titles is not None:
            for ax, title in zip(img.axs.flat, titles):
                ax.set_title(title,fontsize=22)
        img.cbar.ax.tick_params(labelsize=30)
        img.cbar.set_label(label=label, size=30, weight='bold')

    # If values provided for `index`, extract corresponding observations and
    # plot as either single image or facet plot
    else:

        # If a float is supplied instead of an integer index, raise exception
        if isinstance(index, float):
            raise Exception(
                f'Please supply `index` as either an integer or a list of '
                'integers')

        # If col argument is supplied as well as `index`, raise exception
        if 'col' in kwargs:
            raise Exception(
                f'Cannot supply both `index` and `col`; please remove one and '
                'try again')

        # Convert index to generic type list so that number of indices supplied
        # can be computed
        index = index if isinstance(index, list) else [index]

        # Select bands and observations and convert to DataArray
        da = ds.isel(**{index_dim: index}).compute()

        # If percentile_stretch == True, clip plotting to percentile vmin, vmax

        kwargs.update({'vmin': vmin, 'vmax': vmax})

        # If multiple index values are supplied, plot as a faceted plot
        if len(index) > 1:

            img = da.plot.imshow(x=x_dim,
                                 y=y_dim,
                                 robust=robust,
                                 col=index_dim,
                                 col_wrap=col_wrap,
                                 **aspect_size_kwarg,
                                 **kwargs)
            if titles is not None:
                for ax, title in zip(img.axs.flat, titles):
                    ax.set_title(title,fontsize=22)

            img.cbar.ax.tick_params(labelsize=30)
            img.cbar.set_label(label=label, size=30, weight='bold')

        # If only one index is supplied, squeeze out index_dim and plot as a
        # single panel
        else:

            img = da.squeeze(dim=index_dim).plot.imshow(robust=robust,
                                                        **aspect_size_kwarg,
                                                        **kwargs)
            if titles is not None:
                for ax, title in zip(img.axs.flat, titles):
                    ax.set_title(title,fontsize=22)
    

            img.cbar.ax.tick_params(labelsize=30)
            img.cbar.set_label(label=label, size=30, weight='bold')
    # If an export path is provided, save image to file. Individual and
    # faceted plots have a different API (figure vs fig) so we get around this
    # using a try statement:
    if savefig_path:

        print(f'Exporting image to {savefig_path}')

        try:
            img.fig.savefig(savefig_path, **savefig_kwargs)
        except:
            img.figure.savefig(savefig_path, **savefig_kwargs)
    return img


def urban_growth_plot(ds,urban_area,baseline_year,analysis_year):
    """

    Last modified: July 2025

    Rewritten from the DEA notebook here
    https://knowledge.dea.ga.gov.au/notebooks/Real_world_examples/Urban_change_detection/

    Plots urban growth between two specified years using data from a dataset.

    This function visualizes urban extent for a baseline year as a grey background,
    and highlights areas of new urban growth (change from non-urban to urban) 
    between the baseline year and analysis year in red.

    Parameters:
    -----------
    ds : xarray.Dataset
        Dataset containing urban index data with a variable `ENDISI` indexed by `year`.
    
    urban_area : xarray.DataArray
        Binary or categorical DataArray indicating urban extent by year. 
        Values of 1 indicate urban areas, and 0 (or other) indicate non-urban.
    
    baseline_year : int
        The starting year to visualize urban extent.
    
    analysis_year : int
        The ending year to compare against the baseline year to detect urban growth.

    Notes:
    ------
    - The plot shows baseline urban areas in grey.
    - Urban growth hotspots (areas non-urban at baseline but urban at analysis) are highlighted in red.
    - The plot legend indicates growth hotspots, areas that remained urban, and non-urban regions.
    - Adapted from the Digital Earth Australia urban change detection notebook (July 2025).

    Example:
    --------
    >>> urban_growth_plot(ds, urban_area, 2015, 2020)

    
    """
    # Plot urban extent from first year in grey as a background
    plot = ds.ENDISI.sel(year=baseline_year).plot(cmap='Greys',
                                           size=6,
                                           aspect=ds.y.size / ds.y.size,
                                           add_colorbar=False,
                                          
                                          )
  
    # Plot the meaningful change in urban area
    to_urban = '#b91e1e'
    urban_area_diff = urban_area.sel(year=analysis_year)-urban_area.sel(year=baseline_year)
    xr.where(urban_area_diff == 1, 1,
             np.nan).plot(ax=plot.axes,
                          add_colorbar=False,
                          cmap=ListedColormap([to_urban]))
    
    # Add the legend
    plot.axes.legend([Patch(facecolor=to_urban),
                      Patch(facecolor='darkgrey'),
                      Patch(facecolor='white')],
                     ['Urban growth hotspots', 'Remains urban'])
    plt.title('Urban growth between ' + str(baseline_year) + ' and ' +
              str(analysis_year));
    

Fmask plot

In [None]:
def linear_stretch(band, lower_percent=2, upper_percent=98):
    """
    Perform linear contrast stretching on a single band.

    Parameters:
        band (np.ndarray): 2D array (single band).
        lower_percent (float): Lower percentile for stretch.
        upper_percent (float): Upper percentile for stretch.

    Returns:
        np.ndarray: Stretched band with values normalized to [0,1].
    """
    lower = np.percentile(band, lower_percent)
    upper = np.percentile(band, upper_percent)

    stretched = (band - lower) / (upper - lower)
    stretched = np.clip(stretched, 0, 1)  # limit values to [0,1]
    return stretched

def stretch_rgb(rgb_image, lower_percent=2, upper_percent=98):
    """
    Apply linear contrast stretching to each band in an RGB image.

    Parameters:
        rgb_image (np.ndarray): 3D array (H, W, 3).
    
    Returns:
        np.ndarray: Contrast-stretched RGB image.
    """
    stretched = np.zeros_like(rgb_image)
    for i in range(3):  # For each band R, G, B
        stretched[..., i] = linear_stretch(rgb_image[..., i], lower_percent, upper_percent)
        
    return stretched

In [None]:
import numpy as np
import math
import numpy as np
import math

def log_stretching_reflectance_optimized(arr,REF_MIN_THRESHOLD=0.01,REF_MAX_THRESHOLD=0.65):
    """
    Apply log stretching optimized for float reflectance values (-1.0 to 1.0).

    This function maps positive reflectance values into an output range of [1, 255].
    Negative or zero values are excluded from the stretch and assigned an output value of 0.

    Parameters:
    arr (np.ndarray): Input array (2D or 3D) of float values, range -1.0 to 1.0.

    Returns:
    np.ndarray: Stretched array (same shape as input) in uint8 format (0-255 range).
    """
    # Define thresholds for the log stretch based on typical display values (e.g., 0.01 to 0.65)
    # These must be positive reflectance values > 0
    #REF_MIN_THRESHOLD = 0.01 
    #REF_MAX_THRESHOLD = 0.65

    # Define the output range for stretched data
    LOW_VALUE = 1
    HIGH_VALUE = 255

    # Calculate log thresholds
    # Ensure thresholds are > 0 before taking the log
    low_thresh_log = math.log(max(REF_MIN_THRESHOLD, 1e-6))
    high_thresh_log = math.log(max(REF_MAX_THRESHOLD, REF_MIN_THRESHOLD))
    thresh_diff = high_thresh_log - low_thresh_log
    
    # Check if a meaningful stretch can be performed
    if thresh_diff == 0:
        print("gen_browse_img: high and low thresholds are equal.")
        return np.zeros_like(arr, dtype=np.uint8)

    # Convert to float32 and ensure it's a copy we can modify
    arr_float = arr.astype(np.float32)
    
    # Identify data points that are valid for log stretching (non-NaN and strictly positive)
    # This also naturally excludes the original -1.0 values
    valid_mask = ~np.isnan(arr_float) & (arr_float > 0)
    
    if not np.any(valid_mask):
        print("gen_browse_img: no valid positive data.")
        return np.zeros_like(arr, dtype=np.uint8)

    # Initialize the output array as uint8, defaulted to 0 (background/invalid)
    output_arr = np.zeros_like(arr, dtype=np.uint8)

    # Use boolean indexing to work only with valid positive data
    valid_data_subset = arr_float[valid_mask]
    
    # 1. Clip the positive data within the defined reflectance thresholds
    clipped_data = np.clip(valid_data_subset, REF_MIN_THRESHOLD, REF_MAX_THRESHOLD)

    # 2. Apply log transformation
    log_transformed_data = np.log(clipped_data)

    # 3. Rescale the log values to the 1-255 range
    stretched_values = HIGH_VALUE * (log_transformed_data - low_thresh_log) / thresh_diff

    # 4. Clip the results to ensure they stay exactly within [LOW_VALUE, HIGH_VALUE]
    final_clipped_values = np.clip(stretched_values, LOW_VALUE, HIGH_VALUE)

    # 5. Place the final uint8 values back into the output array using the mask
    output_arr[valid_mask] = final_clipped_values.astype(np.uint8)
    
    # All non-positive, NaN, or invalid data points remain 0 as intended.
    return output_arr





In [None]:


def plot_fmask_raster(ax, raster, i=None, fig=None, show_colorbar=False):
    """
    Plot a categorical Fmask raster with a custom colormap and optional colorbar.
    
    Parameters:
        ax (matplotlib.axes.Axes): The axis to plot on.
        raster (np.ndarray): 2D raster array (values 1–4, NaNs allowed).
        i (int, optional): Index of subplot (used for conditional colorbar).
        fig (matplotlib.figure.Figure, optional): Figure handle, required if show_colorbar=True.
        show_colorbar (bool, optional): Whether to display a colorbar (e.g., for last subplot).
    """
    # Define categorical boundaries and colors
    vmin, vmax = 1, 4
    bounds = [1, 2, 3, 4, 5]
    colors_ = ['darkblue', 'cyan', 'grey', 'white']

    custom_cmap = colors.ListedColormap(colors_)
    custom_cmap.set_bad(color='black')  # for NaN/masked values
    norm = colors.BoundaryNorm(bounds, custom_cmap.N)

    # Plot raster
    ax.imshow(raster, cmap=custom_cmap, norm=norm, interpolation='nearest')

    cax = None
    
    # Optionally add colorbar
    if show_colorbar and fig is not None:
        cax = fig.add_axes([
            ax.get_position().x1 + 0.01,  # slightly right of plot
            ax.get_position().y0,         # aligned bottom
            0.015,                        # width
            ax.get_position().height      # same height
        ])

        cbar = plt.colorbar(
            plt.cm.ScalarMappable(norm=norm, cmap=custom_cmap),
            boundaries=bounds,
            cax=cax
        )

        cbar.ax.set_yticks([1.5, 2.5, 3.5, 4.5])
        cbar.ax.set_yticklabels(['Water', 'Cloud Shadow', 'Snow/Ice', 'Cloud'])
        cbar.ax.tick_params(labelsize=10)

    ax.set_axis_off()

    return cax


def Cirrus_percentile(arr, p_min=10, p_max=90):
    """
    Stretch a cirrus (B09) band to enhance visibility using percentile clipping.

    Parameters:
        arr (np.ndarray): 2D cirrus band array.
        p_min, p_max (float): Percentile limits for contrast stretch.

    Returns:
        np.ndarray: Contrast-stretched cirrus band (float, scaled 0–1).
    """
    arr = arr.astype(np.float32)
    v_min = np.nanpercentile(arr, p_min)
    v_max = np.nanpercentile(arr, p_max)
    arr = np.clip(arr, v_min, v_max)
    return arr.astype(np.uint8)


def plot_cirrus_band(ax, cirrus_band, cmap='viridis', p_min=10, p_max=90, title="Cirrus (B10)"):
    """
    Plot the cirrus band (B10) as a grayscale image with percentile stretching.

    Parameters:
        ax (matplotlib.axes.Axes): Axis to plot on.
        cirrus_band (np.ndarray): 2D cirrus band array.
        cmap (str): Colormap for grayscale visualization (default: 'Greys').
        p_min, p_max (float): Percentile limits for stretch.
        title (str): Plot title.
    """
    stretched = Cirrus_percentile(cirrus_band, p_min, p_max)
    # Get the colormap and set NaN color to black

    cmap = cm.get_cmap(cmap).copy()

    # 2. Set the color for bad values (e.g., NaNs)
    # You can use color names (like 'red', 'white', 'black') or hex codes
    cmap.set_bad(color='black') 

    ax.imshow(stretched, cmap=cmap)
    ax.set_title(title, fontsize=12)
    ax.axis('off')




In [None]:
# ------------------------------------------------------------
#  Helper function
# ------------------------------------------------------------
def plot_fmask_diff(raster_list, axes, fig):
    """
    Compare cloud and cloud-shadow masks between Fmask 4.7 and Fmask 5.

    Parameters:
        raster_list (list of np.ndarray): [RGB, Fmask4.7, Fmask5, ...].
        axes (list): List of matplotlib Axes for plotting.
        fig (matplotlib.figure.Figure): Figure object for colorbar positioning.
    """
    for j in [4, 2]:  # 4 = Cloud, 2 = Cloud shadow
        # Select subplot based on j
        ax = axes[5] if j == 4 else axes[4]

        # Deep copy so originals are not modified
        raster_fmask_47 = copy.deepcopy(raster_list[1])
        raster_fmask_5  = copy.deepcopy(raster_list[2])

        # --- Binary mask for selected class (cloud or shadow)
        raster_fmask_47 = np.where(raster_fmask_47 == j, 1, 0)
        raster_fmask_5  = np.where(raster_fmask_5 == j, 1, 0)

        # --- Define valid regions (where either detects)
        mask_union = (raster_fmask_47 + raster_fmask_5)
        mask_union = np.where(mask_union != 0, 1, np.nan)

        # --- Compute signed difference
        raster_diff = (raster_fmask_47 - raster_fmask_5) * mask_union

        # --- Visualization setup
        mask_values = [-1, 0, 1]  # -1 = only Fmask5, 0 = same, +1 = only Fmask4.7
        colors_diff = ['darkorange', 'white', 'dodgerblue']
        cmap_diff = colors.ListedColormap(colors_diff)
        cmap_diff.set_bad('black')
        norm_diff = colors.BoundaryNorm([-1.5, -0.5, 0.5, 1.5], cmap_diff.N)

        # --- Plot diff map
        im = ax.imshow(raster_diff, cmap=cmap_diff, norm=norm_diff, interpolation='nearest')
        ax.axis("off")

        # --- Add colorbar for cloud shadow subplot only
        cax = None
        
        if j == 4:
            cax = fig.add_axes([
                ax.get_position().x1 + 0.0025,
                ax.get_position().y0,
                0.015,
                ax.get_position().height
            ])
            cbar = plt.colorbar(
                plt.cm.ScalarMappable(norm=norm_diff, cmap=cmap_diff),
                cax=cax
            )
            cbar.ax.set_yticks([-1, 0, 1])
            cbar.ax.set_yticklabels([
                "Only Fmask5",
                "No difference",
                "Only Fmask4.7"
            ])

        # --- Add title
        if j == 4:
            ax.set_title("Cloud mask (Fmask 4.7 - Fmask 5)", fontsize=10)
        else:
            ax.set_title("Cloud shadow (Fmask 4.7 - Fmask 5)", fontsize=10)

    return cax

In [None]:
def plot_fmask_percentage_by_day(day, df_fmask_percentage, axes):
    """
    Plot Fmask percentage barplot for a specific day, hiding ticks on other axes.

    Parameters:
        day (str): Date string in format '%Y%m%dT%H%M%S', e.g., '20191006T105941'
        df_fmask_percentage (pd.DataFrame): DataFrame with columns ['Date', 'Features', 'Percentage', 'fmask_version']
        axes (list): List of matplotlib Axes, seaborn will plot on axes[3]
    """
    # Convert day string to target datetime format
    # day_convert = datetime.strptime(day, '%Y%m%d').strftime('%Y-%m-%d')
    day_convert = day
    # Filter dataframe for selected day
    df_filtered = df_fmask_percentage.loc[df_fmask_percentage['Date'] == day_convert]

    ax7 = axes[7]

    # Plot barplot on ax4
    sns.barplot(
        data=df_filtered,
        x='Features',
        y='Percentage',
        hue='fmask_version',
        order=['Water', 'Snow/Ice', 'Cloud Shadow', 'Cloud'],
        palette='magma',
        ax=ax7
    )


In [None]:
def share_axes_group(axes, end_index):
    """
    Share x and y axes among subplots up to end_index.
    Works with modern Matplotlib (>=3.6).
    
    Parameters:
        axes (list or np.ndarray): Collection of subplot axes.
        end_index (int): Last index to include in shared group.
    """
    if not isinstance(axes, (list, np.ndarray)):
        axes = np.ravel(axes)

    # Reference axis
    ref_ax = axes[0]

    # Share x/y axes among the specified range
    for ax in axes[:end_index + 1]:
        if ax is not ref_ax:
            ax.sharex(ref_ax)
            ax.sharey(ref_ax)

    

def customize_axes_layout(axes, cax_fmask, cax_fmask_diff, i, fig):
    """
    Adjust colorbar positions, hide ticks and certain axes, and share axes for interactive comparison.

    Parameters:
        axes (list): List of matplotlib Axes.
        cax_fmask (Axes): Colorbar axes for fmask.
        cax_fmask_diff (Axes): Colorbar axes for difference.
        i (int): Current subplot index to control conditional adjustments.
    """
    fig.subplots_adjust(left=0.0, 
                    right=0.86, 
                    top=0.92, 
                    bottom=0.05, 
                    wspace=0.1, hspace=0.1)

    if i in [2, 5]:
        # Adjust colorbar positions
        ax2 = axes[2]
        cax_fmask.set_position([
            ax2.get_position().x1 + 0.0025,
            ax2.get_position().y0,
            0.015,
            ax2.get_position().height
        ])

        ax5 = axes[5]
        # Note: original code used ax4 height for cax_fmask_diff, assuming a typo? Using ax5 here.
        cax_fmask_diff.set_position([
            ax5.get_position().x1 + 0.0025,
            ax5.get_position().y0,
            0.015,
            ax5.get_position().height
        ])

    # Remove x and y ticks from all axes except axes[7]
    for idx, ax in enumerate(axes):
        if idx != 7:
            ax.set_xticks([])
            ax.set_yticks([])

    # Turn off specific axes visibility
    for k in [6, 8]:
        if k < len(axes):
            axes[k].set_visible(False)

    # Share x and y axes among axes 0 to 5 for interactive comparison
    share_axes_group(axes, 5)


In [None]:
def plot_fmask_comparison_stats(
    df_fmask_percentage,
    fmask4_col='Fmask4',
    fmask5_col='Fmask5',
    figsize=(8.97, 8.69),
    palette='magma',
    violin_gap=0.1,
    scatter_alpha=0.6,
    line_width=1,
    annotation_fontsize=12,
    top_legend_loc='upper left',
    bottom_legend_loc='lower right',
    save_path=None,
    dpi=300
):
    """
    Create a two-panel comparison plot for Fmask 4.7 and 5.0.
    
    Parameters:
    -----------
    df_fmask_percentage : pandas DataFrame
        DataFrame containing Fmask percentage data with columns:
        - 'Features': Feature class names
        - 'Percentage': Cloud coverage percentage
        - 'fmask_version': Fmask version ('Fmask4' or 'Fmask5')
        - 'Date': Date of observation
        - 'mrgs_id': MGRS tile ID
    fmask4_col : str, default='Fmask4'
        Column name for Fmask 4.7 version in the pivot table
    fmask5_col : str, default='Fmask5'
        Column name for Fmask 5.0 version in the pivot table
    figsize : tuple, default=(8.97, 8.69)
        Figure size (width, height) in inches
    palette : str, default='magma'
        Color palette for the violin plot
    violin_gap : float, default=0.1
        Gap between violins in the split violin plot
    scatter_alpha : float, default=0.6
        Transparency of scatter points in regression plot
    line_width : float, default=1
        Width of regression line
    annotation_fontsize : int, default=12
        Font size for R² annotations
    top_legend_loc : str, default='upper left'
        Location of legend in top panel
    bottom_legend_loc : str, default='lower right'
        Location of legend in bottom panel
    save_path : str, optional
        Path to save the figure. If None, figure is not saved.
    dpi : int, default=300
        Resolution for saved figure
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The figure object
    axes : numpy.ndarray
        Array of axes objects
    """
    
    # Create figure with two subplots
    fig, axes = plt.subplots(figsize=figsize, nrows=2)
    ax_top = axes[0]
    ax_bottom = axes[1]
    
    # Top panel: Violin plot comparing distributions
    sns.violinplot(
        data=df_fmask_percentage,
        x="Features",
        y="Percentage",
        hue="fmask_version",
        split=True,
        inner="quart",
        fill=False,
        gap=violin_gap,
        palette=palette,
        ax=ax_top
    )
    ax_top.legend(loc=top_legend_loc)
    
    # Bottom panel: Regression plots for correlation analysis
    # Group and aggregate data
    df_grouped = df_fmask_percentage.groupby(['Date', 'mrgs_id', 'fmask_version', 'Features'])
    df_mean = df_grouped['Percentage'].mean().reset_index()
    
    # Pivot table for regression analysis
    df_pivot = df_mean.pivot_table(
        index=['Date', 'mrgs_id', 'Features'],
        columns='fmask_version',
        values='Percentage'
    ).reset_index()
    
    # Get unique feature classes
    unique_classes = np.unique(df_pivot['Features'])
    n_classes = len(unique_classes)
    
    # Calculate y-coordinate positions for annotations dynamically
    y_coords = np.linspace(0.95, 0.65, n_classes)
    
    # Plot regression for each feature class
    for idx, class_name in enumerate(unique_classes):
        df_class = df_pivot.loc[df_pivot['Features'] == class_name]
        
        # Remove NaN values for regression
        df_clean = df_class.dropna(subset=[fmask4_col, fmask5_col])
        
        if len(df_clean) > 0:
            # Create regression plot
            sns.regplot(
                data=df_clean,
                x=fmask4_col,
                y=fmask5_col,
                ax=ax_bottom,
                scatter_kws={'alpha': scatter_alpha},
                line_kws={'linewidth': line_width},
                label=class_name
            )
            
            # Calculate regression statistics
            x = df_clean[fmask4_col]
            y = df_clean[fmask5_col]
            slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
            r_squared = r_value**2
            
            # Annotate with R-squared and p-value
            ax_bottom.annotate(
                f'{class_name}: R² = {r_squared:.2f} (p={p_value:.3f})',
                xy=(0.05, y_coords[idx]),
                xycoords='axes fraction',
                fontsize=annotation_fontsize,
                color='black'
            )
    
    ax_bottom.legend(loc=bottom_legend_loc)
    plt.tight_layout()
    
    # Save figure if path is provided
    if save_path is not None:
        fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    
    return fig, axes

In [None]:
def plot_cloud_coverage_timeseries(

    cloud_coverage_df,

    figsize=(12, 8),

    linewidth=1,

    markersize=4,

    show_stats=True,

    show_plot=True

):

    """

    Plot cloud coverage time series comparison between Fmask 4.7 and 5.0.

    Parameters

    ----------

    cloud_coverage_df : pandas.DataFrame

        DataFrame with columns: 'time', 'fmask4_coverage', 'fmask5_coverage', 'difference'.

        Output from calculate_cloud_coverage_timeseries().

    figsize : tuple, optional

        Figure size (width, height). Default is (12, 8).

    linewidth : float, optional

        Line width for time series plots. Default is 1.

    markersize : float, optional

        Marker size for time series plots. Default is 4.

    show_stats : bool, optional

        If True, print summary statistics. Default is True.

    show_plot : bool, optional

        If True, display the plot using plt.show(). Default is True.

    Returns

    -------

    fig : matplotlib.figure.Figure

        The matplotlib figure object.

    axes : tuple of matplotlib.axes.Axes

        Tuple of (ax1, ax2) for the two subplots.

    """

    import matplotlib.pyplot as plt

    import matplotlib.dates as mdates

    if cloud_coverage_df is None or len(cloud_coverage_df) == 0:

        print("Error: cloud_coverage_df is None or empty")

        return None, None

    # Create figure with two subplots

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True)

    # Plot 1: Time series comparison

    ax1.plot(cloud_coverage_df['time'], cloud_coverage_df['fmask4_coverage'],

             'o-', label='Fmask 4.7', color='#1f77b4', linewidth=linewidth, markersize=markersize)

    ax1.plot(cloud_coverage_df['time'], cloud_coverage_df['fmask5_coverage'],

             's-', label='Fmask 5.0', color='#ff7f0e', linewidth=linewidth, markersize=markersize)

    ax1.set_ylabel('Cloud Coverage (%)', fontsize=12)

    ax1.set_title('Cloud Coverage Time Series Comparison', fontsize=14, fontweight='bold')

    ax1.legend(loc='best', fontsize=11)

    ax1.grid(True, alpha=0.3)

    ax1.set_ylim(bottom=0)

    # Add mean lines

    mean_4 = cloud_coverage_df['fmask4_coverage'].mean()

    mean_5 = cloud_coverage_df['fmask5_coverage'].mean()

    ax1.axhline(y=mean_4, color='#1f77b4', linestyle='--', alpha=0.5, linewidth=1.5)

    ax1.axhline(y=mean_5, color='#ff7f0e', linestyle='--', alpha=0.5, linewidth=1.5)

    ax1.text(ax1.get_xlim()[1]*0.98, mean_4, f'Mean: {mean_4:.2f}%',

             ha='right', va='bottom', color='#1f77b4', fontsize=9, alpha=0.7)

    ax1.text(ax1.get_xlim()[1]*0.98, mean_5, f'Mean: {mean_5:.2f}%',

             ha='right', va='top', color='#ff7f0e', fontsize=9, alpha=0.7)

    # Plot 2: Difference time series

    colors = ['red' if x > 0 else 'blue' if x < 0 else 'gray' for x in cloud_coverage_df['difference']]

    ax2.bar(cloud_coverage_df['time'], cloud_coverage_df['difference'],

            color=colors, alpha=0.6, width=0.8)

    ax2.axhline(y=0, color='black', linestyle='-', linewidth=1)

    ax2.set_ylabel('Difference (Fmask 5.0 - 4.7) (%)', fontsize=12)

    ax2.set_xlabel('Time', fontsize=12)

    ax2.set_title('Cloud Coverage Difference Over Time', fontsize=14, fontweight='bold')

    ax2.grid(True, alpha=0.3, axis='y')

    # Add mean difference line

    mean_diff = cloud_coverage_df['difference'].mean()

    ax2.axhline(y=mean_diff, color='purple', linestyle='--', alpha=0.7, linewidth=1.5)

    ax2.text(ax2.get_xlim()[1]*0.98, mean_diff, f'Mean diff: {mean_diff:.2f}%',

             ha='right', va='bottom' if mean_diff > 0 else 'top',

             color='purple', fontsize=9, fontweight='bold')

    # Format x-axis dates

    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))

    ax2.xaxis.set_major_locator(mdates.AutoDateLocator())

    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right')

    plt.tight_layout()

    # Show plot if requested

    if show_plot:

        plt.show()

    # Print summary statistics if requested

    if show_stats:

        print(f"\nSummary Statistics:")

        print(f"  Mean Fmask 4.7 coverage: {mean_4:.2f}%")

        print(f"  Mean Fmask 5.0 coverage: {mean_5:.2f}%")

        print(f"  Mean difference: {mean_diff:.2f}%")

        print(f"  Total time steps: {len(cloud_coverage_df)}")

    return fig, (ax1, ax2)



def visualize_and_export_cloud_frequency(

    fmask4_freq,

    fmask5_freq,

    fmask4_red=None,

    fmask4_green=None,

    fmask4_blue=None,

    output_dir='./outputs',

    mrgs_id=None,

    mask_type="cloud",  # "cloud" or "cloud_shadow"

    save_format=['geotiff', 'png'],

    figsize=(15, 12),

    cmap='viridis',

    vmin=0,

    vmax=1,

    export=False,

    show_plot=True,

    auto_zoom=True,

    extent=None  # Optional: tuple (x_min, x_max, y_min, y_max) to set plot extent

):

    """

    Visualize and optionally export cloud frequency maps for Fmask 4.7 and 5.0.

    Parameters

    ----------

    fmask4_freq : xarray.DataArray

        Cloud frequency map for Fmask 4.7 with dimensions (y, x).

        Values range from 0.0 (never cloud) to 1.0 (always cloud).

    fmask5_freq : xarray.DataArray

        Cloud frequency map for Fmask 5.0 with dimensions (y, x).

        Values range from 0.0 (never cloud) to 1.0 (always cloud).

    fmask4_red : xarray.DataArray, optional

        Red band DataArray for Fmask 4.7 with dimensions (y, x).

        Expected to be a 2D array (median-computed or single time step).

        If provided along with green and blue, an RGB composite will be plotted.

        Default is None.

    fmask4_green : xarray.DataArray, optional

        Green band DataArray for Fmask 4.7 with dimensions (y, x).

        Expected to be a 2D array. Default is None.

    fmask4_blue : xarray.DataArray, optional

        Blue band DataArray for Fmask 4.7 with dimensions (y, x).

        Expected to be a 2D array. Default is None.

    output_dir : str, optional

        Directory to save exported files. Default is './outputs'.

    mrgs_id : str, optional

        MGRS tile identifier for filename. If None, uses 'unknown'.

    mask_type : str, optional

        Type of mask being visualized. "cloud" or "cloud_shadow". Default is "cloud".

    save_format : list, optional

        List of formats to export: 'geotiff', 'png', 'netcdf', 'jpg'.

        Default is ['geotiff', 'png']. Only used if export=True.

    figsize : tuple, optional

        Figure size for visualization (width, height). Default is (15, 12).

        Layout is 2x2: top row shows cloud frequencies, bottom row shows RGB and difference.

    cmap : str, optional

        Colormap name for visualization. Default is 'viridis'.

    vmin, vmax : float, optional

        Color scale limits. Default is 0 and 1.

    export : bool, optional

        If True, export files to disk. If False, only visualize.

        Default is False.

    show_plot : bool, optional

        If True, display the plot using plt.show(). Default is True.

    auto_zoom : bool, optional

        If True, automatically applies shared zoom to all three plot panels,

        ensuring they have the same spatial extent for direct comparison.

        Calculates the union of data bounds from all datasets and applies

        the same xlim/ylim to all axes. Default is True.

    extent : tuple, optional

        Optional tuple (x_min, x_max, y_min, y_max) to set a custom plot extent.

        If provided, this extent will override the auto_zoom calculation and be applied

        to all plot panels. If None, auto_zoom will calculate extent from data bounds.

        Default is None.

    Returns

    -------

    tuple or dict

        If export=False: Returns (fig, axes) tuple for further customization.

        If export=True: Returns dict with paths to saved files.

    """

    import os

    import numpy as np

    import matplotlib.pyplot as plt

    import matplotlib.colors as mcolors

    from pathlib import Path

    import xarray as xr

    # Set MGRS ID for filenames

    tile_id = mrgs_id if mrgs_id else 'unknown'

    # Calculate difference map

    difference = fmask5_freq - fmask4_freq

    # Determine if RGB data is available

    has_rgb = all([fmask4_red is not None, fmask4_green is not None, fmask4_blue is not None])

    # Create 2x2 subplot layout

    fig, axes = plt.subplots(2, 2, figsize=figsize, sharex=True, sharey=True)

    axes = axes.flatten()

    # Add figure title with MRGS ID

    if mrgs_id:

        mask_name = "Cloud" if mask_type == "cloud" else "Cloud Shadow"

        fig.suptitle(f'{mask_name} Frequency Comparison - MRGS ID: {mrgs_id}',

                    fontsize=16, fontweight='bold', y=0.98)

    # Set equal aspect ratio for all axes

    for ax in axes:

        ax.set_aspect('equal', adjustable='box')

    # Plot Fmask 4.7 frequency (top left)

    im1 = fmask4_freq.plot(

        ax=axes[0],

        cmap=cmap,

        vmin=vmin,

        vmax=vmax,

        add_colorbar=True,

        cbar_kwargs={'label': f'{mask_name} Frequency'}

    )

    axes[0].set_title(f'Fmask 4.7 {mask_name} Frequency')

    axes[0].set_xlabel('Longitude')

    axes[0].set_ylabel('Latitude')

    # Plot Fmask 5.0 frequency (top right)

    im2 = fmask5_freq.plot(

        ax=axes[1],

        cmap=cmap,

        vmin=vmin,

        vmax=vmax,

        add_colorbar=True,

        cbar_kwargs={'label': f'{mask_name} Frequency'}

    )

    axes[1].set_title(f'Fmask 5.0 {mask_name} Frequency')

    axes[1].set_xlabel('Longitude')

    axes[1].set_ylabel('Latitude')

    # Plot RGB composite or difference map in bottom row

    bottom_left_idx = 2

    bottom_right_idx = 3

    if has_rgb:

        try:

            # Use RGB bands directly (expected to be 2D arrays without time dimension)

            red_median = fmask4_red

            green_median = fmask4_green

            blue_median = fmask4_blue

            # Prepare bands for log stretching (simplified version)

            def prepare_band_for_log_stretch(band):

                """

                Prepare band data for log_stretching_reflectance_optimized."""
                band_data = band.values.copy().astype(np.float32)
                band_data = np.where(band_data == -9999, np.nan, band_data)

                valid_mask = ~np.isnan(band_data)

                if not np.any(valid_mask):

                    return band_data
                valid_data = band_data[valid_mask]

                p_min = np.nanpercentile(valid_data, 0.5)

                p_max = np.nanpercentile(valid_data, 99.5)

                if p_max <= p_min:

                    return np.zeros_like(band_data)

                normalized = (band_data - p_min) / (p_max - p_min)

                reflectance_range = 0.001 + normalized * (0.8 - 0.001)

                reflectance_range[~valid_mask] = np.nan

                return reflectance_range

            def normalize_band_to_01(band_data):

                """

                Normalize band to 0-1 range using its own min/max."""
                band_min = np.nanmin(band_data)
                band_max = np.nanmax(band_data)

                if band_max > band_min:

                    normalized = (band_data - band_min) / (band_max - band_min)

                    normalized = np.where(np.isnan(band_data), np.nan, normalized)

                    return normalized

                else:

                    return np.zeros_like(band_data)

            # Prepare and process bands

            red_prep = prepare_band_for_log_stretch(red_median)

            green_prep = prepare_band_for_log_stretch(green_median)

            blue_prep = prepare_band_for_log_stretch(blue_median)

            red_norm = normalize_band_to_01(red_prep)

            green_norm = normalize_band_to_01(green_prep)

            blue_norm = normalize_band_to_01(blue_prep)

            REF_MIN = 0.01

            REF_MAX = 0.65

            red_scaled = red_norm * (REF_MAX - REF_MIN) + REF_MIN

            green_scaled = green_norm * (REF_MAX - REF_MIN) + REF_MIN

            blue_scaled = blue_norm * (REF_MAX - REF_MIN) + REF_MIN

            # Apply log stretching

            red_stretched = log_stretching_reflectance_optimized(red_scaled, REF_MIN_THRESHOLD=REF_MIN, REF_MAX_THRESHOLD=REF_MAX).astype(np.float32) / 255.0

            green_stretched = log_stretching_reflectance_optimized(green_scaled, REF_MIN_THRESHOLD=REF_MIN, REF_MAX_THRESHOLD=REF_MAX).astype(np.float32) / 255.0

            blue_stretched = log_stretching_reflectance_optimized(blue_scaled, REF_MIN_THRESHOLD=REF_MIN, REF_MAX_THRESHOLD=REF_MAX).astype(np.float32) / 255.0

            # Stack into RGB array

            rgb_array = np.stack([red_stretched, green_stretched, blue_stretched], axis=-1).astype(np.float32)

            rgb_array = np.clip(rgb_array, 0, 1)

            # Plot RGB composite (bottom left)

            axes[bottom_left_idx].imshow(

                rgb_array,

                extent=[red_median.x.min().values, red_median.x.max().values,

                       red_median.y.min().values, red_median.y.max().values],

                origin='upper',

                aspect='equal'

            )

            axes[bottom_left_idx].set_title('Fmask 4.7 RGB Composite (Median)')

            axes[bottom_left_idx].set_xlabel('Longitude')

            axes[bottom_left_idx].set_ylabel('Latitude')

        except Exception as e:

            print(f"Warning: Error creating RGB composite: {e}. Skipping RGB composite.")

            bottom_left_idx = 2

    # Plot difference map (bottom right, or bottom left if no RGB)

    diff_max = max(abs(difference.min().values), abs(difference.max().values))

    diff_ax_idx = bottom_right_idx if has_rgb else bottom_left_idx

    im3 = difference.plot(

        ax=axes[diff_ax_idx],

        cmap='RdBu_r',

        vmin=-diff_max,

        vmax=diff_max,

        add_colorbar=True,

        cbar_kwargs={'label': f'Difference ({mask_name})'}

    )

    axes[diff_ax_idx].set_title('Difference Map')

    axes[diff_ax_idx].set_xlabel('Longitude')

    axes[diff_ax_idx].set_ylabel('Latitude')

    # Hide unused subplot if RGB is not available

    if not has_rgb:

        axes[bottom_right_idx].set_visible(False)

    # Apply shared zoom if requested

    if auto_zoom:

        # If extent is provided, use it directly

        if extent is not None:

            x_min, x_max, y_min, y_max = extent

            for ax in axes:

                if ax.get_visible():

                    ax.set_xlim(x_min, x_max)

                    ax.set_ylim(y_min, y_max)

        else:

            # Calculate extent from data bounds

            def get_data_bounds(data_array):

                """

                Get bounding box of valid (non-NaN) data."""

                valid_mask = ~np.isnan(data_array.values)

                if not valid_mask.any():

                        return None

                if 'x' in data_array.dims and 'y' in data_array.dims:

                        x_coords = data_array.x.values

                y_coords = data_array.y.values

                valid_y, valid_x = np.where(valid_mask)

                if len(valid_y) == 0 or len(valid_x) == 0:

                        return None

                x_min = x_coords[valid_x.min()]

                x_max = x_coords[valid_x.max()]

                y_min = y_coords[valid_y.min()]

                y_max = y_coords[valid_y.max()]

                return (x_min, x_max, y_min, y_max)


            bounds_list = []

            for data in [fmask4_freq, fmask5_freq, difference]:

                bounds = get_data_bounds(data)

                if bounds:

                    bounds_list.append(bounds)

            if has_rgb:

                try:

                    rgb_bounds = get_data_bounds(fmask4_red)

                    if rgb_bounds:

                        bounds_list.append(rgb_bounds)

                except Exception as e:

                    print(f"Warning: Could not get RGB data bounds for auto_zoom: {e}")

            if bounds_list:

                x_mins, x_maxs, y_mins, y_maxs = zip(*bounds_list)

                x_min = min(x_mins)

                x_max = max(x_maxs)

                y_min = min(y_mins)

                y_max = max(y_maxs)

                x_pad = (x_max - x_min) * 0.02

                y_pad = (y_max - y_min) * 0.02

                for ax in axes:

                    if ax.get_visible():

                        ax.set_xlim(x_min - x_pad, x_max + x_pad)

                        ax.set_ylim(y_min - y_pad, y_max + y_pad)

    # Adjust subplot spacing

    top_margin = 0.92 if mrgs_id else 0.95

    plt.subplots_adjust(left=0.05, right=0.95, top=top_margin, bottom=0.05,

                       wspace=0.1, hspace=0.1)

    plt.tight_layout(rect=[0.05, 0.05, 0.95, top_margin])

    # Show plot if requested

    if show_plot:

        plt.show()

    # Export files only if export=True

    saved_files = {}

    if export:

        output_path = Path(output_dir)

        output_path.mkdir(parents=True, exist_ok=True)

        # Save figure

        if 'png' in save_format or 'jpg' in save_format:

            for fmt in ['png', 'jpg']:

                if fmt in save_format:

                    mask_suffix = "cloud" if mask_type == "cloud" else "cloud_shadow"

                    fig_path = output_path / f'{mask_suffix}_frequency_{tile_id}.{fmt}'

                    plt.savefig(fig_path, dpi=300, bbox_inches='tight')

                    saved_files['figure'] = str(fig_path)

                    print(f"Saved figure: {fig_path}")

        # Export as GeoTIFF

        if 'geotiff' in save_format:

            try:

                import rioxarray

                mask_suffix = "cloud" if mask_type == "cloud" else "cloud_shadow"

                fmask4_path = output_path / f'fmask4_{mask_suffix}_frequency_{tile_id}.tif'

                fmask4_freq.rio.to_raster(fmask4_path)

                saved_files['fmask4_freq'] = str(fmask4_path)

                print(f"Saved Fmask 4.7 frequency map: {fmask4_path}")

                fmask5_path = output_path / f'fmask5_{mask_suffix}_frequency_{tile_id}.tif'

                fmask5_freq.rio.to_raster(fmask5_path)

                saved_files['fmask5_freq'] = str(fmask5_path)

                print(f"Saved Fmask 5.0 frequency map: {fmask5_path}")

                diff_path = output_path / f'{mask_suffix}_frequency_difference_{tile_id}.tif'

                difference.rio.to_raster(diff_path)

                saved_files['difference'] = str(diff_path)

                print(f"Saved difference map: {diff_path}")

            except Exception as e:

                print(f"Warning: Could not export GeoTIFF files: {e}")

        # Export as NetCDF

        if 'netcdf' in save_format:

            try:

                ds = xr.Dataset({

                    'fmask4_frequency': fmask4_freq,

                    'fmask5_frequency': fmask5_freq,

                    'difference': difference

                })

                mask_suffix = "cloud" if mask_type == "cloud" else "cloud_shadow"

                nc_path = output_path / f'{mask_suffix}_frequency_{tile_id}.nc'

                ds.to_netcdf(nc_path)

                saved_files['netcdf'] = str(nc_path)

                print(f"Saved NetCDF file: {nc_path}")

            except Exception as e:

                print(f"Warning: Could not export NetCDF file: {e}")

        saved_files['fig'] = fig

        saved_files['axes'] = axes

        plt.close()

        return saved_files

    else:

        return fig, axes

def calculate_cloud_frequency_per_pixel(fmask_data, mask_value=4):

    """

    Calculate frequency per pixel for a specific mask value from time-series Fmask DataArray.

    Can be used for cloud (mask_value=4) or cloud shadow (mask_value=2).

    Parameters

    ----------

    fmask_data : xarray.DataArray

        Time-series Fmask DataArray with dimensions (time, y, x).

        Values: 1=water, 2=cloud_shadow, 3=snow_ice, 4=cloud, NaN=clear.

    mask_value : int, optional

        Fmask value to calculate frequency for. Default is 4 (cloud).

        Use 2 for cloud shadow, 4 for cloud.

    Returns

    -------

    xarray.DataArray

        Frequency map with dimensions (y, x).

        Values range from 0.0 (never detected) to 1.0 (always detected).

    """

    import numpy as np

    if fmask_data is None:

        return None

    # Count mask occurrences (value == mask_value) along time dimension

    mask_count = (fmask_data == mask_value).sum(dim='time')

    # Count valid (non-NaN) observations along time dimension

    valid_count = (~np.isnan(fmask_data)).sum(dim='time')

    # Calculate frequency: mask_count / valid_count

    mask_frequency = mask_count / valid_count

    # Handle division by zero (set to NaN where no valid observations)

    mask_frequency = mask_frequency.where(valid_count > 0)

    return mask_frequency

In [None]:
def batch_time_series_comparison(
    mrgs_id_list,
    bucket,
    fmask4_files,
    fmask5_files,
    output_dir='./outputs',
    min_valid_ratio=0.6,
    vmin=0.0,
    vmax=1.0,
    cmap='viridis',
    figsize=(15, 12),
    export=False,
    show_plot=True,
    auto_zoom=True,
    mask_value=4,  # 4 for cloud, 2 for cloud shadow
    extent=None  # Optional: tuple (x_min, x_max, y_min, y_max) to set plot extent
):
    """
    Batch process time series comparison (Section 5 workflow) for multiple MRGS IDs.
    
    For each MRGS ID, this function:
    1. Loads Fmask 4.7 and 5.0 time-series data
    2. Filters RGB bands to keep only time steps with sufficient valid pixels
    3. Calculates frequency per pixel for both Fmask versions (cloud or cloud shadow)
    3.5. Calculates and plots cloud coverage time series (line plots showing coverage over time)
    4. Creates a 2x2 panel plot showing:
       - Top left: Fmask 4.7 Frequency
       - Top right: Fmask 5.0 Frequency
       - Bottom left: RGB Composite (if available)
       - Bottom right: Difference Map
    
    Parameters
    ----------
    mrgs_id_list : list of str
        List of MGRS tile IDs to process (e.g., ['10UEV', '31UDQ', '11SLT']).
    bucket : str
        S3 bucket name containing the Fmask data.
    fmask4_files : str
        S3 link/pattern to Fmask 4.7 files. Can be:
        - Full S3 URI with MRGS ID: 's3://bucket/path/31UDQ/'
        - S3 URI with placeholder: 's3://bucket/path/{mrgs_id}/'
        - S3 URI ending with '/': 's3://bucket/path/' (MRGS ID will be appended)
        - List of file keys: ['key1', 'key2', ...]
    fmask5_files : str
        S3 link/pattern to Fmask 5.0 files. Same format options as fmask4_files.
        Example: 's3://hls-debug-output/Fmask5_outputs/TS/S30/{mrgs_id}/'
    output_dir : str, optional
        Directory to save exported files. Default is './outputs'.
    min_valid_ratio : float, optional
        Minimum ratio of valid pixels required for RGB time steps (0.0-1.0).
        Default is 0.6 (60%).
    vmin, vmax : float, optional
        Color scale limits for cloud frequency maps. Default is 0.0 and 1.0.
    cmap : str, optional
        Colormap name for cloud frequency visualization. Default is 'viridis'.
    figsize : tuple, optional
        Figure size (width, height) in inches. Default is (15, 12).
    export : bool, optional
        If True, export plots and data files for each MRGS ID. Default is False.
    show_plot : bool, optional
        If True, display each plot using plt.show(). Default is True.
    auto_zoom : bool, optional
        If True, applies shared zoom to all panels for direct comparison.
        Default is True.
    mask_value : int, optional
        Fmask value to analyze. Default is 4 (cloud). Use 2 for cloud shadow.
    extent : tuple, optional
        Optional tuple (x_min, x_max, y_min, y_max) to set a custom plot extent.
        If provided, this extent will override the auto_zoom calculation and be applied
        to all plot panels. If None, auto_zoom will calculate extent from data bounds.
        Default is None.
    
    Returns
    -------
    dict
        Dictionary with MRGS IDs as keys and results as values:
        {
            'mrgs_id_1': {
                'status': 'success',
                'fmask4_frequency': xarray.DataArray,
                'fmask5_frequency': xarray.DataArray,
                'fig': matplotlib.figure.Figure,  # 2x2 cloud frequency plot
                'axes': numpy.ndarray,  # Axes for cloud frequency plot
                'cloud_coverage_df': pandas.DataFrame,  # Time series data
                'timeseries_fig': matplotlib.figure.Figure,  # Time series plot
                'timeseries_axes': tuple,  # Axes for time series plot
                'exported_files': dict (if export=True)
            },
            ...
        }
    
    Example
    -------
    >>> from module.data_access import load_fmask_pair
    >>> from module.plotting import batch_time_series_comparison
    >>> 
    >>> mrgs_ids = ['10UEV', '31UDQ', '11SLT', '42QWM']
    >>> results = batch_time_series_comparison(
    ...     mrgs_id_list=mrgs_ids,
    ...     bucket='hls-debug-output',
    ...     fmask4_files='s3://hls-debug-output/Fmask4_outputs/TS/S30/{mrgs_id}/',
    ...     fmask5_files='s3://hls-debug-output/Fmask5_outputs/TS/S30/{mrgs_id}/',
    ...     export=True,
    ...     show_plot=True
    ... )
    """
    import numpy as np
    import xarray as xr
    from pathlib import Path
    import re
    import matplotlib.pyplot as plt
    
    # Helper function to construct S3 path with MRGS ID
    def _construct_s3_path_with_mrgs(base_path, mrgs_id):
        """
        Construct S3 path by inserting MRGS ID into the path.
        
        Handles multiple patterns:
        1. Path contains '{mrgs_id}' or '{MRGS_ID}' placeholder: replaces it with actual MRGS ID
        2. Path ends with '/': appends MRGS ID to the path
        3. Path already contains the MRGS ID: uses as-is
        4. Path is a list: returns as-is (no modification)
        
        Examples:
        - 's3://bucket/path/{mrgs_id}/' -> 's3://bucket/path/31UDQ/'
        - 's3://bucket/path/' -> 's3://bucket/path/31UDQ/'
        - 's3://bucket/path/31UDQ/' -> 's3://bucket/path/31UDQ/' (unchanged)
        - 's3://hls-debug-output/Fmask5_outputs/TS/S30/{mrgs_id}/' -> 's3://hls-debug-output/Fmask5_outputs/TS/S30/31UDQ/'
        """
        if isinstance(base_path, list):
            # If it's already a list, return as-is
            return base_path
        
        base_path = str(base_path).strip()
        
        # Check if path contains {mrgs_id} placeholder (case-insensitive)
        if '{mrgs_id}' in base_path.lower():
            # Replace placeholder with actual MRGS ID (case-insensitive replacement)
            # Replace {mrgs_id} or {MRGS_ID} with actual MRGS ID
            path = re.sub(r'\{mrgs_id\}', mrgs_id, base_path, flags=re.IGNORECASE)
        elif base_path.endswith('/'):
            # If path ends with '/', append MRGS ID
            path = base_path + mrgs_id + '/'
        else:
            # Check if MRGS ID is already in the path (as a directory component)
            # Look for MRGS ID as a standalone path component (between slashes or at end)
            if re.search(r'/' + re.escape(mrgs_id) + r'(/|$)', base_path):
                # MRGS ID already in path as a directory, use as-is
                path = base_path
            else:
                # Append MRGS ID to path
                path = base_path.rstrip('/') + '/' + mrgs_id + '/'
        return path
    
    # Note: This function assumes the following are available in the namespace:
    # - load_fmask_pair (from module.data_access)
    # - calculate_cloud_frequency_per_pixel (from this module)
    # - visualize_and_export_cloud_frequency (from this module)
    # - filter_valid_time_steps (from this module)
    # - calculate_cloud_coverage_timeseries (from module/fmask)
    # - plot_cloud_coverage_timeseries (from this module)
    # These should be loaded via %run in the notebook before calling this function
    
    results = {}
    
    mask_name = "Cloud" if mask_value == 4 else "Cloud Shadow" if mask_value == 2 else f"Mask {mask_value}"
    
    print(f"{'='*80}")
    print(f"Batch {mask_name} Time Series Comparison")
    print(f"Processing {len(mrgs_id_list)} MRGS IDs: {mrgs_id_list}")
    print(f"{'='*80}\n")
    
    for idx, mrgs_id in enumerate(mrgs_id_list, 1):
        print(f"\n{'='*80}")
        print(f"Processing MRGS ID {idx}/{len(mrgs_id_list)}: {mrgs_id}")
        print(f"{'='*80}")
        
        try:
            # Step 1: Load Fmask 4.7 and 5.0 time-series data
            print(f"\n[Step 1] Loading Fmask data for {mrgs_id}...")
            
            # Construct S3 paths with MRGS ID
            fmask4_path = _construct_s3_path_with_mrgs(fmask4_files, mrgs_id)
            fmask5_path = _construct_s3_path_with_mrgs(fmask5_files, mrgs_id)
            
            print(f"  Fmask 4.7 path: {fmask4_path if isinstance(fmask4_path, str) else 'list of files'}")
            print(f"  Fmask 5.0 path: {fmask5_path if isinstance(fmask5_path, str) else 'list of files'}")
            
            fmask_pair = load_fmask_pair(
                mrgs_id,
                bucket=bucket,
                fmask4_files=fmask4_path,
                fmask5_files=fmask5_path,
            )
            
            fmask4_data = fmask_pair.get("fmask4")
            fmask5_data = fmask_pair.get("fmask5")
            
            if fmask4_data is None or fmask5_data is None:
                print(f"  ⚠ Warning: Missing data for {mrgs_id}. Skipping...")
                results[mrgs_id] = {
                    'status': 'skipped',
                    'reason': 'missing_data'
                }
                continue
            
            print(f"  ✓ Loaded Fmask 4.7: {fmask4_data.shape if fmask4_data is not None else 'None'}")
            print(f"  ✓ Loaded Fmask 5.0: {fmask5_data.shape if fmask5_data is not None else 'None'}")
            
            # Step 2: Filter RGB bands
            print(f"\n[Step 2] Filtering RGB bands for {mrgs_id}...")
            fmask4_red_raw = fmask_pair.get("fmask4_red")
            fmask4_green_raw = fmask_pair.get("fmask4_green")
            fmask4_blue_raw = fmask_pair.get("fmask4_blue")
            
            # Replace -9999 with NaN for all RGB bands
            if fmask4_red_raw is not None:
                fmask4_red_raw = fmask4_red_raw.where(fmask4_red_raw != -9999, np.nan)
            if fmask4_green_raw is not None:
                fmask4_green_raw = fmask4_green_raw.where(fmask4_green_raw != -9999, np.nan)
            if fmask4_blue_raw is not None:
                fmask4_blue_raw = fmask4_blue_raw.where(fmask4_blue_raw != -9999, np.nan)
            
            # Filter RGB bands and calculate medians
            fmask4_red_median = None
            fmask4_green_median = None
            fmask4_blue_median = None
            
            has_rgb = all([fmask4_red_raw is not None, 
                          fmask4_green_raw is not None, 
                          fmask4_blue_raw is not None])
            
            if has_rgb:
                # Filter each band
                if fmask4_red_raw is not None:
                    fmask4_red = filter_valid_time_steps(fmask4_red_raw, min_valid_ratio=min_valid_ratio)
                    if fmask4_red is not None:
                        fmask4_red_median = fmask4_red.median(dim='time')
                
                if fmask4_green_raw is not None:
                    fmask4_green = filter_valid_time_steps(fmask4_green_raw, min_valid_ratio=min_valid_ratio)
                    if fmask4_green is not None:
                        fmask4_green_median = fmask4_green.median(dim='time')
                
                if fmask4_blue_raw is not None:
                    fmask4_blue = filter_valid_time_steps(fmask4_blue_raw, min_valid_ratio=min_valid_ratio)
                    if fmask4_blue is not None:
                        fmask4_blue_median = fmask4_blue.median(dim='time')
                
                # Check if all medians are available
                has_rgb = all([fmask4_red_median is not None,
                              fmask4_green_median is not None,
                              fmask4_blue_median is not None])
                
                if has_rgb:
                    print(f"  ✓ RGB bands filtered and medians calculated")
                else:
                    print(f"  ⚠ Warning: Some RGB bands missing after filtering")
            
            # Step 3: Calculate frequency per pixel
            print(f"\n[Step 3] Calculating {mask_name.lower()} frequency per pixel for {mrgs_id}...")
            fmask4_frequency = calculate_cloud_frequency_per_pixel(fmask4_data, mask_value=mask_value)
            fmask5_frequency = calculate_cloud_frequency_per_pixel(fmask5_data, mask_value=mask_value)
            
            print(f"  ✓ Fmask 4.7 {mask_name.lower()} frequency: {fmask4_frequency.shape}, "
                  f"range: {fmask4_frequency.min().values:.3f} to {fmask4_frequency.max().values:.3f}")
            print(f"  ✓ Fmask 5.0 {mask_name.lower()} frequency: {fmask5_frequency.shape}, "
                  f"range: {fmask5_frequency.min().values:.3f} to {fmask5_frequency.max().values:.3f}")
            
            # Step 3.5: Calculate and plot coverage time series
            print(f"\n[Step 3.5] Calculating and plotting {mask_name.lower()} coverage time series for {mrgs_id}...")
            cloud_coverage_df = None
            timeseries_fig = None
            timeseries_axes = None
            
            try:
                # Note: calculate_cloud_coverage_timeseries should be available from module/fmask
                # Note: plot_cloud_coverage_timeseries should be available from this module
                
                cloud_coverage_df = calculate_cloud_coverage_timeseries(
                    fmask4_data,
                    fmask5_data,
                    verbose=False,  # Suppress verbose output in batch mode
                    mask_value=mask_value
                )
                
                if cloud_coverage_df is not None and len(cloud_coverage_df) > 0:
                    # Plot coverage time series (don't show yet, we'll add title first)
                    timeseries_fig, timeseries_axes = plot_cloud_coverage_timeseries(
                        cloud_coverage_df,
                        figsize=(12, 8),
                        linewidth=1,
                        markersize=4,
                        show_stats=False,  # Suppress stats in batch mode
                        show_plot=False  # Don't show yet, we'll add title first
                    )
                    
                    # Add MRGS ID to the figure title
                    if timeseries_fig is not None:
                        timeseries_fig.suptitle(f'{mask_name} Coverage Time Series - MRGS ID: {mrgs_id}', 
                                               fontsize=14, fontweight='bold', y=0.98)
                        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for title
                    
                    # Show plot if requested
                    if show_plot:
                        plt.show()
                    
                    print(f"  ✓ {mask_name} coverage time series plotted ({len(cloud_coverage_df)} time steps)")
                else:
                    print(f"  ⚠ Warning: Could not calculate {mask_name.lower()} coverage time series")
            except Exception as e:
                print(f"  ⚠ Warning: Error creating {mask_name.lower()} coverage time series plot: {e}")
                import traceback
                traceback.print_exc()
            
            # Step 4: Visualize and optionally export
            print(f"\n[Step 4] Creating visualization for {mrgs_id}...")
            
            mask_type_str = mask_name.lower().replace(' ', '_')  # "cloud" or "cloud_shadow"
            if has_rgb:
                result = visualize_and_export_cloud_frequency(
                    fmask4_freq=fmask4_frequency,
                    fmask5_freq=fmask5_frequency,
                    fmask4_red=fmask4_red_median,
                    fmask4_green=fmask4_green_median,
                    fmask4_blue=fmask4_blue_median,
                    output_dir=output_dir,
                    mrgs_id=mrgs_id,
                    vmin=vmin,
                    vmax=vmax,
                    cmap=cmap,
                    figsize=figsize,
                    export=export,
                    show_plot=show_plot,
                    auto_zoom=auto_zoom,
                    extent=extent,
                    mask_type=mask_type_str
                )
            else:
                result = visualize_and_export_cloud_frequency(
                    fmask4_freq=fmask4_frequency,
                    fmask5_freq=fmask5_frequency,
                    output_dir=output_dir,
                    mrgs_id=mrgs_id,
                    vmin=vmin,
                    vmax=vmax,
                    cmap=cmap,
                    figsize=figsize,
                    export=export,
                    show_plot=show_plot,
                    auto_zoom=auto_zoom,
                    extent=extent,
                    mask_type=mask_type_str
                )
            
            # Handle return value (tuple if export=False, dict if export=True)
            if export and isinstance(result, dict):
                # When export=True, function returns dict with exported files
                fig = result.get('fig')
                axes = result.get('axes')
                exported_files = result
            else:
                # When export=False, function returns (fig, axes) tuple
                fig, axes = result
                exported_files = None
            
            # Store results
            result_dict = {
                'status': 'success',
                'fmask4_frequency': fmask4_frequency,
                'fmask5_frequency': fmask5_frequency,
                'fig': fig,  # 2x2 cloud frequency plot
                'axes': axes,
                'cloud_coverage_df': cloud_coverage_df,  # Time series DataFrame
                'timeseries_fig': timeseries_fig,  # Time series plot figure
                'timeseries_axes': timeseries_axes  # Time series plot axes
            }
            
            if exported_files is not None:
                result_dict['exported_files'] = exported_files
            
            results[mrgs_id] = result_dict
            print(f"  ✓ Completed processing {mrgs_id}")
            
        except Exception as e:
            print(f"  ✗ Error processing {mrgs_id}: {e}")
            import traceback
            traceback.print_exc()
            results[mrgs_id] = {
                'status': 'error',
                'error': str(e)
            }
    
    print(f"\n{'='*80}")
    print(f"Batch processing complete!")
    print(f"  Successfully processed: {sum(1 for r in results.values() if r.get('status') == 'success')}")
    print(f"  Skipped: {sum(1 for r in results.values() if r.get('status') == 'skipped')}")
    print(f"  Errors: {sum(1 for r in results.values() if r.get('status') == 'error')}")
    print(f"{'='*80}\n")
    
    return results


In [None]:
def filter_valid_time_steps(data_array, min_valid_ratio=0.9):
    """
    Filter time steps to keep only those with sufficient valid (non-NaN) pixels.
    Also treats -9999 as invalid (NaN) values.
    Parameters
    ----------
    data_array : xarray.DataArray
        Time-series DataArray with dimensions (time, y, x).
    min_valid_ratio : float, optional
        Minimum ratio of valid (non-NaN) pixels required for a time step to be kept.
        Default is 0.9 (90%).
    Returns
    -------
    xarray.DataArray or None
        Filtered DataArray, or None if input is None or no time steps meet the criteria.
    """
    if data_array is None:
        return None
    if 'time' not in data_array.dims:
        return data_array
    import numpy as np
    # Replace -9999 (no data value) with NaN if present
    # This ensures -9999 is treated as invalid data
    data_array = data_array.where(data_array != -9999, np.nan)
    initial_time_steps = len(data_array.time)
    # Calculate the number of valid (non-NaN) pixels for each time step
    # This now includes both NaN and -9999 as invalid
    valid_pixels_per_time_step = (~np.isnan(data_array)).sum(dim=['x', 'y'])
    # Calculate the total number of pixels in each spatial slice
    total_pixels_per_time_step = data_array.isel(time=0).size
    # Calculate the ratio of valid pixels
    valid_ratio = valid_pixels_per_time_step / total_pixels_per_time_step
    # Filter time steps where the valid ratio is greater than or equal to the minimum
    filtered_data_array = data_array.isel(time=(valid_ratio >= min_valid_ratio))
    final_time_steps = len(filtered_data_array.time)
    print(f"Filtered time steps: {initial_time_steps} -> {final_time_steps} "
          f"(kept {final_time_steps/initial_time_steps*100:.1f}% with >= {min_valid_ratio*100:.0f}% valid pixels)")
    if final_time_steps == 0:
        return None
    return filtered_data_array
