## Masks Refinement

## Modules

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os

# Get the parent directory of the current notebook
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../src"))

# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)

from process_masks import (
    refine_cloud_mask,
    refine_shadow_mask,
    generate_seasonal_backgrounds
)

from scripting import (
    print_map,
    get_season_id,
    load_s2,
    preprocess,
    set_bands,
    drop_aux_bands,
    get_scl_mask,
)

import rioxarray
import dask
import dask.distributed
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import yaml
import re
import datetime
import rasterio
from glob import glob
from datetime import date

def read_yaml(file_path: str) -> dict:
    with open(file_path, 'r') as yaml_file: return yaml.safe_load(yaml_file)
    
def fix_paths_for_nb(input_dict, old_substring = "/home/hrlcuser/media", new_substring = "/media/datapart/lucazanolo"):
    return {
        key: (value.replace(old_substring, new_substring) if isinstance(value, str) else value)
        for key, value in input_dict.items()
    }


## Parameters

In [None]:
parameters = fix_paths_for_nb(read_yaml("/home/lucazanolo/luca-zanolo/scripts/config_files/1.masks_refinement.yaml"))

resolutions = parameters["resolutions"]

dask_graph_path = f"{parameters['output_path']}/dask_graph/"
bg_output_path = f"{parameters['output_path']}/backgrounds"
cloud_masks_path = f"{parameters['output_path']}/cloud_masks"
shadow_masks_path = f"{parameters['output_path']}/shadow_masks"

os.makedirs(cloud_masks_path, exist_ok=True)
os.makedirs(bg_output_path, exist_ok=True)
os.makedirs(dask_graph_path, exist_ok=True)
os.makedirs(shadow_masks_path, exist_ok=True)

SAVE_DATA = True
parameters

## Load Dataset

In [None]:
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(
        os.cpu_count() or 2
    ),
    
) as client:

    print(f"Dask dashboard: {client.dashboard_link}")

    # Load datasets and set the band names
    dss = {res: load_s2(parameters["input_path"], 
                        group=res, 
                        preprocess=preprocess, 
                        tile=parameters["tile_id"], 
                        sensor=parameters["sensor"], 
                        year=parameters["year"]) for res in resolutions}
    
    dss = {res: set_bands(dss[res], only_bands=False) for res in resolutions}
    scls = {res : dss[res].sel(band='SCL') for res in resolutions if res in ["20m", "60m"]} # Isolating SCLs band
    dss, band_names = drop_aux_bands(**dss) # Dropping auxiliary bands
    
    if "10m" in resolutions:
        print(f"Estimating SCL mask for 10m resolution")
        ref = "20m" if "20m" in resolutions else "60m"
        scls["10m"] = scls[ref].interp(
            
            dict(x=dss["10m"].coords["x"], y=dss["10m"].coords["y"]),
            method="nearest",
            kwargs=dict(fill_value="extrapolate"),
        )        

    print(f"Start retrieving masks from SCL band. Masks Requested: {parameters['mask_definitions']}")

    scl_masks = xr.concat([
        xr.concat(
            [
                get_scl_mask(scls[resolutions[0]].sel(time=t), scl_values)
                .expand_dims({"mask_type": [mask_type]})
                .astype(bool)
                for mask_type, scl_values in parameters["mask_definitions"].items()
            ], 
            dim="mask_type"
        )
        for t in scls[resolutions[0]].time.values
    ], dim="time")
    
    scl_valid = scl_masks   

    # Interpolating resolutions to 10m
    dss_up: dict[str, xr.Dataset] = dict()

    # Keep the 10m data (which includes B2 and B8) as is
    dss_up[resolutions[0]] = dss[resolutions[0]]

    # Interpolate only B11 from the 20m resolution dataset
    if "B11" in dss[resolutions[1]].band:
        dss_up["20m"] = (
            dss["20m"]
            .sel(band="B11")
            .interp(
                dict(
                    x=dss_up[resolutions[0]].coords["x"],
                    y=dss_up[resolutions[0]].coords["y"],
                ),
                method="nearest",
                kwargs=dict(fill_value="extrapolate"),
            )
            .astype(np.uint16)
        )

    # Concatenate the 10m data (B2, B8) with the interpolated B11
    ds = xr.concat([dss_up[resolutions[0]], dss_up["20m"]], dim="band")
    
    ds.attrs["long_name"] = band_names
    ds = ds.assign(dict(masks=scl_valid))
    #ds = ds.unify_chunks()
    
    print_map(scls, "\n\n SCL ISOLATED BANDs\n")
    print(f"\n\n SCL MASKS (retrieved from interpolated DSS (-> DS))\n\n{scl_valid}\n\n")
    print_map(dss, "\n\n DSS + SET_BANDS() + AUX_BANDS_DROP()\n")
    print(f"\n\n DSS + SET_BANDS() + AUX_BANDS_DROP() + INTERPOLATION to 10M\n\n{ds}\n\n")


