In [None]:
from pathlib import Path
from glob import glob
import rasterio as rio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.windows import Window, from_bounds
from shapely.geometry import box
from shapely.ops import unary_union
from concurrent.futures import ProcessPoolExecutor
from tqdm.notebook import tqdm
import geopandas as gpd
import tempfile
import os
import numpy as np
import math

# --- CONFIG ---
TEMP_DIR = "/data/datasets/dakota/temp"
os.makedirs(TEMP_DIR, exist_ok=True)
NUM_PROCESSES = 24
TILE_SIZE_M = 10000     # Nominal tile size (10 km)
TILE_OVERLAP_M = 1000   # Overlap buffer (1 km)

# --- Directories ---
in_dir = Path('/home/dhester/server/dbcenter/products/land/temporalcomposites/CPB_S2L2A_time_composites')
out_dir = Path('/home/dhester/server/guser/dh/CPB_S2L2A_time_composites')

# --- Target projection and resolution ---
target_crs = 'EPSG:5070'
target_res = 10  # meters

# --- Load and prepare footprint geometries ---
footprint_gdf = gpd.read_file("../data/cpb_lc/footprint.gpkg").to_crs(target_crs)
footprint_gdf["geometry"] = footprint_gdf.simplify(100).buffer(TILE_OVERLAP_M)
footprint_union = unary_union(footprint_gdf.geometry)
footprint_bounds = footprint_union.bounds  # (minx, miny, maxx, maxy)

def reproject_to_temp(src_path, dst_crs=target_crs, dst_res=target_res, nodata_val=None):
    fd, tmp_path = tempfile.mkstemp(suffix='.tif', dir=TEMP_DIR)
    os.close(fd)  # Immediately close descriptor

    with rio.open(src_path) as src:
        if nodata_val is None:
            nodata_val = src.nodata if src.nodata is not None else np.nan

        transform, width, height = calculate_default_transform(
            src.crs, dst_crs, src.width, src.height, *src.bounds, resolution=dst_res
        )
        kwargs = src.meta.copy()
        kwargs.update({
            'crs': dst_crs,
            'transform': transform,
            'width': width,
            'height': height,
            'compress': 'lzw',
            'tiled': True,
            'blockxsize': 512,
            'blockysize': 512,
            'BIGTIFF': 'YES',
            'dtype': 'float32',
            'nodata': nodata_val
        })

        with rio.open(tmp_path, 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rio.band(src, i),
                    destination=rio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=dst_crs,
                    resampling=Resampling.lanczos,
                    src_nodata=nodata_val,
                    dst_nodata=nodata_val
                )

        # Force header flush and ensure file is readable
        with rio.open(tmp_path) as _check:
            _ = _check.meta

    return tmp_path


def process_year(year):
    year_in = in_dir / str(year) / 'annual'
    year_out = out_dir / str(year)
    year_out.mkdir(parents=True, exist_ok=True)

    rasters = glob(str(year_in / '*' / '*' / 's2_composite_mean.tif'))
    if not rasters:
        print(f'No rasters found for {year}')
        return year

    # --- Step 1: Reproject all rasters in parallel ---
    with ProcessPoolExecutor(max_workers=NUM_PROCESSES) as pool:
        tmp_files = list(tqdm(pool.map(reproject_to_temp, rasters),
                              total=len(rasters),
                              desc=f"Reprojecting {year}"))

    srcs = [rio.open(f) for f in tmp_files]

    # --- Step 2: Determine unified mosaic extent (intersected with footprint) ---
    bounds = rio.coords.BoundingBox(
        left=max(min(s.bounds.left for s in srcs), footprint_bounds[0]),
        bottom=max(min(s.bounds.bottom for s in srcs), footprint_bounds[1]),
        right=min(max(s.bounds.right for s in srcs), footprint_bounds[2]),
        top=min(max(s.bounds.top for s in srcs), footprint_bounds[3]),
    )

    # --- Step 3: Generate tiles with overlap ---
    width = math.ceil((bounds.right - bounds.left) / TILE_SIZE_M)
    height = math.ceil((bounds.top - bounds.bottom) / TILE_SIZE_M)

    for ty in tqdm(range(height), desc=f"Merging {year}"):
        for tx in range(width):
            # Expand tile by overlap buffer
            tile_left = bounds.left + tx * TILE_SIZE_M - TILE_OVERLAP_M
            tile_right = min(bounds.right, tile_left + TILE_SIZE_M + 2*TILE_OVERLAP_M)
            tile_bottom = bounds.bottom + ty * TILE_SIZE_M - TILE_OVERLAP_M
            tile_top = min(bounds.top, tile_bottom + TILE_SIZE_M + 2*TILE_OVERLAP_M)

            tile_geom = box(tile_left, tile_bottom, tile_right, tile_top)

            # Skip if no actual geometric intersection with footprint
            if not footprint_union.intersects(tile_geom):
                continue

            tile_bounds = tile_geom.bounds
            dst_transform, dst_width, dst_height = calculate_default_transform(
                target_crs, target_crs,
                left=tile_bounds[0], bottom=tile_bounds[1],
                right=tile_bounds[2], top=tile_bounds[3],
                resolution=target_res
            )

            meta = srcs[0].meta.copy()
            meta.update({
                'transform': dst_transform,
                'width': dst_width,
                'height': dst_height,
            })

            out_path = year_out / f"s2_composite_mean_mosaic_tile_{tx}_{ty}.tif"

            mosaic_block = np.full((meta['count'], dst_height, dst_width), meta['nodata'], dtype=np.float32)
            count_mask = np.zeros((dst_height, dst_width), dtype=np.uint16)

            with rio.open(out_path, 'w', **meta) as dst:
                for src in srcs:
                    if not box(*src.bounds).intersects(tile_geom):
                        continue

                    left, bottom, right, top = tile_bounds
                    left = max(left, src.bounds.left)
                    right = min(right, src.bounds.right)
                    bottom = max(bottom, src.bounds.bottom)
                    top = min(top, src.bounds.top)
                    if right <= left or top <= bottom:
                        continue

                    src_window = from_bounds(left, bottom, right, top, src.transform)
                    if src_window.width <= 0 or src_window.height <= 0:
                        continue

                    data = src.read(window=src_window, resampling=Resampling.lanczos, boundless=True, fill_value=NODATA_VALUE)
                    valid = np.any(data != meta['nodata'], axis=0)
                    mosaic_block[:, valid] += data[:, valid]
                    count_mask[valid] += 1

                valid_mask = count_mask > 0
                count_mask[~valid_mask] = 1
                mosaic_block /= count_mask
                mosaic_block[:, ~valid_mask] = meta['nodata']
                dst.write(mosaic_block.astype(meta['dtype']))

    # Cleanup
    for src in srcs:
        src.close()
    for f in tmp_files:
        try:
            os.remove(f)
        except Exception as e:
            print(f"Warning: could not remove temp file {f}: {e}")

    return year


# --- Run for all years ---
for year in tqdm(range(2020, 2025), desc='Processing all years'):
    process_year(year)

print("âœ… All mosaics complete.")


Processing all years:   0%|          | 0/5 [00:00<?, ?it/s]

Reprojecting 2020:   0%|          | 0/50 [00:00<?, ?it/s]