In [1]:
from datetime import datetime, date
from rasterio.warp import reproject, Resampling
from rasterio.windows import from_bounds




def filter_fmask_files(files_s3_fmask5, files_s3_fmask4, mrgs_id_target, to_julian_date):
    """
    Filter fmask5 and fmask4 files for a given target id based on matching days converted to Julian dates.

    Parameters:
        files_s3_fmask5 (list): List of filenames for fmask5.
        files_s3_fmask4 (list): List of filenames for fmask4.
        mrgs_id_target (str): Target ID to filter files.
        to_julian_date (function): Function converting YYYY-MM-DD string to Julian date string/int.

    Returns:
        tuple:
            files_fmask5_test (list): Filtered fmask5 files containing the target ID.
            files_fmask4_test (list): Filtered fmask4 files containing target ID and matching Julian dates.
    """
    # Filter fmask5 files by target ID
    files_fmask5_test = [f for f in files_s3_fmask5 if mrgs_id_target in f]

    # Extract unique days (assuming day is 4th element when splitting by '_')
    day_uniques = list(np.unique([f.split('.')[3] for f in files_fmask5_test]))

    # Convert unique days to Julian dates
    day_uniques_julian = [day for day in day_uniques]

    # Filter fmask4 files by target ID and matching julian dates
    files_fmask4_test = [
        f for f in files_s3_fmask4
        if mrgs_id_target in f and any(str(jd) in f for jd in day_uniques_julian)
    ]


    return files_fmask5_test, files_fmask4_test, day_uniques, day_uniques_julian

def select_fmask_files(fmask_version, day, files_fmask4_test, files_fmask5_test, day_uniques, day_uniques_julian):
    """
    Selects Fmask or RGB reference files for a given version and day.

    Parameters:
        fmask_version (str): Either '', 'fmask4.7', or 'fmask5_May2025'.
        day (str): Day string in the format 'YYYY-MM-DD_T...' (for example, from HLS filenames).
        files_fmask4_test (list): List of Fmask 4.7 file paths.
        files_fmask5_test (list): List of Fmask 5 file paths.
        day_uniques (list): Unique days in ISO-like format (YYYY-MM-DD...).
        day_uniques_julian (list): Corresponding Julian date strings (YYYYDDD).

    Returns:
        str or list: 
            - A single file path (string) for Fmask 4.7 or Fmask 5 files.
            - A list of band files (for reference RGB case).
    """
    
    # --- Reference RGB (i == 0) ---
    if fmask_version == '':
        julian_day = day_uniques_julian[day_uniques.index(day)]
        # Select all B0* bands (RGB)
        file_list = [f for f in files_fmask4_test if julian_day in f and '.B0' in f]
        return file_list
    
    elif 'cirrus' in fmask_version:
        julian_day = day_uniques_julian[day_uniques.index(day)]
        
        file_list = [f for f in files_fmask4_test if julian_day in f and 'B10' in f]
        # Select band B10
        return file_list[0] if file_list else None
        

    # --- Fmask 4.7 (i == 1) ---
    elif 'fmask4.7' in fmask_version:
        julian_day = day_uniques_julian[day_uniques.index(day)]
        file_match = [f for f in files_fmask4_test if julian_day in f and 'Fmask.tif' in f]
        return file_match[0] if file_match else None

    # --- Fmask 5 (i == 2) ---
    elif 'fmask5' in fmask_version:
        file_match = [f for f in files_fmask5_test if day in f]
        return file_match[0] if file_match else None

    
            
        

    else:
        raise ValueError(f"Unknown fmask_version: {fmask_version}")

def read_true_color_from_files(red_path, green_path, blue_path, normalize=True):
    """
    Reads three single-band images (paths) and stacks them into a true color composite.

    Parameters:
        red_path (str): File path to the red band image.
        green_path (str): File path to the green band image.
        blue_path (str): File path to the blue band image.
        normalize (bool): Whether to normalize pixel values to [0,1].

    Returns:
        np.ndarray: RGB composite image array (height, width, 3).
    """
    with rasterio.open(red_path) as src_r:
        red = src_r.read(1)

    with rasterio.open(green_path) as src_g:
        green = src_g.read(1)

    with rasterio.open(blue_path) as src_b:
        blue = src_b.read(1)

    rgb = np.dstack((red, green, blue)).astype(np.float32)

    if normalize:
        rgb /= np.max(rgb)

    return rgb