## Background calculation

In [None]:
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(
        os.cpu_count() or 2
    ),
    
) as client:


    print(f"Dask dashboard: {client.dashboard_link}")
    backgrounds = generate_seasonal_backgrounds(ds, quantile = 0.25)
    
    print(f"\n\nBackground:\n\n{backgrounds}\n\n")
    
    for season, group in ds.groupby("season_id"):
        
        # Background computation for current season
        print(f"[{season}] Processing season: {season}")
        background_path = f"{bg_output_path}/background_{group.season_id.values[0]}.tif" 
        background = backgrounds.sel(season_id = season)
        
        if not os.path.exists(background_path):
            print(f"[{season}] Start generating and saving background at {background_path}")
                        
            background.rio.to_raster(
                background_path,
                compress="DEFLATE",
                num_threads="all_cpus"
            )
            
            print(f"[{season}] Background saved")

        #background = rioxarray.open_rasterio(background_path).squeeze()
        #background = background.compute()


## Refine cloud and shadow masks

In [None]:
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(
        os.cpu_count() or 2
    ),
    
) as client:


    print(f"Dask dashboard: {client.dashboard_link}")


    refined_cloud_masks = []
    refined_shadow_masks = []

    for season, group in ds.groupby("season_id"):
        
        # Background computation for current season
        print(f"[{season}] Processing season: {season}")
        background_path = f"{bg_output_path}/background_{group.season_id.values[0]}.tif" 
        background = backgrounds.sel(season_id = season)
        
        if not os.path.exists(background_path):
            print(f"[{season}] Start generating and saving background at {background_path}")
                        
            background.rio.to_raster(
                background_path,
                compress="DEFLATE",
                num_threads="all_cpus"
            )
            
            print(f"[{season}] Background saved")

        background = rioxarray.open_rasterio(background_path).squeeze()
        background = background.compute()
        
        # Masks refinement
        
        for time in group.time.values:
            
            print(f"Generating and saving Sen2cor and refined cloud and shadow masks for {time}")
            c_group = group.sel(time = time)        
                
            ref_cloud_mask_path = f"{cloud_masks_path}/{str(c_group.file_name.values)[:-4]}_cloudMediumMask.tif"
            ref_shadow_mask_path = f"{shadow_masks_path}/{str(c_group.file_name.values)[:-4]}_shadowMask.tif"            
            
            if not os.path.exists(ref_shadow_mask_path) or not os.path.exists(ref_cloud_mask_path):
                c_group = c_group.load()
                            
            refined_cloud_mask = xr.apply_ufunc(
                refine_cloud_mask,
                group.data.sel(band="B2"),                            # Blue band (B2) as input
                group.masks.sel(mask_type="cloud"),  
                background,# Cloud mask
                kwargs = {'cloud_coverage_threshold':parameters["cloud_coverage_threshold"]},
                input_core_dims=[['y', 'x'], ['y', 'x'], ['y', 'x']], 
                output_core_dims=[['y', 'x']],
                vectorize=True,
                dask="parallelized",
                output_dtypes=[np.uint8],
                keep_attrs=True,
                dask_gufunc_kwargs={'allow_rechunk': True}
            )
            del refined_cloud_mask.attrs['long_name']
            
            refined_shadow_mask = xr.apply_ufunc(
                refine_shadow_mask,       
                group.data.sel(band="B2"),             # Blue band (B2)
                group.data.sel(band="B8"),             # NIR band (B8)
                group.data.sel(band="B11"),            # SWIR band (B11)
                refined_cloud_mask,    # Cloud mask
                group.masks.sel(mask_type="shadow"),   # Shadow mask
                kwargs={"cloud_coverage_threshold"   : parameters["cloud_coverage_threshold"],
                        "image_brightness_threshold" : parameters["image_brightness_threshold"]},
                input_core_dims=[
                    ['y', 'x'], ['y', 'x'], ['y', 'x'], ['y', 'x'], ['y', 'x']],
                output_core_dims=[['y', 'x']],
                vectorize=True,
                dask="parallelized",
                output_dtypes=[np.uint8],
                keep_attrs=True,
                dask_gufunc_kwargs={'allow_rechunk': True}
            )
            del refined_shadow_mask.attrs['long_name']
        
            date = str(time)[:10]

            if not os.path.exists(ref_cloud_mask_path):

                refined_cloud_mask.rio.to_raster(
                    ref_cloud_mask_path,
                    compress="DEFLATE",
                    num_threads="all_cpus"
                )
                print(f"[{season}][{date}] Saved refined cloud mask at {ref_cloud_mask_path}")
                
            else: print(f"[{season}][{date}] Skipping. Already exists: {ref_cloud_mask_path}")
            
            if not os.path.exists(ref_shadow_mask_path):
            
                refined_shadow_mask.rio.to_raster(
                    ref_shadow_mask_path,
                    compress="DEFLATE",
                    num_threads="all_cpus"
                )
                print(f"[{season}][{date}] Saved refined shadow mask at {ref_shadow_mask_path}")
            else: print(f"[{season}][{date}] Skipping. Already exists: {ref_shadow_mask_path}")
            
            refined_shadow_masks.append(refined_shadow_mask)
            refined_cloud_masks.append(refined_cloud_mask)
            
            del c_group, refined_cloud_mask, refined_shadow_mask

        del background

    refined_cloud_masks = xr.concat(refined_cloud_masks, dim='time')
    refined_shadow_masks = xr.concat(refined_shadow_masks, dim='time')

    refined_masks = xr.Dataset({
        "refined_cloud_masks" : refined_cloud_masks,
        "refined_shadow_masks" : refined_shadow_masks
    })

    print(f"\n\n REFINED MASKS \n\n{refined_masks}\n\n")
    print(f"\n\n BACKGROUNDS \n\n{backgrounds}\n\n")
        

