In [1]:
# Description:
# 1. to provide a xarray dataset
# 2. to provide a dataframe ?
# 3. to provide a number of acquisitions per pixel

# pip install odc-stac
# pip install pystac-client
# pip install dask

import numpy as np
import pandas as pd
import xarray as xr

from pystac_client import Client
import geopandas as gpd
import rioxarray as xrx

import odc.stac

import dask.array as da
import dask.dataframe as dd

tiles="data/uk_20km_grid.gpkg"
tile_of_interest="NZ26"
collection = "sentinel-2-c1-l2a"
time_of_interest = '2024-01-01/2024-12-31'
bands_of_interest = ["cloud", "scl"]
api_url="https://earth-search.aws.element84.com/v1"
res = 20
target_crs=27700


tiles = gpd.read_file(tiles).to_crs(4326)
# aoi = gpd.read_file(aoi).to_crs(4326)

In [2]:
def inspect_search(items,index=0):
    """Print inspect/debug info about the one of the assets (usually the first one).
    Parameters:
    index (int=0): random tile number to inspect (Default: 0)
    """
    item = items[index]
    try:
        print("-" * 40)
        print(f"Inspecting the asset #{index}")
        print(f"DATETIME is {item.datetime}")
        print(f"GEOMETRY is {item.geometry}")
        print(f"PROPERTIES are:\n{item.properties}")
        print(f"CRS: {item.properties.get('proj:code') or item.properties.get('proj:epsg')}")
        print(f"STAC metadata:\n")
        metadata=odc.stac.extract_collection_metadata(item, cfg=None, md_plugin=None)
        list(metadata.keys())
        print("-" * 40)
    except Exception as e:
        print("-" * 40)
        print(f"Error checking item[{index}]: {e}")
        print("-" * 40)


In [3]:
def mask_with_scl(data):

    bitmask = 0 # NOTE: all pixels=0 in SCL band are true no-data values
    '''for field in mask_bitfields:
        bitmask |= 1 << field'''

    cloud = data["cloud"].astype("float32") # NOTE: cloud probability might be float, don't use int
    scl=data["scl"].astype("int16") 

    good_mask = scl != 0

    # apply mask to cloud
    cloud_masked = cloud.where(good_mask)
    # NOTE: don't need broadcasting because cloud and scl have the same dimensions

    print(f"Cloud dimensions:\n{cloud.dims}")
    print(f"SCL dimensions:\n{scl.dims}")
    print(f"Cloud masked dimensions:\n{cloud_masked.dims}")

    return cloud_masked

    """
    # define bad pixels
    nodatas = bitmask
    all_bad_pixels = nodatas(dim="time")

    # Expand dimensions to match data shape
    all_bad_expanded = all_bad_pixels.broadcast_like(nodatas)

    # For these pixels, weâ€™ll override and mark them as good
    effective_bad_mask = nodatas.where(~all_bad_expanded, other=False)

    # Apply the mask: keep data where bad == False
    masked = data.where(~effective_bad_mask)
    return masked
    """

In [4]:
def to_geotif(dataset,bands_of_interest:list=None, out_path:str=None):
    """The function save the Dataset with several variables to GeoTiff """
    # Taken from: https://discourse.pangeo.io/t/comparing-odc-stac-load-and-stackstac-for-raster-composite-workflow/4097

    if isinstance(dataset, xr.DataArray):
        print(f"dataarray")
        image = dataset.squeeze('year').rio.write_crs(27700)

        num_nodata = image.isnull().sum().item()
        print(f"Number of NoData (NaN) pixels: {num_nodata}")
        
    else:
        image = (
            dataset[bands_of_interest]
            .to_dataarray(dim="band")
            .transpose(..., "band")
            .squeeze('year')
            .transpose('band', 'y', 'x')
            .rio.write_crs(f"epsg:{target_crs}")
        )

    print(f"Output GeoTIFF saved to {out_path}")

    return image.rio.to_raster(out_path)

In [None]:
if tile_of_interest: 
    print(f"Calculating cloud probability for a tile of interest {tile_of_interest}:")
    tile_name = tile_of_interest
    
    print("-" * 40)
    print(f'Process for tile {tile_name} started.', flush=True)
    
    selected_tile = tiles[tiles["tile_name"]==tile_of_interest]
    bbox_of_interest = selected_tile.total_bounds.tolist()
    print(f"Bbox of interest is: {bbox_of_interest}")

    catalog = Client.open(
    api_url
    )
    search = catalog.search(
        collections=collection,
        bbox=bbox_of_interest,
        datetime=time_of_interest
    )
    items = search.item_collection()
    print(f"Number of items: {len(items)}")
    inspect_search(items, index=0)

    data = odc.stac.stac_load(
        items, bands=bands_of_interest, 
        bbox=bbox_of_interest,
        resolution=res,
        crs=target_crs
        # chunks={'time': 12, 'x': 600, 'y': 600}
        #align=target_resolution_grid  # optional, aligns to custom grid # NOTE: ODC automatically extends the data from the bbox by 4 pixels to each side
    )
    # NOTE: odc.stac.stac_load is not covered by documentation (only the old version - odc.stac.load

    print(data)
    print(data.dims)

    # NOTE: Loading collection (NZ26 tile, cloud + scl, without calculations, without chunks, 2024, 720 assets)  - 24m, 25s (1460s)
    # NOTE: Loading collection (HP40 tile, cloud + scl, without calculations, without chunks, 2024, 333 assets) - 10m, 56s

    # MULTIPLE TILES
    """for idx, row in tiles.bounds.iterrows():
    tile_name = tiles.loc[row.name, "tile_name"]

    print("-" * 40)
    print(f'Process for tile {tile_name} started.', flush=True)

    bbox_of_interest = row.to_list()
    print(bbox_of_interest)


    catalog = Client.open(
    api_url
    )
    search = catalog.search(
        collections=collection,
        bbox=bbox_of_interest,
        datetime=time_of_interest
    )
    items = search.item_collection()
    inspect_search(items, index=0)

    data = odc.stac.stac_load(
        items, bands=bands_of_interest, 
        bbox=bbox_of_interest,
        resolution=res,
        crs=target_crs
    )

    print(data)
    print(data.dims)"""

