# Construct Zarr Chips

This notebook constructs zarr chips from the raw (tile based sentinel 2) data on the server

In [1]:
from pathlib import Path
import xml.etree.ElementTree as ET
import os
import warnings

import numpy as np
import pandas as pd
import geopandas as gpd
from tqdm import tqdm
import xarray as xr
import rioxarray
import shutil

# Open all Samples and Labels 

In [2]:
samples = gpd.read_parquet("data/samples.parquet")
labels = pd.read_parquet("data/labels.parquet")

In [3]:
bands = {
    "blue": "B02_10m",
    "green": "B03_10m",
    "red": "B04_10m",
    "nir": "B08_10m",
    "rededge1": "B05_20m",
    "rededge2": "B06_20m",
    "rededge3": "B07_20m",
    "nir08": "B8A_20m",
    "swir16": "B11_20m",
    "swir22": "B12_20m",
    "scl": "SCL_20m",
}

In [5]:
def get_band_files(product_folder):
    # Parses the manifest.SAFE file for elements
    tree = ET.parse(product_folder / "manifest.safe")
    root = tree.getroot()

    data_objects = root.find("dataObjectSection").findall("dataObject")

    band_files = {}

    for data_object in data_objects:
        file_location = (
            data_object.find("byteStream").find("fileLocation").attrib["href"]
        )  # The path to that file

        # Searches for a match of the band name within the ID of the object
        for band in bands.values():
            if band in file_location:
                band_files[band] = product_folder / file_location

    return band_files

