This Notebook is the raw code to get the cloud probability product without much description.

In [1]:
# pip install odc-stac
# pip install pystac-client
# pip install dask

import numpy as np
import pandas as pd
import xarray as xr

from pystac_client import Client
import geopandas as gpd
import rioxarray as xrx

import odc.stac

import dask.array as da
import dask.dataframe as dd

tiles="data/uk_20km_grid.gpkg"
tile_of_interest="NZ44" #NOTE: NZ26 is the main one
collection = "sentinel-2-c1-l2a"
time_of_interest = '2024-01-01/2024-12-31'
bands_of_interest = ["cloud", "scl"]
api_url="https://earth-search.aws.element84.com/v1"
res = 20
target_crs=27700

tiles = gpd.read_file(tiles).to_crs(4326)

In [2]:
def check_kernel_limit():
    """This checks the local kernel limitations on memory.
    Prints `-1` if inherited from OS and no restrictions."""
    import resource

    soft, hard = resource.getrlimit(resource.RLIMIT_AS)
    print(f"Address space (virtual memory) soft/hard: {soft}/{hard}")
    soft, hard = resource.getrlimit(resource.RLIMIT_DATA)
    print(f"Data segment size soft/hard: {soft}/{hard}")

check_kernel_limit()

Address space (virtual memory) soft/hard: -1/-1
Data segment size soft/hard: -1/-1


In [3]:
def inspect_search(items,index=0):
    """Print inspect/debug info about the one of the assets (usually the first one).
    Parameters:
    items (pystac.ItemCollection): STAC collection
    index (int=0): random tile number to inspect (Default: 0)
    """
    item = items[index]
    try:
        print("-" * 40)
        print(f"Inspecting the asset #{index}")
        print(f"DATETIME is {item.datetime}")
        print(f"GEOMETRY is {item.geometry}")
        print(f"PROPERTIES are:\n{item.properties}")
        print(f"CRS: {item.properties.get('proj:code') or item.properties.get('proj:epsg')}")
        metadata=odc.stac.extract_collection_metadata(item, cfg=None, md_plugin=None)
        print(f"STAC metadata:\n{metadata}")
        print("-" * 40)
    except Exception as e:
        print("-" * 40)
        print(f"Error checking item[{index}]: {e}")
        print("-" * 40)


In [4]:
def mask_with_scl(data):
    """
    Replace cloud values based on SCL mask.
    Rules:
    - If SCL == 0 → cloud = -1
    - If SCL != 0 → keep cloud as is
    - Assign -1 as the new no-data value for the cloud band
    - Ensure no NaNs remain (fill them with -1)
    
    Parameters:
    data (xarray.Dataset): original dataset with `cloud` and `scl` variables
    Returns:
    cloud (xarray.DataArray): output array with masked `cloud` band
    """
    cloud = data["cloud"].astype("float32").copy()
    scl = data["scl"].astype("int16")

    # apply rules
    cloud = cloud.where(scl != 0, -1)

    # replace any remaining NaNs with -1
    cloud = cloud.fillna(-1)
    # assign -1 as the nodata value for output
    cloud.attrs["nodata"] = -1

    print(f"Cloud dimensions: {cloud.dims}")
    print(f"SCL dimensions: {scl.dims}")
    print("-1 is set as the no-data value (no NaNs remain).")

    return cloud

    """
    #NOTE: DEBUG for checking the cloud masked 
    #NOTE: heavy calculation as it opens the whole numpy array (computes)
    num_nodata = cloud_masked.isnull().sum().compute().item()
    print(f"Number of no-data (NaN) values: {num_nodata}")
    
    # select the first time slice if 'time' is one of the dimensions
    if "time" in cloud_masked.dims:
        first_scene = cloud_masked.isel(time=0)
    else:
        first_scene = cloud_masked

    # ensure CRS and spatial transform are defined
    first_scene = first_scene.rio.write_crs(data["cloud"].rio.crs, inplace=False)
    # export to GeoTIFF
    output_path = "cloud_masked_first_scene.tif"
    first_scene.rio.to_raster(output_path)
    print(f"Exported first scene to {output_path}")
    """
    """
    # DEPRECATED
    # define bad pixels
    nodatas = bitmask
    all_bad_pixels = nodatas(dim="time")

    # Expand dimensions to match data shape
    all_bad_expanded = all_bad_pixels.broadcast_like(nodatas)

    # For these pixels, we’ll override and mark them as good
    effective_bad_mask = nodatas.where(~all_bad_expanded, other=False)

    # Apply the mask: keep data where bad == False
    masked = data.where(~effective_bad_mask)
    return masked
    """