def read_raster_or_rgb(file, bucket_name, read_true_color_from_files, fmask4_ref = None):
    """
    Reads raster data or RGB composite from S3, aligning all rasters to 
    the spatial extent of the Fmask 4 reference file if provided.

    Parameters:
        file (str or list): Single filename or list of RGB band filenames.
        bucket_name (str): S3 bucket name.
        read_true_color_from_files (func): RGB composite reader.
        fmask5_ref (str, optional): S3 path or local path to Fmask 4 raster to match extent.

    Returns:
        np.ndarray: Raster or RGB array clipped/resampled to Fmask4 extent.
    """
    # Helper: open reference (Fmask 5) once
    ref_profile = None
    if fmask4_ref:
        with rasterio.open(f's3://{bucket_name}/{fmask4_ref}') as ref:
            ref_profile = {
                "crs": ref.crs,
                "transform": ref.transform,
                "width": ref.width,
                "height": ref.height,
                "bounds": ref.bounds
            }


    if isinstance(file, str):
        path = f's3://{bucket_name}/{file}'
        with rasterio.open(path) as src:
            raster = src.read(1).astype(np.float32)
            src_crs = src.crs
            src_transform = src.transform

        if 'b10' in file.lower():
            mask = raster < 0
            # Replace negative values with the defined nodata value
            raster[mask] = np.nan

            
        else:
            # Mask out no-data values (255 and 0)
            raster[(raster == 255) | (raster == 0)] = np.nan
            # Convert both Fmask 4&5 bit patterns to categorical labels if file is Fmask 4
            #if "fmask4" in file.lower():
            raster = convert_fmask_unique(raster)   
                

        # Clip/reproject to Fmask4 extent if available
        if ref_profile:
            dst = np.empty((ref_profile["height"], ref_profile["width"]), dtype=np.float32)
            reproject(
                source=raster,
                destination=dst,
                src_transform=src_transform,
                src_crs=src_crs,
                dst_transform=ref_profile["transform"],
                dst_crs=ref_profile["crs"],
                resampling=Resampling.nearest
            )
            raster = dst



        return raster

    

    elif isinstance(file, (list, tuple)):
        # Identify band files in the list
        red_path = next((f for f in file if 'B04' in f), None)
        green_path = next((f for f in file if 'B03' in f), None)
        blue_path = next((f for f in file if 'B02' in f), None)

        if not all([red_path, green_path, blue_path]):
            raise ValueError("RGB band files (B04, B03, B02) not found in file list.")

        rgb = read_true_color_from_files(
            f's3://{bucket_name}/{red_path}',
            f's3://{bucket_name}/{green_path}',
            f's3://{bucket_name}/{blue_path}',
            normalize=True
        )

        # Clip RGB to Fmask4 spatial extent if available
        if ref_profile:
            with rasterio.open(f's3://{bucket_name}/{red_path}') as src:
                win = from_bounds(
                    *ref_profile["bounds"], 
                    transform=src.transform
                )
                win = win.round_offsets().round_lengths()
                rgb = rgb[
                    int(win.row_off):int(win.row_off + win.height),
                    int(win.col_off):int(win.col_off + win.width),
                    :
                ]

        return rgb


    else:
        raise TypeError("`file` must be a string or a list/tuple of strings.")


def to_julian_date(dt):
    """
    Convert a date (string, datetime, or date) to Julian date format YYYYDDD.
    
    Examples:
        '2025-11-05' → '2025310'
        datetime(2025, 11, 5) → '2025310'
    """
    if isinstance(dt, str):
        # Try parsing common formats automatically
        try:
            dt = datetime.fromisoformat(dt)
        except ValueError:
            dt = datetime.strptime(dt, "%Y%m%d")
    elif isinstance(dt, date) and not isinstance(dt, datetime):
        dt = datetime.combine(dt, datetime.min.time())
    
    # Convert to Julian date format
    return f"{dt.year}{dt.timetuple().tm_yday:03d}"


    

def MNDWI(dataset,green_band,swir1_band):
    """
    Calculates the Modified Normalized Difference Water Index (MNDWI).

    MNDWI is used to enhance water features in remote sensing imagery,
    typically calculated as (Green - SWIR1) / (Green + SWIR1).

    Parameters
    ----------
    dataset : xarray.Dataset or similar
        Dataset containing spectral bands.
    green_band : str
        Name of the green band in the dataset.
    swir1_band : str
        Name of the short-wave infrared band 1 (SWIR1) in the dataset.

    Returns
    -------
    xarray.DataArray
        The MNDWI index values.
    """

    return (dataset[green_band] - dataset[swir1_band])/(dataset[green_band] + dataset[swir1_band])



