In [None]:
import pcxarray as pcx
import rioxarray as rxr
import geopandas as gpd
from tqdm import tqdm
from warnings import filterwarnings
from joblib import Parallel, delayed, dump, load
from shapely import STRtree
from shapely import Point, Geometry
from rtree import index
import numpy as np
import odc.geo.xr
from rasterio.enums import Resampling
from sklearn.linear_model import Ridge, LinearRegression
import os
from threading import Lock
import pandas as pd
from planetary_computer import sign_url
import rasterio as rio
import math

filterwarnings('ignore')

max_num_samples = 500000
min_sample_dist = 2000 # minimum distance between samples
out_dir = os.path.join('..', 's2_naip_pairs')
target_resolution = 1.0
tile_size = 256

os.makedirs(out_dir, exist_ok=True)

In [None]:
conus_gdf = pcx.utils.load_census_shapefile(level='state')
conus_gdf = conus_gdf.loc[~conus_gdf['STUSPS'].isin(['AK', 'HI', 'PR', 'GU', 'VI', 'MP', 'AS'])]
# south_states = ["AL","AR","DE","DC","FL","GA","KY","LA","MD","MS","NC","OK","SC","TN","TX","VA","WV"]
# conus_gdf = conus_gdf.loc[conus_gdf['STUSPS'].isin(south_states)]
conus_geom = conus_gdf.dissolve().buffer(0.1).simplify(0.1).geometry.iloc[0]

In [3]:
naip_observations_gdf = pd.concat([
    pcx.pc_query(
        collections='naip',
        geometry=conus_geom,
        crs=conus_gdf.crs,
        datetime='2021',
    ),
    pcx.pc_query(
        collections='naip',
        geometry=conus_geom,
        crs=conus_gdf.crs,
        datetime='2022',
    )
])

In [4]:
def _query_row(geom, date_str, crs):
    try:
        res = pcx.pc_query(
            collections="sentinel-2-l2a",
            geometry=geom,
            crs=crs,
            datetime=date_str,
        )
        if res is None or len(res) == 0:
            return []
        else:
            return [geom.intersection(g) for g in res.geometry]
    except Exception:
        return []

inputs = []
for _, row in naip_observations_gdf.iterrows():
    dt = row['properties.datetime']
    if dt is None:
        continue
    inputs.append((row.geometry, dt.date().isoformat(), naip_observations_gdf.crs))

n_jobs = min(16, (len(inputs) or 1))  # tune for your machine / network
matching_regions = []
with tqdm(total=len(inputs)) as pbar:
    def _wrap(args):
        res = _query_row(*args)
        pbar.update()
        return res
    results = Parallel(n_jobs=n_jobs, prefer="threads", batch_size=1)(
        delayed(_wrap)(arg) for arg in inputs[:1000]
    )

for r in results:
    matching_regions.extend(r)
candidate_regions_gdf = gpd.GeoDataFrame({'geometry': matching_regions}, crs=naip_observations_gdf.crs)

candidate_regions_gdf = gpd.GeoDataFrame({'geometry': matching_regions}, crs=naip_observations_gdf.crs)
candidate_regions_gdf = candidate_regions_gdf.dissolve().explode().reset_index()
candidate_regions_gdf = candidate_regions_gdf.to_crs(5070)
candidate_regions_gdf.to_parquet(os.path.join(out_dir, 'candidate_regions.par'))

  2%|▏         | 1000/61975 [00:50<50:55, 19.95it/s] 


In [None]:
valid_regions_strtree = STRtree(candidate_regions_gdf.geometry)
invalid_regions_rtree = index.Index() 
pixel_buffer = int(math.ceil((tile_size * target_resolution) / 2.0))

minx, miny, maxx, maxy = candidate_regions_gdf.total_bounds

n_points = max_num_samples
sampled_points = []
np.random.seed(1701)