In [5]:
def to_geotif(dataset,bands_of_interest:list=None, out_path:str=None):
    """The function saves the dataset with several variables to geotiff"""
    # Taken from: https://discourse.pangeo.io/t/comparing-odc-stac-load-and-stackstac-for-raster-composite-workflow/4097

    if isinstance(dataset, xr.DataArray):
        print(f"dataarray")
        image = dataset.squeeze('year').rio.write_crs(target_crs)

        # NOTE: this might overflow dask
        '''num_nodata = image.isnull().sum()
        print(f"Number of nodata pixels: {num_nodata}")'''
        
    else:
        image = (
            dataset[bands_of_interest]
            .to_dataarray(dim="band")
            .transpose(..., "band")
            .squeeze('year')
            .transpose('band', 'y', 'x')
            .rio.write_crs(f"epsg:{target_crs}")
        )

    return image.rio.to_raster(out_path)

In [6]:
if tile_of_interest: 
    print(f"Calculating cloud probability for a tile of interest {tile_of_interest}:")
    tile_name = tile_of_interest
    
    print("-" * 40)
    print(f'Process for tile {tile_name} started.', flush=True)
    
    selected_tile = tiles[tiles["tile_name"]==tile_of_interest]
    bbox_of_interest = selected_tile.total_bounds.tolist()
    print(f"Bbox of interest is: {bbox_of_interest}")

    catalog = Client.open(
    api_url
    )
    search = catalog.search(
        collections=collection,
        bbox=bbox_of_interest,
        datetime=time_of_interest
    )
    items = search.item_collection()
    print(f"Number of items: {len(items)}")
    print(f"Type of items: {type(items)}")
    inspect_search(items, index=0)

    data = odc.stac.stac_load(
        items, bands=bands_of_interest, 
        bbox=bbox_of_interest,
        resolution=res,
        crs=target_crs,
        chunks={'time': 20, 'x': 300, 'y': 300}
        #align=target_resolution_grid  # optional, aligns to custom grid # NOTE: ODC automatically extends the data from the bbox by a few pixels to each side
    ) 
    # NOTE: odc.stac.stac_load is not covered by the documentation yet (only the old version - odc.stac.load in 0.39.0)
    # TODO - to check docs 0.40.0

    print(data)
    print(data.dims) 
    
    # NOTE: TODO - for further development MULTIPLE TILES
    """for idx, row in tiles.bounds.iterrows():
    tile_name = tiles.loc[row.name, "tile_name"]

    print("-" * 40)
    print(f'Process for tile {tile_name} started.', flush=True)

    bbox_of_interest = row.to_list()
    print(bbox_of_interest)

    catalog = Client.open(
    api_url
    )
    search = catalog.search(
        collections=collection,
        bbox=bbox_of_interest,
        datetime=time_of_interest
    )
    items = search.item_collection()
    inspect_search(items, index=0)

    data = odc.stac.stac_load(
        items, bands=bands_of_interest, 
        bbox=bbox_of_interest,
        resolution=res,
        crs=target_crs
    )

    print(data)
    print(data.dims)"""