Calculating cloud probability for a tile of interest NZ26:
----------------------------------------
Process for tile NZ26 started.
Bbox of interest is: [-1.6893950016480788, 54.93308063265072, -1.3744845424180632, 55.11399779421551]
Number of items: 720
----------------------------------------
Inspecting the asset #0
DATETIME is 2024-12-31 11:16:09.671000+00:00
GEOMETRY is {'type': 'Polygon', 'coordinates': [[[-1.4892868190198094, 55.037602642484046], [-1.945565451302893, 54.055591019680406], [-1.3232267769574286, 54.04852390275896], [-1.2822817967468176, 55.034856954918745], [-1.4892868190198094, 55.037602642484046]]]}
PROPERTIES are:
{'created': '2024-12-31T16:39:31.284Z', 'platform': 'sentinel-2a', 'constellation': 'sentinel-2', 'instruments': ['msi'], 'eo:cloud_cover': 100, 'proj:centroid': {'lat': 54.46203, 'lon': -1.53085}, 'mgrs:utm_zone': 30, 'mgrs:latitude_band': 'U', 'mgrs:grid_square': 'WF', 'grid:code': 'MGRS-30UWF', 'view:azimuth': 108.6879993147173, 'view:incidence_angle'

In [None]:
'''
data_mask = mask_with_scl(data)
print(data_mask)
print(data_mask.dims)
'''
'''
cloud_data = data['cloud'].where(data['cloud'] != 0, np.nan) # NOTE: keep, otherwise could be considered and cropped
cloud_prob = cloud_data.groupby("time.year").median(dim="time", skipna=True)''' # THIS will crash (too much in memory?)

# NOTE: replaced strings above
cloud_prob = (
    data['cloud']
    .where(data['cloud'] != 0) # NOTE: replaced `.where(data['cloud'] != 0, np.nan)` as lasily filtering and defaults other values to False
    .groupby("time.year")
    .mean(dim="time", skipna=True)
    .astype('float32')
)

print(cloud_prob.rio.crs)
print(cloud_prob.rio.bounds())

print(cloud_prob(type))

# NOTE: DEBUG check of the output dataset (unique values, 0 values, no data values)
cloud_prob_2d = cloud_prob.squeeze('year')  # shape: (y, x)
unique_values = np.unique(cloud_prob_2d.values)
print(f"Number of unique median values: {len(unique_values)}")
print("First 20 unique values:", unique_values[:20])
num_zeros = np.sum(cloud_prob_2d.values == 0)
total_pixels = cloud_prob_2d.size
percent_zeros = num_zeros / total_pixels * 100
print(f"Number of pixels with value 0: {num_zeros}")
print(f"Percent of zeros: {percent_zeros:.2f}%")

print(cloud_prob.shape)

'''
cloud_prob_2d = cloud_prob.squeeze('year')  # shape: (y, x)
num_zeros = np.sum(cloud_prob_2d.values == 0)
total_pixels = cloud_prob_2d.size
percent_zeros = num_zeros / total_pixels * 100
print(f"Number of pixels with value 0: {num_zeros}")
print(f"Percent of zeros: {percent_zeros:.2f}%")

# Squeeze year dimension if single year
cloud_prob = cloud_prob.squeeze('year')

# Convert variables to float32
cloud_prob = cloud_prob.astype('float32')

# Set spatial dims for rioxarray
cloud_prob = cloud_prob.rio.set_spatial_dims(x_dim='x', y_dim='y')

# Set CRS (replace 27700 with your CRS if different)
cloud_dataset = cloud_prob.rio.write_crs(27700)

out_path = "data/cloud_median_count_2024.tif"
cloud_dataset.rio.to_raster(out_path)
print(f"Saved multi-band GeoTIFF to {out_path}")'''

out_path = f'data/{tile_name}_2024.tif'
to_geotif(cloud_prob,bands_of_interest[:-1], out_path=out_path)
'''to_geotif(cloud_prob_count, bands_of_interest[:-1], out_path=f'data/{tile_name}_count_2024.tif')
'''

'''
tif = xrx.open_rasterio("data/NZ26_2024.tif", masked=True)  # masked=True treats nodata as NaN
num_nodata = int(tif.isnull().sum())
total_pixels = tif.size
percent_nodata = num_nodata / total_pixels * 100

print(f"GeoTIFF shape: {tif.shape}")
print(f"NoData pixels (NaN): {num_nodata}")
print(f"Percent NoData: {percent_nodata:.2f}%")'''

In [None]:
print(type(cloud_prob))
print(cloud_prob.shape)       # (num_years, y, x)
print(cloud_prob.coords)      # shows year, y, x coordinates
print(cloud_prob.attrs)
print(cloud_prob.dims)

num_nodata = cloud_prob.isnull().sum().item()
print(f"Total NoData (NaN) values: {num_nodata}")