In [1]:
# Import libraries
import numpy as np
import stackstac
import pystac_client
import planetary_computer
import xrspatial.multispectral as ms
import dask.array as da
from dask.distributed import Client, LocalCluster
from urllib3.util.retry import Retry
from dask.diagnostics import ProgressBar
import xarray as xr
import bottleneck
import matplotlib.pyplot as plt
import matplotlib as mpl
import rioxarray
import os

In [3]:
# Set variables
local = False # This controls whether or not a cluster is created using Coiled as the provisioning service, for testing a local cluster is recommended.
resolution = 10 # This controls the spatial resolution of the fetched datasets - this is changed dynamically using resampling by StackSTAC
bands = ['B04', 'B03', 'B02','B07', 'B08', 'SCL'] # This controls the bands of the input data to fetch, it can be useful to limit this to just one band + SCL for testing purposes. SCL is always needed, but the others can be picked as you choose (see Sentinel 2 documentation for band names and descriptions.)

# Set search parameters
months = ['june', 'july', 'august'] # The months per year to search for data within
years = ['2018','2019','2020','2021', '2022', '2023'] # The years to iterate through, processing one median over the above months per year
geojson_path = r"./working_dir/study_area/Arvidsjaur.geojson" # Point this to where your study area is located, should be a GeoJSON Polygon file. 
working_dir = "./working_dir/" # Point this to a folder named working_dir in the same directory as this script (you might need to create it)
max_items = 10 # This limits the amount of items returned from the STAC catalog after filtering - Useful to set to a low value while testing a workflow over and over again.
tile_max_cloud = 20 # This limits the included items based on their cloud_percentage parameter. This value is however tile wide!! I.e this statistic has been calculated on a whole Sentinel 2 tile, which is 100x100km and thus not neccesarily applicable for the area of interest.


# Define the directory structure
directories = [
    'working_dir',
    'working_dir/labels',
    'working_dir/prediction_maps',
    'working_dir/study_area'
]

# Create directories if they don't exist
for directory in directories:
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Created directory: {directory}")
    else:
        print(f"Directory already exists: {directory}")


Directory already exists: working_dir
Created directory: working_dir/labels
Directory already exists: working_dir/prediction_maps
Directory already exists: working_dir/study_area


In [None]:
# Set parameters and initialize pystac + coiled
# The retry parameter is used for accessing the STAC catalog, as recommmended by Microsoft.

retry = Retry(
    total=5, backoff_factor=1, status_forcelist=[502, 503, 504], allowed_methods=None
)
pystac_client.stac_api_io.StacApiIO(max_retries=retry)


# If the local parameter is set to true, skip importing Coiled and just make a cluster on this local machine. Workers is automatically equal to the number of cores of your processor in that case.
if local:
    cluster = LocalCluster(name = "Linnars-Dator")
    client = Client(cluster)
else:
    import coiled # See the documentation available for Coiled on how to set up Coiled + Dask!
    cluster = coiled.Cluster(name="WetlandsClassification", shutdown_on_close=True)
    cluster.adapt(n_workers = 1, maximum=10)
    client = cluster.get_client()



In [17]:
# This cell converts the GeoJSON to a suitable geometry for searching the STAC catalog. If the GeoJSON has multiple polygons, the first one is used. Either Polygons or MultiPolygons can be used.

# Get search bbox from GeoJSON
import json

# After loading the GeoJSON file
with open(geojson_path) as f:
    region = json.load(f)

# Extract the geometry from the FeatureCollection
if region['type'] == 'FeatureCollection':
    # Get the first feature's geometry
    geometry = region['features'][0]['geometry']
    
    # If it's a MultiPolygon, it's already in the correct format
    # If it's a single Polygon, it will also work
    if geometry['type'] in ['MultiPolygon', 'Polygon']:
        search_geometry = geometry
    else:
        raise ValueError(f"Unsupported geometry type: {geometry['type']}")
else:
    search_geometry = region

In [None]:
# This gets the extents of the area polygon above, which is the actual area used to search for data.