## Inspect background/refined masks - Generate reports

Only the load dataset cell must be runned. The backgrounds and refined masks should be already available.

In [None]:
def load_refined_masks(ds_slice):
    
    cloud_masks = []
    shadow_masks = []

    print(f"\n\nDataset slice \n{ds_slice}\n")
    print(ds_slice.masks.sel(mask_type = "cloud"))
    cloud_path = f"{cloud_masks_path}/{str(ds_slice.file_name.values)[:-4]}_cloudMediumMask.tif" 
    shadow_path = f"{shadow_masks_path}/{str(ds_slice.file_name.values)[:-4]}_shadowMask.tif"

    print(f"Loading refined cloud mask: {cloud_path}")
    cloud_mask = xr.DataArray(
        rioxarray.open_rasterio(cloud_path).squeeze().data,
        dims=ds_slice.masks.sel(mask_type = "cloud").dims,
        coords=ds_slice.masks.sel(mask_type = "cloud").coords,
        attrs=ds_slice.masks.sel(mask_type = "cloud").attrs
    )
    
    print(f"Loading refined shadow mask: {shadow_path}")
    shadow_mask = xr.DataArray(
        rioxarray.open_rasterio(shadow_path).squeeze().data,
        dims=ds_slice.masks.sel(mask_type = "shadow").dims,
        coords=ds_slice.masks.sel(mask_type = "shadow").coords,
        attrs=ds_slice.masks.sel(mask_type = "shadow").attrs
    )
    cloud_masks.append(cloud_mask)
    shadow_masks.append(shadow_mask)
        
    cloud_masks_da = xr.concat(cloud_masks, dim='time')
    shadow_masks_da = xr.concat(shadow_masks, dim='time')

    ds_slice = ds_slice.assign({
        "refined_cloud_masks":cloud_masks_da.astype(bool),
        "refined_shadow_masks":shadow_masks_da.astype(bool)
    })
    
    return ds_slice

def calculate_percentage(mask):
    """
    Calculate the percentage of pixels with a value of 1 in the mask.
    """
    total_pixels = mask.size
    cloud_shadow_pixels = np.sum(mask == 1)
    return (cloud_shadow_pixels / total_pixels) * 100

