In [None]:
import warnings
warnings.filterwarnings("ignore")

# Data manipulation and analysis
import numpy as np
import pandas as pd

# Planetary Computer tools for STAC API access and authentication
import pystac_client
import planetary_computer as pc
from odc.stac import stac_load
from pystac.extensions.eo import EOExtension as eo

from datetime import date, timedelta
from tqdm import tqdm
import os

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=pc.sign_inplace,
)

def create_bbox(lat, lon, patch_size=64, pixel_resolution=30):
    import math
    half_width_m = (patch_size * pixel_resolution) / 2
    lat_buffer = half_width_m / 111320
    lon_buffer = half_width_m / (111320 * math.cos(math.radians(lat)))
    return [
        lon - lon_buffer,
        lat - lat_buffer,
        lon + lon_buffer,
        lat + lat_buffer
    ]

def extract_landsat(row, catalog, patch_size=64, pixel_resolution=30, max_cloud_cover=10, time_window=30):
    """
    Extract a full Landsat patch as an xr.Dataset for a single sample.
    Returns the raw Dataset with all bands at full spatial resolution, or None.
    """
    lat = row['Latitude']
    lon = row['Longitude']
    sample_date = pd.to_datetime(row['Sample Date'], dayfirst=True, errors='coerce')

    bbox = create_bbox(lat=lat, lon=lon, patch_size=patch_size, pixel_resolution=pixel_resolution)

    start_date = sample_date - timedelta(days=time_window)
    end_date = sample_date + timedelta(days=time_window)
    datetime_range = f"{start_date.isoformat()}/{end_date.isoformat()}"

    search = catalog.search(
        collections=["landsat-c2-l2"],
        bbox=bbox,
        datetime=datetime_range,
        query={"eo:cloud_cover": {"lt": max_cloud_cover}},
    )

    items = search.item_collection()
    if not items:
        return None

    try:
        sample_date_utc = sample_date.tz_localize("UTC") if sample_date.tzinfo is None else sample_date.tz_convert("UTC")

        items = sorted(
            items,
            key=lambda x: abs(pd.to_datetime(x.properties["datetime"]).tz_convert("UTC") - sample_date_utc)
        )
        selected_item = pc.sign(items[0])

        ds = stac_load([selected_item], bbox=bbox).isel(time=0)

        # Attach metadata
        ds.attrs['sample_lat'] = lat
        ds.attrs['sample_lon'] = lon
        ds.attrs['sample_date'] = str(sample_date.date())
        ds.attrs['scene_datetime'] = items[0].properties["datetime"]

        return ds

    except Exception as e:
        print(f"  Error for ({lat}, {lon}): {e}")
        return None


def pad_patch(ds, patch_size=64):
    """Pad or crop all bands in a Dataset to (patch_size, patch_size)."""
    result = {}
    for var in ds.data_vars:
        arr = ds[var].values.astype("float32")
        h, w = arr.shape
        padded = np.full((patch_size, patch_size), np.nan, dtype="float32")
        padded[:min(h, patch_size), :min(w, patch_size)] = arr[:min(h, patch_size), :min(w, patch_size)]
        result[var] = padded
    return result


def save_batch(patches, metadata, batch_idx, output_dir, patch_size=64):
    """
    Combine a list of extracted patches into a single .nc file.
    
    Saved Dataset has dims (sample, y, x) with metadata as coordinates.
    """
    if not patches:
        return

    # Pad all patches to uniform size
    padded = [pad_patch(p, patch_size) for p in patches]
    band_names = list(padded[0].keys())

    # Stack into (n_samples, patch_size, patch_size) per band
    batch_ds = xr.Dataset({
        band: (["sample", "y", "x"], np.stack([p[band] for p in padded]))
        for band in band_names
    })

    # Attach metadata as coordinates on the sample dimension
    batch_ds.coords["lat"] = ("sample", [m["lat"] for m in metadata])
    batch_ds.coords["lon"] = ("sample", [m["lon"] for m in metadata])
    batch_ds.coords["sample_date"] = ("sample", [m["sample_date"] for m in metadata])
    batch_ds.coords["scene_datetime"] = ("sample", [m["scene_datetime"] for m in metadata])
    batch_ds.coords["original_index"] = ("sample", [m["original_index"] for m in metadata])

    path = os.path.join(output_dir, f"batch_{batch_idx:03d}.nc")
    batch_ds.to_netcdf(path)
    print(f"  Saved {path} ({len(patches)} patches)")