def get_bbox(geometry):
    if geometry['type'] == 'Polygon':
        coords = np.array(geometry['coordinates'][0])
        return [coords[:, 0].min(), coords[:, 1].min(), coords[:, 0].max(), coords[:, 1].max()]
    elif geometry['type'] == 'MultiPolygon':
        all_coords = []
        for polygon in geometry['coordinates']:
            all_coords.extend(polygon[0])
        coords = np.array(all_coords)
        return [coords[:, 0].min(), coords[:, 1].min(), coords[:, 0].max(), coords[:, 1].max()]
    else:
        raise ValueError(f"Unsupported geometry type: {geometry['type']}")

# Calculate and print the bounding box
bbox = get_bbox(search_geometry)
print("Bounding box coordinates (minx, miny, maxx, maxy):", bbox)

In [None]:
# Create a list of datetime ranges
datetime_ranges = []
for year in years:
    # Get first and last month from the months list
    first_month = months[0]
    last_month = months[-1]
    
    # Convert month names to numbers (1-12)
    month_to_num = {
        'january': '01', 'february': '02', 'march': '03', 'april': '04',
        'may': '05', 'june': '06', 'july': '07', 'august': '08',
        'september': '09', 'october': '10', 'november': '11', 'december': '12'
    }
    
    # Convert month names to numbers and ensure they're zero-padded
    first_month = month_to_num[first_month.lower()]
    last_month = month_to_num[last_month.lower()]
    
    datetime_ranges.append(f"{year}-{first_month}-01/{year}-{last_month}-31")

datetime_ranges

In [None]:

# Sometimes reading a dataset will fail - usually when processing times are long (over 30min per datetime)
# I think this is to do with the access to the PySTAC items timing out. 
# This needs some form of error handling, but I have been working around the error through
# just rerunning the script and removing the already fetched years from the list of years above.
for datetime in datetime_ranges:

    catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
    )
    # Create an itemcollection with all the selected datetime ranges
    search = catalog.search(
    collections=["sentinel-2-l2a"],
    intersects = search_geometry,
    datetime=datetime,
    max_items = max_items,
    query={"eo:cloud_cover": {"lt": tile_max_cloud}}
    )
    items = search.item_collection()
    print(f"\nFound {len(items)} items for range {datetime}")

    # Stack itemcollection
    data = (
        stackstac.stack(
            items,
            assets=bands,  # The selected bands from the assets list 
            resolution=resolution,
            epsg=3006,
            chunksize= (-1, 1, 128, 128),
            dtype=np.dtype('float'),
            bounds_latlon = bbox
        )
        .where(lambda x: x > 0, other=np.nan) 
    )

    # After creating your dask array
    # Learning about how Dask data is chunked is useful - chunksize needs to be big enough to get good parallelization efficiency, but small enough to fit in memory. A size of around 100MB chunks is generally a good aim. 
    print("Array size information:")
    print(f"Shape: {data.shape}")
    print(f"Size in bytes: {data.data.nbytes}")
    print(f"Size in GB: {data.data.nbytes / 1e9:.2f} GB")
    print(f"Number of chunks: {data.data.npartitions}"))

    # Filter out pixels in all bands by the corresponding SCL classification. The classes to include are according to their pixel value, see Sentinel 2 SCL band documentation.
    scl_band_name = 'SCL'
    clear_classes = [4,5,6,7,11]

    try:
        data_scl = data.sel(band=scl_band_name)
    except KeyError:
        print(f"Error: SCL band names '{scl_band_name}' was not found in data.coords['band']. Available bands: {list(data.band.values)}")

    # Create mask for clear pixels
    clear_mask = data_scl.isin(clear_classes)

    # Apply mask to all bands except SCL
    data = data.where(clear_mask)

    # Compute median for the year
    data_median = data.median(dim='time', skipna=True).compute()


    # Write the dataset to a GeoTIFF file with timestamp
    output_path = f"./working_dir/median_pixels_{datetime.replace('/', '_')}.tif"
    data_median.rio.to_raster(output_path, driver="GTiff", compress = "LZW")
    print(f"Saved dataset to {output_path}")