def create_report(image, background, s2c_cloud_mask, s2c_shadow_mask, s2c_cloud_mask_ref, s2c_shadow_mask_ref, output_path="report.png"):
    """
    Create a report of subplots for the provided arrays and save the image.
    """
    # Calculate percentages for masks
    cloud_percentage = calculate_percentage(s2c_cloud_mask)
    shadow_percentage = calculate_percentage(s2c_shadow_mask)
    cloud_ref_percentage = calculate_percentage(s2c_cloud_mask_ref)
    shadow_ref_percentage = calculate_percentage(s2c_shadow_mask_ref)

    # Create the figure and subplots
    fig, axes = plt.subplots(3, 2, figsize=(20, 30))

    # Plot the RGB bands
    axes[0, 0].imshow(np.clip(image.transpose(1,2,0) / 3000, 0, 1), cmap="jet")
    axes[0, 0].set_title("Blue Band")
    axes[0, 0].axis("off")

    # Plot the background
    axes[0, 1].imshow(background, cmap="jet")
    axes[0, 1].set_title("Background")
    axes[0, 1].axis("off")

    # Plot cloud mask
    axes[1, 0].imshow(s2c_cloud_mask, cmap="gray")
    axes[1, 0].set_title(f"Cloud Mask\n{cloud_percentage:.2f}% Cloud Pixels")
    axes[1, 0].axis("off")

    # Plot cloud mask refined
    axes[1, 1].imshow(s2c_cloud_mask_ref, cmap="gray")
    axes[1, 1].set_title(f"Cloud Mask Ref\n{cloud_ref_percentage:.2f}% Cloud Pixels")
    axes[1, 1].axis("off")

    # Plot shadow mask
    axes[2, 0].imshow(s2c_shadow_mask, cmap="gray")
    axes[2, 0].set_title(f"Shadow Mask\n{shadow_percentage:.2f}% Shadow Pixels")
    axes[2, 0].axis("off")
    
    # Plot shadow mask refined
    axes[2, 1].imshow(s2c_shadow_mask_ref, cmap="gray")
    axes[2, 1].set_title(f"Shadow Mask Ref\n{shadow_ref_percentage:.2f}% Shadow Pixels")
    axes[2, 1].axis("off")

    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.show()
    plt.close(fig)

    print(f"Report saved as {output_path}")
    
with dask.distributed.Client(
    processes=False,
    threads_per_worker=(
        os.cpu_count() or 2
    ),
    
) as client:

    print(f"Dask dashboard: {client.dashboard_link}")


    tile_id = "10UED"
    report_path = f"/media/datapart/lucazanolo/S2_processed_masks/reports"    
    os.makedirs(report_path, exist_ok=True)
    limit = 2
    for time_idx, time in enumerate(ds.time.values):
        
        print(time)
        continue
        print(f"Generating report for time index: {time_idx}")
        curr_report_path = f"{report_path}/report_{tile_id}_{time_idx}.png"
        print(curr_report_path)
        if not os.path.exists(curr_report_path):
            
            time = ds.time.values[time_idx]
            season = get_season_id(ds.time.values[time_idx])
            bg_path = f"{bg_output_path}/background_{tile_id}{season}.tif"
            ds_test = ds.isel(time=time_idx)
            ds_test = load_refined_masks(ds_test)
            
            rgb_image = ds_test.data.sel(band = ["B2","B3","B4"]).values
            background = rioxarray.open_rasterio(bg_path).squeeze().values
            s2c_cloud_mask = ds_test.masks.sel(mask_type = "cloud").values
            s2c_shadow_mask = ds_test.masks.sel(mask_type = "shadow").values
            s2c_cloud_mask_ref = ds_test.refined_cloud_masks.squeeze().values
            s2c_shadow_mask_ref = ds_test.refined_shadow_masks.squeeze().values

            create_report(rgb_image, background, s2c_cloud_mask, s2c_shadow_mask, s2c_cloud_mask_ref, s2c_shadow_mask_ref, curr_report_path)
            break
    
    """
    print(rgb_image.shape)
    print(background.shape)
    print(s2c_cloud_mask.shape)
    print(s2c_shadow_mask.shape)
    print(s2c_cloud_mask_ref.shape)
    print(s2c_shadow_mask_ref.shape)
    """




### Plot backgrounds

In [None]:
import rioxarray
import matplotlib.pyplot as plt

background_path = "/media/datapart/lucazanolo/S2_processed_masks/backgrounds/background_18NWL2019_1.tif"
background = rioxarray.open_rasterio(background_path)
background_2d1 = background.squeeze(drop=True)

plt.figure(figsize=(8, 6))
background_2d1.plot()
plt.title("Background Image")
plt.show()


In [None]:
import rioxarray
import matplotlib.pyplot as plt

background_path = "/media/datapart/lucazanolo/S2_processed_masks/backgrounds/backgroundImage_18NWL2019_1.tif"
background = rioxarray.open_rasterio(background_path)
background_2d2 = background.squeeze(drop=True)

plt.figure(figsize=(8, 6))
background_2d2.plot()
plt.title("Background Image")
plt.show()