Calculating cloud probability for a tile of interest NZ44:
----------------------------------------
Process for tile NZ44 started.
Bbox of interest is: [-1.3800354189419266, 54.75137797895958, -1.0651866799222283, 54.93308063265072]
Number of items: 292
Type of items: <class 'pystac.item_collection.ItemCollection'>
----------------------------------------
Inspecting the asset #0
DATETIME is 2024-12-31 11:16:09.671000+00:00
GEOMETRY is {'type': 'Polygon', 'coordinates': [[[-1.4892868190198094, 55.037602642484046], [-1.945565451302893, 54.055591019680406], [-1.3232267769574286, 54.04852390275896], [-1.2822817967468176, 55.034856954918745], [-1.4892868190198094, 55.037602642484046]]]}
PROPERTIES are:
{'created': '2024-12-31T16:39:31.284Z', 'platform': 'sentinel-2a', 'constellation': 'sentinel-2', 'instruments': ['msi'], 'eo:cloud_cover': 100, 'proj:centroid': {'lat': 54.46203, 'lon': -1.53085}, 'mgrs:utm_zone': 30, 'mgrs:latitude_band': 'U', 'mgrs:grid_square': 'WF', 'grid:code': 'MGRS-30

In [7]:
def export_s2_scenes(data, items, output_dir="data/test", band="cloud"):
    """
    Export Sentinel-2 scenes from an xarray Dataset or DataArray to individual GeoTIFFs.
    Useful for visual checks in 
    Filenames are based on the STAC 's2:tile_id' property."""
    
    import os

    os.makedirs(output_dir, exist_ok=True)
    # extract tile IDs from STAC items
    tile_ids = [item.properties.get("s2:tile_id", f"scene_{i}") for i, item in enumerate(items)]
    # attach as the coordinate
    if "tile_id" not in data.coords:
        data = data.assign_coords(tile_id=("time", tile_ids))
    # wrap in a dataset if it's dataarray
    if isinstance(data, xr.DataArray):
        print("Input is a DataArray — converting to Dataset for export.")
        data = data.to_dataset(name=data.name or band)
    # if band exists
    if band not in data.data_vars:
        raise ValueError(f"Band '{band}' not found in dataset. Available bands: {list(data.data_vars)}")
        
    # loop over scenes
    for i, tile_id in enumerate(data.tile_id.values):
        print(f"Processing scene {i+1}/{len(data.time)} → {tile_id}")

        # select one scene and load into memory
        scene = data.isel(time=i).compute()
        out_path = os.path.join(output_dir, f"{band}_{tile_id}.tif")
        # to check if crs is written
        scene_band = scene[band]
        scene_band = scene_band.rio.write_crs(scene_band.rio.crs or data[band].rio.crs, inplace=False)
        
        scene_band.rio.to_raster(out_path)
        print(f"Exported: {out_path}")

    print(f"\n All {len(data.time)} scenes exported to '{output_dir}'.")

# USAGE (to export unmasked scenes)
# export_s2_scenes(data, items, output_dir="data/test/unmasked", band="cloud")

In [8]:
try:
    cloud_masked = mask_with_scl(data)
    print(f"Cloud masked with SCL.")
except Exception as e:
    print(f"Failed to mask cloud with SCL band")

print("Type:", type(cloud_masked))
print("Name:", cloud_masked.name)
print("Dimensions:", cloud_masked.dims)
print("Coordinates:", list(cloud_masked.coords))

#DEBUG: check each scene
#export_s2_scenes(cloud_masked, items, output_dir="data/test/masked", band="cloud")

Cloud dimensions: ('time', 'y', 'x')
SCL dimensions: ('time', 'y', 'x')
-1 is set as the no-data value (no NaNs remain).
Cloud masked with SCL.
Type: <class 'xarray.core.dataarray.DataArray'>
Name: cloud
Dimensions: ('time', 'y', 'x')
Coordinates: ['y', 'x', 'spatial_ref', 'time']


In [9]:
# NOTE: version with cloud_prob, masked by scl:
# NOTE: do not skip `.where(cloud-masked!=-1)` because otherwise it just won't consider pixels covered by satellite, but with a 0 cloud probability
# NOTE: we use now `-1` instead of 0 because that's our new no data value
cloud_prob = (
    cloud_masked
    .where(cloud_masked != -1)
    .groupby("time.year")
    .mean(dim="time", skipna=True)
    .astype('float32')
)

print(cloud_prob.rio.crs)
print(cloud_prob.rio.bounds()) # NOTE: bounds are extended

# NOTE: DEBUG check of the output dataset (unique values, 0 values, no data values) - this is all very heavy for Dask computations
"""cloud_prob_2d = cloud_prob.squeeze('year')  # shape: (y, x)
unique_values = np.unique(cloud_prob_2d.values)
print(f"Number of unique values: {len(unique_values)}")
print("First 20 unique values:", unique_values[:20])
num_zeros = np.sum(cloud_prob_2d.values == 0)
total_pixels = cloud_prob_2d.size
percent_zeros = num_zeros / total_pixels * 100
print(f"Number of pixels with value 0: {num_zeros}")
print(f"Percent of zeros: {percent_zeros:.2f}%")

print(cloud_prob.shape)"""

'''
cloud_prob_2d = cloud_prob.squeeze('year')  # shape: (y, x)
num_zeros = np.sum(cloud_prob_2d.values == 0)
total_pixels = cloud_prob_2d.size
percent_zeros = num_zeros / total_pixels * 100
print(f"Number of pixels with value 0: {num_zeros}")
print(f"Percent of zeros: {percent_zeros:.2f}%")

# squeeze year dimension if single year
cloud_prob = cloud_prob.squeeze('year')

# convert variables to float32
cloud_prob = cloud_prob.astype('float32')

# set spatial dims for rioxarray
cloud_prob = cloud_prob.rio.set_spatial_dims(x_dim='x', y_dim='y')
# set CRS
cloud_dataset = cloud_prob.rio.write_crs(27700)

out_path = "data/cloud_median_count_2024.tif"
cloud_dataset.rio.to_raster(out_path)
print(f"Saved multi-band GeoTIFF to {out_path}")'''

out_path = f'data/{tile_name}_2024.tif'
try:
    to_geotif(cloud_prob, bands_of_interest[:-1], out_path=out_path)
    print(f"Output GeoTIFF saved to {out_path}")
except Exception as e:
    print(f"Failed to save GeoTIFF: {e}")

'''
tif = xrx.open_rasterio("data/NZ26_2024.tif", masked=True)  # masked=True treats nodata as NaN
num_nodata = int(tif.isnull().sum())
total_pixels = tif.size
percent_nodata = num_nodata / total_pixels * 100

print(f"GeoTIFF shape: {tif.shape}")
print(f"NoData pixels (NaN): {num_nodata}")
print(f"Percent NoData: {percent_nodata:.2f}%")'''

EPSG:27700
(439820.0, 539760.0, 460280.0, 560240.0)
dataarray


  dest = _reproject(


Output GeoTIFF saved to data/NZ44_2024.tif


'\ntif = xrx.open_rasterio("data/NZ26_2024.tif", masked=True)  # masked=True treats nodata as NaN\nnum_nodata = int(tif.isnull().sum())\ntotal_pixels = tif.size\npercent_nodata = num_nodata / total_pixels * 100\n\nprint(f"GeoTIFF shape: {tif.shape}")\nprint(f"NoData pixels (NaN): {num_nodata}")\nprint(f"Percent NoData: {percent_nodata:.2f}%")'

In [10]:
# TODO - decide how to handle cropping tiles in mosaic (as odc stac returns arrays with extended borders)
"""
def crop_array(array, bbox):
"Crops the output array by the initial bounding box.
Parameters:
xarray.DataArray: output array
bbox: initial bounding box
Returns:
xarray.DataArray: cropped array
    .rio.clip_box() # to crop by the initial bounding box
    
crop_array(cloud_prob,bbox)    
"""

'\ndef crop_array(array, bbox):\n"Crops the output array by the initial bounding box.\nParameters:\nxarray.DataArray: output array\nbbox: initial bounding box\nReturns:\nxarray.DataArray: cropped array\n    .rio.clip_box() # to crop by the initial bounding box\n    \ncrop_array(cloud_prob,bbox)    \n'

In [11]:
def check_array(array):
    """General info about the output array"""
    print(type(array))
    print(array.shape)
    print(array.coords)
    print(array.attrs)
    print(array.dims)
    num_nodata = array.isnull().sum()
    print(f"Total nodata values: {num_nodata}")

check_array(cloud_prob)

<class 'xarray.core.dataarray.DataArray'>
(1, 1024, 1023)
Coordinates:
  * y            (y) float64 8kB 5.602e+05 5.602e+05 ... 5.398e+05 5.398e+05
  * x            (x) float64 8kB 4.398e+05 4.398e+05 ... 4.602e+05 4.603e+05
    spatial_ref  int32 4B 27700
  * year         (year) int64 8B 2024
{'nodata': -1}
('year', 'y', 'x')
Total nodata values: <xarray.DataArray 'cloud' ()> Size: 8B
dask.array<sum-aggregate, shape=(), dtype=int64, chunksize=(), chunktype=numpy.ndarray>
Coordinates:
    spatial_ref  int32 4B 27700


In [12]:
def inspect_chunks(obj):
    """
    Inspects Dask chunking for an xarray Dataset or DataArray.
    Prints total number of chunks, average size, and alignment info.
    """
    # Handle Dataset (multiple variables)
    if isinstance(obj, xr.Dataset):
        print(f"Dataset with {len(obj.data_vars)} variables:")
        print("=" * 60)
        for var_name, da in obj.data_vars.items():
            print(f"\nVariable: {var_name}")
            inspect_chunks(da)
        return

    da = obj #handle dataarray (single variable)

    if not hasattr(da.data, "chunks"):
        print("Array not chunked (not a Dask array).")
        return

    chunks = da.data.chunks
    dtype_size = da.dtype.itemsize

    print("-" * 60)
    total_chunks = 1
    uneven = False

    for dim, sizes in zip(da.dims, chunks):
        total_chunks *= len(sizes)
        equal = len(set(sizes)) == 1
        if not equal:
            uneven = True
        print(f"{dim:>6}: {len(sizes)} chunks | sizes = {sizes[:5]}{'...' if len(sizes) > 5 else ''}")

    avg_chunk_elems = np.prod([np.mean(s) for s in chunks])
    avg_chunk_bytes = avg_chunk_elems * dtype_size
    avg_chunk_mb = avg_chunk_bytes / 1e6

    print("-" * 60)
    print(f"Total chunks: {total_chunks}")
    print(f"Average chunk size: {avg_chunk_mb:.2f} MB ({da.dtype})")
    print(f"Chunks evenly sized? {'Yes' if not uneven else 'No, uneven chunks'}")

inspect_chunks(data)

Dataset with 2 variables:

Variable: cloud
------------------------------------------------------------
  time: 15 chunks | sizes = (20, 20, 20, 20, 20)...
     y: 4 chunks | sizes = (300, 300, 300, 124)
     x: 4 chunks | sizes = (300, 300, 300, 123)
------------------------------------------------------------
Total chunks: 240
Average chunk size: 1.27 MB (uint8)
Chunks evenly sized? No, uneven chunks

Variable: scl
------------------------------------------------------------
  time: 15 chunks | sizes = (20, 20, 20, 20, 20)...
     y: 4 chunks | sizes = (300, 300, 300, 124)
     x: 4 chunks | sizes = (300, 300, 300, 123)
------------------------------------------------------------
Total chunks: 240
Average chunk size: 1.27 MB (uint8)
Chunks evenly sized? No, uneven chunks


In [13]:
import numpy as np
import geopandas as gpd
from shapely.geometry import box

def get_chunk_polygons(da, var_name=None):
    """
    Generate a geodataframe with polygons for each Dask chunk of a dataarray.
    Parameters:
    da (xr.DataArray): dask-backed DataArray with 'x' and 'y' coordinates
    var_name (str, optional): variable name for labeling in the geodataframe
    Returns:
    gdf (geopandas.GeoDataFrame): each row is a polygon representing a chunk
    """
    # check
    if not hasattr(da.data, "chunks"):
        raise ValueError("Input array is not chunked (not a Dask array).")
    if 'x' not in da.dims or 'y' not in da.dims:
        raise ValueError("DataArray must have 'x' and 'y' dimensions.")

    # chunk borders
    x_dim = da.dims.index('x')
    y_dim = da.dims.index('y')

    x_chunks = da.data.chunks[x_dim]
    y_chunks = da.data.chunks[y_dim]
    x_edges = np.cumsum([0] + list(x_chunks))
    y_edges = np.cumsum([0] + list(y_chunks))

    x_vals = da['x'].values
    y_vals = da['y'].values
    y_descending = y_vals[0] > y_vals[-1]

    polygons = []
    for i in range(len(y_edges)-1):
        for j in range(len(x_edges)-1):
            xmin = x_vals[x_edges[j]]
            xmax = x_vals[x_edges[j+1]-1]
            if y_descending:
                ymax = y_vals[y_edges[i]]
                ymin = y_vals[y_edges[i+1]-1]
            else:
                ymin = y_vals[y_edges[i]]
                ymax = y_vals[y_edges[i+1]-1]
            polygons.append(box(xmin, ymin, xmax, ymax))

    gdf = gpd.GeoDataFrame({'variable': var_name or getattr(da, 'name', 'unnamed'),
                            'geometry': polygons},
                           crs=getattr(da.rio, 'crs', None))
    return gdf

cloud_gdf = get_chunk_polygons(data['cloud'], var_name="cloud")
cloud_gdf.to_file("data/cloud_chunks.gpkg", driver="GPKG")

In [14]:
    ####### PERFORMANCE without chunks
    # Loading collection (NZ26 tile, cloud + scl, without calculations, without chunks, 2024, 720 assets)  - 24m, 25s (1460s)
    # Loading collection (HP40 tile, cloud + scl, without calculations, without chunks, 2024, 333 assets) - 10m, 56s

    ####### PERFORMANCE WITH CHUNKS(without counting records per pixel per year)
    # chunks={'time': 20, 'x': 300, 'y': 300}
    # Tile NZ06 - load, mask, calculate, export (cloud + scl, 2024, 290 assets) ~233s
    # Tile NZ04 - load, mask, calculate, export (cloud + scl, 2024, 146 assets) ~144s
    # Tile NZ26 - load, mask, calculate, export (cloud + scl, 2024, 720 assets) 433s
    # Tile NZ46 - load, mask, calculate, export (cloud + scl, 2024, 584 assets) 286s
    # Tile NZ44 - load, mask, calculate, export (cloud + scl, 2024, 292 assets) 165s
    # Tile NZ24 - load, mask, calculate, export (cloud + scl, 2024, 292 assets) 167s
    ###### Other years
    # Tile NZ26 (2023) - load, mask, calculate, export (cloud+scl, 713 assets) - 358s
    # Tile NZ26 (2022) - load, mask, calculate, export (cloud+scl, 713 assets) ~40s
    # Tile NZ26 (2021) - load, mask, calculate, export (cloud+scl, 721 assets) ~411s   