def swir_diff(dataset,swir1_band,swir2_band):
    """
    Computes the ratio between two short-wave infrared bands (SWIR1 / SWIR2).

    This ratio can be useful for various land cover and water analyses.

    Parameters
    ----------
    dataset : xarray.Dataset or similar
        Dataset containing spectral bands.
    swir1_band : str
        Name of the first short-wave infrared band.
    swir2_band : str
        Name of the second short-wave infrared band.

    Returns
    -------
    xarray.DataArray
        The ratio of SWIR1 to SWIR2 bands.
    """

    return dataset[swir1_band] / dataset[swir2_band]


def alpha(dataset,blue_band,green_band,swir1_band,swir2_band):
    """
    Calculates the alpha parameter used in the Enhanced Normalized Difference
    Synthetic Index (ENDISI) formula.

    Alpha is computed as:
        (2 * mean(blue_band)) / (mean(swir_diff) + mean(MNDWI^2))

    Parameters
    ----------
    dataset : xarray.Dataset or similar
        Dataset containing spectral bands.
    blue_band : str
        Name of the blue band.
    green_band : str
        Name of the green band.
    swir1_band : str
        Name of the SWIR1 band.
    swir2_band : str
        Name of the SWIR2 band.

    Returns
    -------
    float
        The alpha coefficient scalar value.
    """

    return (2 * (np.mean(dataset[blue_band]))) / (np.mean(swir_diff(dataset,swir1_band,swir2_band)) +
                                            np.mean(MNDWI(dataset,green_band,swir1_band)**2))


def ENDISI(dataset,blue_band,green_band,swir1_band,swir2_band):
    """
    Computes the Enhanced Normalized Difference Synthetic Index (ENDISI).

    ENDISI is an index used for water feature enhancement and synthetic analysis, calculated as:

        (blue - alpha * (swir_diff + MNDWI^2)) / (blue + alpha * (swir_diff + MNDWI^2))

    Parameters
    ----------
    dataset : xarray.Dataset or similar
        Dataset containing spectral bands.
    blue_band : str
        Name of the blue band.
    green_band : str
        Name of the green band.
    swir1_band : str
        Name of the SWIR1 band.
    swir2_band : str
        Name of the SWIR2 band.

    Returns
    -------
    xarray.DataArray
        The ENDISI index values.
    """

    mndwi = MNDWI(dataset,green_band,swir1_band)
    swir_diff_ds = swir_diff(dataset,swir1_band,swir2_band)
    alpha_ds = alpha(dataset,blue_band,green_band,swir1_band,swir2_band)
    
    return (dataset[blue_band] - (alpha_ds) *
            (swir_diff_ds + mndwi**2)) / (dataset[blue_band] + (alpha_ds) *
                                       (swir_diff_ds + mndwi**2))
    
def create_quality_mask(quality_data, bit_nums):
    """       
    Creates a binary mask indicating pixels flagged by specified bits in a quality (Fmask) layer.

    By default, bits 1 through 5 are used if `bit_nums` is not provided.

    Parameters
    ----------
    quality_data : numpy.ndarray
        2D array of integer quality flags (e.g., from Fmask), possibly containing NaNs.
    bit_nums : list of int, optional
        List of bit positions to check in each pixel's quality flag.
        Defaults to [1, 2, 3, 4, 5].

    Returns
    -------
    numpy.ndarray
        Boolean 2D array mask where True indicates pixels flagged by any of the specified bits.

    Notes
    -----
    - NaN values in `quality_data` are replaced with zeros before processing.
    - Bit positions correspond to bits in the flag integer, with 0 being the least significant bit.
    """

    
    mask_array = np.zeros((quality_data.shape[0], quality_data.shape[1]))
    # Remove/Mask Fill Values and Convert to Integer
    quality_data = np.nan_to_num(quality_data, 0).astype(np.int8)
    for bit in bit_nums:
        # Create a Single Binary Mask Layer
        mask_temp = np.array(quality_data) & 1 << bit > 0
        mask_array = np.logical_or(mask_array, mask_temp)
    return mask_array