In [None]:
def generate_mask_report(tile: str, date: datetime.date, output_path="report.png"):
    import numpy as np
    tile = tile.upper()

    base_old = "/media/datapart/lucazanolo/S2_processed_masks_old"
    base_new = "/media/datapart/lucazanolo/S2_processed_masks"

    def find_background(path):
        season = (date.month % 12 + 3) // 3  # 1=Winter, ..., 4=Fall
        pattern = f"backgroundImage_{tile}{date.year}_{season}.tif"
        matches = glob(os.path.join(path, "backgrounds", pattern))
        return matches[0] if matches else None

    def find_sen2cor_cloud_mask():
        folder = os.path.join(base_old, "cloud_masks")
        pattern = re.compile(
            rf"MSIL2A_{date.strftime('%Y%m%d')}T\d+_N\d+_R\d+_T{tile}_\d+T\d+_cloudMediumMask_Sen2Cor\.tif$"
        )
        for file in os.listdir(folder):
            if pattern.match(file):
                return os.path.join(folder, file)
        return None

    def find_sen2cor_shadow_mask():
        folder = os.path.join(base_old, "shadow_masks")
        pattern = re.compile(
            rf"MSIL2A_{date.strftime('%Y%m%d')}T\d+_N\d+_R\d+_T{tile}_\d+T\d+_shadowMask_Sen2Cor\.tif$"
        )
        for file in os.listdir(folder):
            if pattern.match(file):
                return os.path.join(folder, file)
        return None
    
    def find_refined_cloud_mask(base_path):
        folder = os.path.join(base_path, "cloud_masks")
        pattern = re.compile(
            rf"MSIL2A_{date.strftime('%Y%m%d')}T\d+_N\d+_R\d+_T{tile}_\d+T\d+_cloudMediumMask\.tif$"
        )
        for file in os.listdir(folder):
            if pattern.match(file):
                return os.path.join(folder, file)
        return None

    def find_refined_shadow_mask(base_path):
        folder = os.path.join(base_path, "shadow_masks")
        pattern = re.compile(
            rf"MSIL2A_{date.strftime('%Y%m%d')}T\d+_N\d+_R\d+_T{tile}_\d+T\d+_shadowMask\.tif$"
        )
        for file in os.listdir(folder):
            if pattern.match(file):
                return os.path.join(folder, file)
        return None

    # Find all required files
    bg_old = find_background(base_old)
    bg_new = find_background(base_new)

    cloud_sen2cor = find_sen2cor_cloud_mask()
    shadow_sen2cor = find_sen2cor_shadow_mask()

    refined_old_cloud = find_refined_cloud_mask(base_old)
    refined_old_shadow = find_refined_shadow_mask(base_old)
    refined_new_cloud = find_refined_cloud_mask(base_new)
    refined_new_shadow = find_refined_shadow_mask(base_new)

    files = {
        "Existing Pipepline - Background": bg_old,
        "Reimplemented Pipeline - Background": bg_new,
        "Sen2Cor Cloud": cloud_sen2cor,
        "Sen2Cor Shadow": shadow_sen2cor,
        "Existing Pipeline - Refined Cloud": refined_old_cloud,
        "Existing Pipeline - Refined Shadow": refined_old_shadow,
        "Reimpl. Pipeline - Refined Cloud": refined_new_cloud,
        "Reimpl. Pipeline - Refined Shadow": refined_new_shadow,
    }

    fig, axs = plt.subplots(4, 2, figsize=(10, 20))
    axs = axs.flatten()

    for idx, (title, path) in enumerate(files.items()):
        if path is None or not os.path.exists(path):
            print(f"{title} - Path: {path} does not exist!")
            continue
        ax = axs[idx]
        with rasterio.open(path) as src:
            img = src.read(1)
            # Calcolo percentuale solo per le maschere
            if "Background" not in title:
                total_pixels = img.size
                active_pixels = np.sum(img == 1)
                percentage = (active_pixels / total_pixels) * 100
                ax.set_title(f"{title}\n{percentage:.2f}%", fontsize=18)
                ax.imshow(img, cmap='gray')
            else:
                ax.set_title(title, fontsize=18)
                ax.imshow(img)
        ax.axis('off')

    if len(files) < len(axs):
        for ax in axs[len(files):]:
            ax.axis('off')

    # Suptitle con tile e data
    fig.suptitle(f"Tile: {tile} — Date: {date.isoformat()}", fontsize=22)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(output_path)
    plt.close()
    print(f"Report saved at: {output_path}")


In [None]:
# TILE 10UED -> 2019 10 06 report_10ued_20191006
# TILE 18NWL -> 2019 02 23 report_18nwl_20190109
generate_mask_report("18NWL", date(2019, 2, 23), "report_18nwl_20190223.png")


In [None]:
generate_mask_report("10UED", date(2019, 10, 6), "report_10ued_20191006.png")