with tqdm(desc='Sampling points', total=n_points) as pbar:
    while len(sampled_points) < n_points:
        
        sample_x = np.random.uniform(minx, maxx)
        sample_y = np.random.uniform(miny, maxy)
        sample_point = Point(sample_x, sample_y)
        
        is_valid_region = False
        for potential_intersection in valid_regions_strtree.query(sample_point):
            if candidate_regions_gdf.iloc[potential_intersection].geometry.contains(sample_point.buffer(pixel_buffer)):
                is_valid_region = True
                break
        
        if not is_valid_region:
            continue
        
        for potential_intersection in invalid_regions_rtree.intersection(sample_point.bounds):
            if sampled_points[potential_intersection].buffer(min_sample_dist).intersects(sample_point):
                is_valid_region = False
                break
            
        if not is_valid_region:
            continue
        
        invalid_regions_rtree.insert(len(sampled_points), sample_point.buffer(min_sample_dist).bounds)
        sampled_points.append(sample_point)
        pbar.update(1)


Sampling points: 100%|██████████| 1000/1000 [00:01<00:00, 764.60it/s]


In [6]:
sampled_points_gdf = gpd.GeoDataFrame(geometry=sampled_points, crs=candidate_regions_gdf.crs)
n_digits = len(str(len(sampled_points_gdf) - 1))
sampled_points_gdf['id'] = sampled_points_gdf.index.map(lambda x: str(x).zfill(n_digits))
sampled_points_gdf.to_parquet(os.path.join(out_dir, 'sampled_points.par'))

In [8]:
for i in range(len(sampled_points_gdf)):
    try:
        lcmap_item_gdf = pcx.pc_query(
            collections='io-lulc-annual-v02',
            geometry=sampled_points_gdf.iloc[i].geometry,
            crs=sampled_points_gdf.crs,
            datetime='2021-06-01',
        )

        url = lcmap_item_gdf.iloc[0]['assets.data.href']
        with rio.open(sign_url(url)) as src:
            cmap = src.colormap(1)
        
        break
    except:
        continue

In [None]:
valid_scl_values = [4, 5, 6, 0] 

os.makedirs(os.path.join(out_dir, 'sentinel2'), exist_ok=True)
os.makedirs(os.path.join(out_dir, 'naip'), exist_ok=True)
os.makedirs(os.path.join(out_dir, 'lcmap'), exist_ok=True)