In [None]:
import xarray as xr

DATA_DIR = os.path.join("..", "..", "Data")
PATCH_SIZE = 64
BATCH_SIZE = 100

train_df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"))
val_df = pd.read_csv(os.path.join(DATA_DIR, "validation.csv"))

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
train_df.head()

### Extract Training Patches
Processes in batches of 100. Each batch is saved as a `.nc` file immediately, so progress is preserved if the run is interrupted. Already-saved batches are skipped on re-run.

In [None]:
train_output_dir = os.path.join(DATA_DIR, "landsat_patches", "train")
os.makedirs(train_output_dir, exist_ok=True)

n_batches = (len(train_df) + BATCH_SIZE - 1) // BATCH_SIZE
failed_indices = []

for batch_idx in range(n_batches):
    batch_path = os.path.join(train_output_dir, f"batch_{batch_idx:03d}.nc")

    # Skip batches that are already saved
    if os.path.exists(batch_path):
        print(f"Batch {batch_idx:03d} already exists, skipping.")
        continue

    start = batch_idx * BATCH_SIZE
    end = min(start + BATCH_SIZE, len(train_df))
    batch_df = train_df.iloc[start:end]

    patches = []
    metadata = []

    print(f"Batch {batch_idx:03d} ({start}-{end-1}):")
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"  Extracting"):
        ds = extract_landsat(row, catalog, patch_size=PATCH_SIZE)
        if ds is not None:
            patches.append(ds)
            metadata.append({
                "lat": ds.attrs["sample_lat"],
                "lon": ds.attrs["sample_lon"],
                "sample_date": ds.attrs["sample_date"],
                "scene_datetime": ds.attrs["scene_datetime"],
                "original_index": idx,
            })
        else:
            failed_indices.append(idx)

    save_batch(patches, metadata, batch_idx, train_output_dir, patch_size=PATCH_SIZE)

print(f"\nDone. {len(failed_indices)} failed samples: {failed_indices[:20]}{'...' if len(failed_indices) > 20 else ''}")

### Extract Validation Patches

In [None]:
val_output_dir = os.path.join(DATA_DIR, "landsat_patches", "validation")
os.makedirs(val_output_dir, exist_ok=True)

n_batches_val = (len(val_df) + BATCH_SIZE - 1) // BATCH_SIZE
failed_indices_val = []

for batch_idx in range(n_batches_val):
    batch_path = os.path.join(val_output_dir, f"batch_{batch_idx:03d}.nc")

    if os.path.exists(batch_path):
        print(f"Batch {batch_idx:03d} already exists, skipping.")
        continue

    start = batch_idx * BATCH_SIZE
    end = min(start + BATCH_SIZE, len(val_df))
    batch_df = val_df.iloc[start:end]

    patches = []
    metadata = []

    print(f"Batch {batch_idx:03d} ({start}-{end-1}):")
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"  Extracting"):
        ds = extract_landsat(row, catalog, patch_size=PATCH_SIZE)
        if ds is not None:
            patches.append(ds)
            metadata.append({
                "lat": ds.attrs["sample_lat"],
                "lon": ds.attrs["sample_lon"],
                "sample_date": ds.attrs["sample_date"],
                "scene_datetime": ds.attrs["scene_datetime"],
                "original_index": idx,
            })
        else:
            failed_indices_val.append(idx)

    save_batch(patches, metadata, batch_idx, val_output_dir, patch_size=PATCH_SIZE)

print(f"\nDone. {len(failed_indices_val)} failed samples: {failed_indices_val[:20]}{'...' if len(failed_indices_val) > 20 else ''}")

### Verify Saved Batches
Load a batch back to confirm structure and metadata.

In [None]:
# Load first training batch and inspect
sample_batch = xr.open_dataset(os.path.join(train_output_dir, "batch_000.nc"))
print("Dims:", dict(sample_batch.dims))
print("Bands:", list(sample_batch.data_vars))
print("Coords:", list(sample_batch.coords))
print(f"\nSample 0: lat={float(sample_batch.lat[0])}, lon={float(sample_batch.lon[0])}, date={str(sample_batch.sample_date[0].values)}")
print(f"Patch shape per band: ({sample_batch.dims['y']}, {sample_batch.dims['x']})")
sample_batch