This notebook downloads a collection of hard negative examples from the Planetary Computer

The only thing you will need to change in this is the base_dataset_dir to a local drive with 300 GB of available storage

In [1]:
from pathlib import Path
import numpy as np
import rasterio as rio
from tqdm.auto import tqdm
import pandas as pd
from datetime import datetime
import xarray as xr
import cubo
from rasterio.transform import from_origin
from rasterio.crs import CRS
from datetime import timedelta
from multiprocessing.pool import ThreadPool

In [None]:
base_dataset_dir = Path("/media/nick/4TB Working 7/Datasets/OCM datasets")


In [None]:
hard_negative_dir = base_dataset_dir / "Hard negative"
hard_negative_dir.mkdir(parents=True, exist_ok=True)

In [None]:
dataset_csv_path = Path("OCM hard negative dataset metadata.csv")

In [None]:
EDGE_PX = 509
RES_M = 10  # 10 m request
BANDS = ["B04", "B03", "B8A"]  # Red, Green, NIR-narrow

In [None]:
dataset_metadata = pd.read_csv(dataset_csv_path)

In [None]:
def fetch_stack(
    lat: float, lon: float, start_date: datetime, end_date: datetime, scene_id: str
) -> xr.DataArray:
    da = cubo.create(
        lat=lat,
        lon=lon,
        collection="sentinel-2-l2a",
        bands=BANDS,
        start_date=start_date.strftime("%Y-%m-%d"),
        end_date=end_date.strftime("%Y-%m-%d"),
        edge_size=EDGE_PX,
        resolution=RES_M,
    )
    da = da.where(da["id"] == scene_id, drop=True)
    return da

In [None]:
def save_geotiff(
    path: Path, chw_u16: np.ndarray, ref_da: xr.DataArray, crs: str
) -> None:
    """
    chw_u16 C H W uint16
    ref_da is your xarray DataArray with coords x and y and attrs with epsg or proj:code
    Writes a GeoTIFF with correct transform and CRS
    """
    chw_u16 = chw_u16[:, :EDGE_PX, :EDGE_PX]  # crop to requested size
    C, H, W = chw_u16.shape

    x = ref_da.x.values
    y = ref_da.y.values
    # pixel size from coords
    resx = float(np.mean(np.diff(x)))
    resy = float(abs(np.mean(np.diff(y))))

    # top left corner from centre coords
    west = float(x[0] - resx / 2.0)
    north = float(y[0] + resy / 2.0)
    transform = from_origin(west, north, resx, resy)

    crs = CRS.from_epsg(crs)

    with rio.open(
        path,
        "w",
        driver="GTiff",
        height=H,
        width=W,
        count=C,
        dtype="uint16",
        crs=crs,
        compress="lzw",
        transform=transform,
    ) as dst:
        dst.write(chw_u16)


In [None]:
def make_mask(image_path):
    src = rio.open(image_path)
    mask = np.zeros([1, 509, 509]).astype(np.uint8)
    file_name = image_path.name.replace("_image_l2a.tif", "_label.tif")
    assert file_name != image_path.name
    profile = src.profile
    profile.update(dtype=rio.uint8, count=1)
    mask_path = hard_negative_dir / file_name
    with rio.open(mask_path, "w", **profile) as dst:
        dst.write(mask)
    src.close()

In [None]:
def download_image(row):
    try:
        out_path = hard_negative_dir / row["filename"]
        if out_path.exists():
            make_mask(out_path)
            return

        lat = row["center_lat"]
        lon = row["center_lon"]
        date_time = datetime.strptime(row["date_time"], "%Y%m%dT%H%M%S")
        crs = row["crs"]
        start_date = date_time - timedelta(days=3)
        end_date = date_time + timedelta(days=3)
        scene_id = row["source_scene"]

        da = fetch_stack(lat, lon, start_date, end_date, scene_id)
        array = da.isel(time=0).values
        array = np.nan_to_num(array, nan=0).astype("uint16")

        save_geotiff(out_path, array, da.isel(time=0), crs)
        make_mask(out_path)
    except Exception as e:
        print(f"Error processing row {row.name}: {e}")

In [None]:
with ThreadPool(processes=2) as pool:
    _ = list(
        tqdm(
            pool.imap(download_image, [row for _, row in dataset_metadata.iterrows()]),
            total=len(dataset_metadata),
        )
    )