In [6]:
def write_zarr_chip(ds, geom, sample_id):
    # Find nearest x/y indices
    x_idx = np.abs(ds.x - geom.x).argmin().item()
    y_idx = np.abs(ds.y - geom.y).argmin().item()

    # Define window size
    full_size = 128
    x_start = max(0, x_idx - (full_size // 2))
    x_end = x_start + full_size
    y_start = max(0, y_idx - (full_size // 2))
    y_end = y_start + full_size

    # Subset dataset
    subset = ds.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))
    data_mask = (subset["SCL"] == 0).any(dim=("y", "x"))
    if data_mask:
        warnings.warn(
            f"Sample {sample_id} contains no valid data (SCL == 0). Skipping."
        )
        return

    get_data = subset.astype("uint16")

    # There's a bug with xarray doing over-eager conversion of timestamps (see https://github.com/pydata/xarray/issues/3942)
    # so we need to specify a time encoding
    encoding = {
        "time": {
            "units": "seconds since 2015-01-01",
            "calendar": "standard",
            "dtype": "int64",
        }
    }

    # This is to remove scale and offset, it messes with appending correct dtypes
    for data_var in get_data.data_vars:
        encoding[data_var] = {"dtype": "uint16"}
        get_data[data_var].attrs = {}

    # Write to zarr
    zarr_path = f"data/chips/{sample_id}.zarr"
    if not os.path.exists(zarr_path):
        # First time: initialize the zarr
        get_data.to_zarr(zarr_path, mode="w", zarr_format=3, encoding=encoding)
    else:
        # Next times: append along time
        get_data.drop_attrs().to_zarr(
            zarr_path, append_dim="time", mode="a", zarr_format=3
        )

In [7]:
def load_sentinel2_bands(band_files, resolution="10m"):
    """Load Sentinel-2 bands into a properly structured xarray Dataset"""

    # Filter bands by resolution
    filtered_bands = {k: v for k, v in band_files.items() if k.endswith(resolution)}

    # Load each band
    band_arrays = {}
    for band_name, file_path in filtered_bands.items():
        # Open with rioxarray to preserve spatial reference
        da = rioxarray.open_rasterio(file_path, chunks=True, mask_and_scale=False)
        da = da.squeeze().drop_vars("band")  # Remove band dimension (it's singular)

        # Clean band name (B02, B03, etc.)
        clean_name = band_name.split("_")[0]
        band_arrays[clean_name] = da

    # Create dataset
    ds = xr.Dataset(band_arrays)

    return ds

In [8]:
def filter_acquisitions(acquisitions, samples_df):
    # try to load zarr from last sample_id and check which acquisitions have already been added
    last_sample = samples_df.iloc[-1]
    zarr_path = f"data/chips/{last_sample.sample_id}.zarr"
    if not os.path.exists(zarr_path):
        return acquisitions
    # getting last added acquisition and adding some tolerance due to smaller precision of time stored in zarr
    last_added_acquisition = pd.Timestamp(
        xr.open_zarr(zarr_path).time.values[-1]
    ) + pd.Timedelta(seconds=10)
    return [
        acquisition
        for acquisition in acquisitions
        if pd.Timestamp(acquisition.stem.split("_")[1]) > last_added_acquisition
    ]

# FNEWS

In [17]:
fnews_samples = samples.query(
    "dataset=='Evoland' and source=='Regional Forestry Departments'"
)
tiles = list(fnews_samples["s2_tile"].unique())[3:]
tiles

['33UUS', '33UVS', '32UMU', '32TMT']

In [56]:
fnews_samples

Unnamed: 0,sample_id,original_sample_id,interpreter,dataset,source,source_description,s2_tile,cluster_id,cluster_description,comment,confidence,geometry
606,606,612,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",32UNA,198.0,Damage polygons,"1, 2019/08/18",high,POINT (9.05234 50.07864)
607,607,616,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",32UNA,209.0,Damage polygons,"1, 2019/08/18",high,POINT (9.31017 50.11753)
608,608,624,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",32UNA,208.0,Damage polygons,"1, 2019/08/18",high,POINT (9.0414 50.07218)
609,609,625,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",32UNA,200.0,Damage polygons,"1, 2019/08/18",high,POINT (9.04996 50.06712)
610,610,631,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",32UNA,202.0,Damage polygons,"1, 2019/08/18",high,POINT (9.02246 50.08142)
...,...,...,...,...,...,...,...,...,...,...,...,...
1001,1001,1263,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",33UVS,421.0,Damage polygons,22,high,POINT (14.08833 50.99596)
1002,1002,1264,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",33UVS,420.0,Damage polygons,22,high,POINT (14.10273 51.0011)
1003,1003,1265,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",33UVS,,Damage polygons,,high,POINT (14.11205 51.00166)
1004,1004,1266,pum,Evoland,Regional Forestry Departments,"FNews Project, German Forestry Departmetns Sou...",33UVS,419.0,Damage polygons,22,high,POINT (14.10425 51.00364)


In [None]:
for tile in tiles:
    print(tile)
    acquisitions = list(
        Path(f"//digs110/FER/fnews/RasterData/L2A/{tile}").glob("*/*.SAFE")
    )
    fnews_tile_reprojected = fnews_samples.query("s2_tile==@tile").to_crs(
        f"EPSG:326{tile[0:2]}"
    )
    filtered_acquisitions = filter_acquisitions(acquisitions, fnews_tile_reprojected)
    for product_folder in tqdm(filtered_acquisitions):
        timestamp = pd.Timestamp(product_folder.stem.split("_")[1])
        try:
            band_files = get_band_files(product_folder)
        except FileNotFoundError:
            print(f"maniftest file missing for tile {product_folder}")
            continue

        # load data
        try:
            ds_10m = load_sentinel2_bands(band_files, "10m")
            ds_20m = (
                load_sentinel2_bands(band_files, "20m")
                .interp(
                    x=ds_10m["x"],
                    y=ds_10m["y"],
                    method="nearest",
                    kwargs={"fill_value": "extrapolate"},
                )
                .astype("uint16")
            )
        except KeyError:
            print(f"some bands not found in manifest for tile {product_folder}")
            continue
        ds = (
            xr.merge([ds_10m, ds_20m])
            .expand_dims(dim="time")
            .assign_coords(time=[timestamp])
            .compute()
        )

        # write out chips for all samples
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            _ = fnews_tile_reprojected.apply(
                lambda geo_series: write_zarr_chip(
                    ds, geo_series.geometry, geo_series.sample_id
                ),
                axis=1,
            )

33UUS


0it [00:00, ?it/s]


33UVS


100%|██████████| 421/421 [4:00:59<00:00, 34.35s/it]  


32UMU


  1%|          | 5/469 [02:03<3:13:20, 25.00s/it]

maniftest file missing for tile \\digs110\FER\fnews\RasterData\L2A\32UMU\2015\L2A_20150908T103707_108A_32UMU.SAFE


  3%|▎         | 15/469 [05:46<3:07:25, 24.77s/it]

maniftest file missing for tile \\digs110\FER\fnews\RasterData\L2A\32UMU\2016\L2A_20160205T103556_108A_32UMU.SAFE


  8%|▊         | 37/469 [14:23<3:18:07, 27.52s/it]

maniftest file missing for tile \\digs110\FER\fnews\RasterData\L2A\32UMU\2016\L2A_20160929T102344_065A_32UMU.SAFE


  9%|▊         | 40/469 [15:24<2:52:08, 24.08s/it]

maniftest file missing for tile \\digs110\FER\fnews\RasterData\L2A\32UMU\2016\L2A_20161022T103357_108A_32UMU.SAFE
maniftest file missing for tile \\digs110\FER\fnews\RasterData\L2A\32UMU\2016\L2A_20161101T103156_108A_32UMU.SAFE


100%|██████████| 469/469 [3:09:34<00:00, 24.25s/it]  


32TMT


100%|██████████| 516/516 [3:49:35<00:00, 26.70s/it]  


# Evoland

In [11]:
evo_samples = samples.query(
    "dataset=='Evoland' and source != 'Regional Forestry Departments'"
)
tiles = list(evo_samples["s2_tile"].unique())
tiles

['30SUF', '33VVJ']

In [12]:
for tile in tiles:
    print(tile)
    acquisitions = list(
        Path(f"//digs110/FER/EvoLand/WP2_6_CFM/RasterData/L2A/{tile}").glob("*.SAFE")
    )
    tiles_reprojected = evo_samples.query("s2_tile==@tile").to_crs(
        f"EPSG:326{tile[0:2]}"
    )
    filtered_acquisitions = filter_acquisitions(acquisitions, tiles_reprojected)
    for product_folder in tqdm(filtered_acquisitions):
        timestamp = pd.Timestamp(product_folder.stem.split("_")[1])
        try:
            band_files = get_band_files(product_folder)
        except FileNotFoundError:
            print(f"manifest file missing for tile {product_folder}")
            continue

        # load data
        try:
            ds_10m = load_sentinel2_bands(band_files, "10m")
            ds_20m = (
                load_sentinel2_bands(band_files, "20m")
                .interp(
                    x=ds_10m["x"],
                    y=ds_10m["y"],
                    method="nearest",
                    kwargs={"fill_value": "extrapolate"},
                )
                .astype("uint16")
            )
        except KeyError:
            print(f"some bands not found in manifest for tile {product_folder}")
            continue
        ds = (
            xr.merge([ds_10m, ds_20m])
            .expand_dims(dim="time")
            .assign_coords(time=[timestamp])
            .compute()
        )

        # write out chips for all samples
        with warnings.catch_warnings():
            # filtering warning about consolidated zarr metadata
            warnings.filterwarnings("ignore")
            _ = tiles_reprojected.apply(
                lambda geo_series: write_zarr_chip(
                    ds, geo_series.geometry, geo_series.sample_id
                ),
                axis=1,
            )

30SUF


100%|██████████| 216/216 [4:24:49<00:00, 73.56s/it]  


33VVJ


100%|██████████| 137/137 [3:26:28<00:00, 90.42s/it] 


# HRVPP

In [4]:
hrvpp_samples = samples.query("dataset=='HRVPP'")
tiles = list(hrvpp_samples["s2_tile"].unique())
tiles

['29SPC',
 '29UNV',
 '30SVG',
 '30TXQ',
 '31UDR',
 '31UFS',
 '32TNS',
 '33UVS',
 '33VUF',
 '34TDS',
 '34WDB',
 '35TLG',
 '35VMJ']

In [10]:
def get_hvrpp_band_files(product_folder):
    band_files = {"SCL_20m": product_folder}  # Initialize with SCL_20m band
    # Searches for a match of the band name within the ID of the object
    for band in bands.values():
        if band == "SCL_20m":
            continue
        band_files[band] = Path(
            str(product_folder).replace(
                "SCENECLASSIFICATION_20M", f"TOC-{band.upper()}"
            )
        )
    return band_files

In [11]:
for tile in tiles:
    print(tile)
    acquisitions = list(
        Path("//digs110/FER/HR-VPP2/Data/TOC/v00/").glob(
            f"**/*_{tile}_SCENECLASSIFICATION*.tif"
        )
    )
    tiles_reprojected = hrvpp_samples.query("s2_tile==@tile").to_crs(
        f"EPSG:326{tile[0:2]}"
    )
    filtered_acquisitions = filter_acquisitions(acquisitions, tiles_reprojected)
    for product_folder in tqdm(filtered_acquisitions):
        timestamp = pd.Timestamp(product_folder.stem.split("_")[1])
        band_files = get_hvrpp_band_files(product_folder)

        # load data
        try:
            ds_10m = load_sentinel2_bands(band_files, "10m")
            ds_20m = load_sentinel2_bands(band_files, "20m").interp(
                x=ds_10m["x"],
                y=ds_10m["y"],
                method="nearest",
                kwargs={"fill_value": "extrapolate"},
            )
        except KeyError:
            print(f"some bands not found in manifest for tile {product_folder}")
            continue
        ds = (
            xr.merge([ds_10m, ds_20m])
            .expand_dims(dim="time")
            .assign_coords(time=[timestamp])
            .compute()
        )

        # write out chips for all samples
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            _ = tiles_reprojected.apply(
                lambda geo_series: write_zarr_chip(
                    ds, geo_series.geometry, geo_series.sample_id
                ),
                axis=1,
            )

29SPC


0it [00:00, ?it/s]


29UNV


0it [00:00, ?it/s]


30SVG


100%|██████████| 172/172 [3:20:34<00:00, 69.97s/it]  


30TXQ


100%|██████████| 1803/1803 [19:15:37<00:00, 38.46s/it]   


31UDR


 24%|██▍       | 296/1209 [3:31:03<10:51:00, 42.78s/it]


KeyboardInterrupt: 

In [14]:
def clean_zarr(zarr_id):
    x = xr.open_zarr(
        f"data/chips/{zarr_id}.zarr", mask_and_scale=False, decode_coords="all"
    )

    new_size = 128
    new_time_dim = 32

    small_chip = x.isel(
        x=slice(len(x.x) // 2 - new_size // 2, len(x.x) // 2 + new_size // 2),
        y=slice(len(x.y) // 2 - new_size // 2, len(x.y) // 2 + new_size // 2),
    )

    # only take chips with full coverage
    data_mask = (small_chip["SCL"] == 0).any(dim=("y", "x"))

    # rechunk to 32x128x128 this results in chunks of around 1MB in uncompressed size
    # ideally we would also shard, to reduce the number of files
    # however I didn't get this to work yet with xarray
    rechunked = small_chip.sel(time=~data_mask).chunk(
        {"time": new_time_dim, "y": new_size, "x": new_size}
    )
    for var in rechunked:
        del rechunked[var].encoding["chunks"]
    with warnings.catch_warnings():
        # rioxarray is warning about different scales per band, did not find a way to handle this warning
        # so we just ignore it
        warnings.filterwarnings("ignore", category=UserWarning, module="zarr")
        rechunked.to_zarr(f"data/cleaned_chips/{zarr_id}.zarr", zarr_format=3)
    # delete the original zarr
    shutil.rmtree(f"data/chips/{zarr_id}.zarr")

In [7]:
expect_last_tile = "35VMJ"
only_finished = hrvpp_samples.query("s2_tile!=@expect_last_tile")

In [19]:
for zarr_id in tqdm(list(only_finished.sample_id)[0:10]):
    try:
        clean_zarr(zarr_id)
    except Exception:
        print(f"Error processing {zarr_id}")
        continue

In [None]:
import xarray as xr
import warnings
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed


def clean_zarr(zarr_id):
    try:
        x = xr.open_zarr(
            f"data/chips/{zarr_id}.zarr", mask_and_scale=False, decode_coords="all"
        )
        new_size = 128
        new_time_dim = 32
        small_chip = x.isel(
            x=slice(len(x.x) // 2 - new_size // 2, len(x.x) // 2 + new_size // 2),
            y=slice(len(x.y) // 2 - new_size // 2, len(x.y) // 2 + new_size // 2),
        )
        data_mask = (small_chip["SCL"] == 0).any(dim=("y", "x"))
        rechunked = small_chip.sel(time=~data_mask).chunk(
            {"time": new_time_dim, "y": new_size, "x": new_size}
        )

        for var in rechunked:
            del rechunked[var].encoding["chunks"]

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning, module="zarr")
            rechunked.to_zarr(f"data/cleaned_chips/{zarr_id}.zarr", zarr_format=3)

        shutil.rmtree(f"data/chips/{zarr_id}.zarr")
    except Exception as e:
        print(f"Error processing {zarr_id}: {e}")


# Parallel execution wrapper
def process_zarr_ids_parallel(zarr_ids, max_workers=10):
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(clean_zarr, zarr_id): zarr_id for zarr_id in zarr_ids
        }
        for future in tqdm(
            as_completed(futures), total=len(futures), desc="Processing Zarr files"
        ):
            _ = future.result()  # Errors are already printed in `clean_zarr`

In [None]:
# process_zarr_ids_parallel(zarr_ids, max_workers=4)