In [None]:
import dask_geopandas
import rioxarray
import geopandas as gpd
import os
from xrspatial.zonal import stats

In [None]:
def load_raster(vpu_id, layer):
    raster_path = f'high_res_data/NHDPLUS_H_{vpu_id}_HU4_RASTERS/HRNHDPlusRasters{vpu_id}/{layer}.tif'
    if not os.path.exists(raster_path):
        raise FileNotFoundError(f"Raster file not found: {raster_path}")
    return rioxarray.open_rasterio(raster_path, chunks=True)

In [None]:
def load_vector_data(in_zone_data_path, vpu_id=None):
    #TODO add layer parameter so we can do all layers in the dask_geopandas file
    if vpu_id:
        vpu_data_path = os.path.join(in_zone_data_path, f'NHDPLUS_H_{vpu_id}_HU4_GDB.gdb')
        if not os.path.exists(vpu_data_path):
            raise FileNotFoundError(f"VPU geodatabase not found: {vpu_data_path}")
        return dask_geopandas.read_file(vpu_data_path, layer='NHDPlusCatchment')
    else:
        return dask_geopandas.read_file(in_zone_data_path, layer='NHDPlusCatchment')

In [None]:
def compute_zonal_stats(raster, vector_data):
    # Ensure the vector data is reprojected to match the raster CRS
    vector_data = vector_data.to_crs(raster.rio.crs)
    
    # Convert vector data to a rasterized form
    zones = raster.rio.clip(vector_data.geometry, vector_data.crs, drop=False)
    zones.values = vector_data['nhdplusid'].values
    
    # Compute zonal statistics
    stats_df = stats(
        zones=zones,
        values=raster,
        stats_funcs=['mean', 'median', 'min', 'max', 'count'],
        return_type='pandas.DataFrame'
    )
    
    return stats_df

In [None]:
def process_vpu(vpu_id, layer, in_zone_data_path, output_path):
    try:
        # Load raster data
        raster = load_raster(vpu_id, layer)
        
        # Load vector data
        vector_data = load_vector_data(in_zone_data_path, vpu_id)
        
        # Compute zonal statistics
        zonal_stats = compute_zonal_stats(raster, vector_data)
        
        # Save the results
        output_file = os.path.join(output_path, f'zonal_stats_{vpu_id}.csv')
        zonal_stats.to_csv(output_file)
        
        print(f"Processed VPU {vpu_id} and saved results to {output_file}")
    except Exception as e:
        print(f"Error processing VPU {vpu_id}: {e}")

In [None]:
def main(layer, in_zone_data_path, output_path, vpu_ids):
    for vpu_id in vpu_ids:
        process_vpu(vpu_id, layer, in_zone_data_path, output_path)

In [None]:
if __name__ == "__main__":
    layer = 'filldepth'
    in_zone_data_path = 'high_res_data/NHDPlus_H_National_Release_1_GDB/NHDPlus_H_National_Release_1_GDB.gdb'
    output_path = 'high_res_data/output'
    vpu_ids = [1710, 1709]
    
    main(layer, in_zone_data_path, output_path, vpu_ids)