def process_sample(
    sample_row,
    crs,
    out_dir,
    valid_scl_values=(0, 4, 5, 6), # nodata (will check again), vegetation, not-vegetated, water
    target_resolution=target_resolution,
    resampling_method=Resampling.lanczos,
    tile_size=tile_size,
    write_opts=None
) -> None:
    
    try:
        sid = sample_row.id
        pixel_buffer = int(math.ceil((tile_size * target_resolution) / 2.0))
        geom = sample_row.geometry.buffer(pixel_buffer).envelope

        naip_items_gdf = pcx.pc_query(
            collections='naip',
            geometry=geom,
            crs=crs,
            datetime='2021'
        )
        if naip_items_gdf is None or len(naip_items_gdf) == 0:
            return 

        if len(naip_items_gdf) > 1:
            if not all(naip_items_gdf['properties.datetime'] == naip_items_gdf['properties.datetime'].iloc[0]):
                return 

        naip_dt = naip_items_gdf['properties.datetime'].iloc[0].date().isoformat()

        s2_items_gdf = pcx.pc_query(
            collections='sentinel-2-l2a',
            geometry=geom,
            crs=crs,
            datetime=naip_dt
        )
        if s2_items_gdf is None or len(s2_items_gdf) == 0:
            return

        s2_scl = pcx.prepare_data(
            s2_items_gdf, 
            geometry=geom, 
            crs=crs, 
            bands=['SCL']
        )
        if not s2_scl.fillna(0).isin(valid_scl_values).all():
            return

        naip_xr = pcx.prepare_data(
            naip_items_gdf,
            geometry=geom,
            crs=crs,
            target_resolution=target_resolution,
            resampling_method='bilinear',
            all_touched=True
        )
        if naip_xr.isnull().any():
            return 

        s2_xr = pcx.prepare_data(
            s2_items_gdf,
            geometry=geom,
            crs=crs,
            bands=['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'],
            all_touched=True,
        )
        if s2_xr.isnull().any():
            return 

        naip_downsampled_xr = naip_xr.rio.reproject_match(s2_xr, resampling=resampling_method)

        X = naip_downsampled_xr.values.reshape(naip_downsampled_xr.shape[0], -1).T
        Y = s2_xr.sel(band=['B04', 'B03', 'B02', 'B08']).values
        Y = Y.reshape(Y.shape[0], -1).T

        valid_idx = np.where(~np.isnan(X).any(axis=1))[0]
        if valid_idx.size == 0:
            return 
        X = X[valid_idx]
        Y = Y[valid_idx]

        ridge = Ridge(solver='cholesky')
        ridge.fit(X, Y)

        X_full = naip_xr.values.reshape(naip_xr.shape[0], -1).T
        naip_harmonized = ridge.predict(X_full)
        naip_harmonized = naip_harmonized.T.reshape(naip_xr.shape)

        naip_harmonized_xr = naip_xr.copy()
        naip_harmonized_xr.values = naip_harmonized
        naip_harmonized_xr = naip_harmonized_xr.assign_coords(band=["B04", "B03", "B02", "B08"])

        try:
            naip_harmonized_xr = naip_harmonized_xr.isel(x=slice(0, tile_size), y=slice(0, tile_size))
            s2_match = s2_xr.rio.reproject_match(naip_harmonized_xr, resampling=resampling_method)
            s2_match = s2_match.isel(x=slice(0, tile_size), y=slice(0, tile_size))
        except Exception as e:
            return 

        if naip_harmonized_xr.shape[1:] != (tile_size, tile_size) or s2_match.shape[1:] != (tile_size, tile_size):
            return 
        
        lcmap_items_gdf = pcx.pc_query(
            collections='io-lulc-annual-v02',
            geometry=geom,
            crs=crs,
            datetime=naip_dt,
        )
        
        lcmap_xr = pcx.prepare_data(
            lcmap_items_gdf,
            geometry=geom,
            crs=crs,
            all_touched=True,
        )
        if lcmap_xr.isnull().any():
            return
        lcmap_xr = lcmap_xr.rio.reproject_match(naip_xr, resampling=Resampling.nearest)
        lcmap_xr = lcmap_xr.isel(x=slice(0, tile_size), y=slice(0, tile_size))

        # clip and cast
        naip_out = naip_harmonized_xr.clip(0, 10000).astype(np.uint16)
        s2_out = s2_match.clip(0, 10000).astype(np.uint16)
        lcmap_out = lcmap_xr.clip(0, 255).astype(np.uint8)

        # default write options
        if write_opts is None:
            write_opts = dict(
                driver='GTiff',
                tiled=True,
                blockxsize=tile_size,
                blockysize=tile_size,
                compress='lzw',
                interleave='pixel',
            )

        naip_path = os.path.join(out_dir, 'naip', f"{sid}.tif")
        s2_path = os.path.join(out_dir, 'sentinel2', f"{sid}.tif")
        lcmap_path = os.path.join(out_dir, 'lcmap', f"{sid}.tif")

        naip_out.rio.to_raster(naip_path, **write_opts)

        s2_out = s2_out.rio.write_nodata(None)
        s2_out.rio.to_raster(s2_path, **write_opts)
        
        lcmap_out.rio.to_raster(lcmap_path, **write_opts)
        with rio.open(lcmap_path, 'r+') as src:
            src.write_colormap(1, cmap)
        
        return 
    
    except Exception as e:
        return 

rows = list(sampled_points_gdf.itertuples())
n_jobs = min(16, max(1, os.cpu_count() // 2))  # tune

pbar = tqdm(total=len(rows), desc="Processing samples", unit="sample")
pbar_lock = Lock()

def _wrap_process_sample(row, crs, out_dir, *args, **kwargs):
    try:
        return process_sample(row, crs, out_dir, *args, **kwargs)
    finally:
        with pbar_lock:
            pbar.update(1)

# run in threads so shared tqdm updates work
results = Parallel(n_jobs=n_jobs, prefer="threads", batch_size=1)(
    delayed(_wrap_process_sample)(row, sampled_points_gdf.crs, os.path.join('..', 's2_naip_pairs'))
    for row in rows
)

# results = [process_sample(row, sampled_points_gdf.crs, out_dir) for row in rows[:100]]

pbar.close()


Processing samples:   5%|▍         | 49/1000 [01:57<50:34,  3.19s/sample]  

KeyboardInterrupt: 

Processing samples:   5%|▌         | 54/1000 [02:11<39:54,  2.53s/sample]  