### Take ERA5 data and compute the weighted sums or averages over the water basins

In [2]:
import xarray as xr
import geopandas as gpd
import pandas as pd
import numpy as np
from pathlib import Path
from shapely.geometry import box
from shapely.prepared import prep
from glob import glob
import warnings
warnings.filterwarnings('ignore')

def create_grid_cells(lons, lats):
    """
    Create polygon geometries for each grid cell.
    Returns a dictionary mapping (i, j) indices to cell geometry.
    """
    grid_cells = {}
    
    # Calculate cell boundaries
    lat_res = abs(lats[1] - lats[0]) if len(lats) > 1 else 0.25
    lon_res = abs(lons[1] - lons[0]) if len(lons) > 1 else 0.25
    
    for i, lat in enumerate(lats):
        for j, lon in enumerate(lons):
            # Create cell boundaries (center ± half resolution)
            min_lon = lon - lon_res / 2
            max_lon = lon + lon_res / 2
            min_lat = lat - lat_res / 2
            max_lat = lat + lat_res / 2
            
            # Create box geometry
            cell_geom = box(min_lon, min_lat, max_lon, max_lat)
            grid_cells[(i, j)] = cell_geom
    
    return grid_cells

def calculate_basin_weights(basins, grid_cells, station_nums):
    """
    Calculate fractional overlap weights for specified basins with grid cells.
    Returns dict mapping basin_id -> {(i,j): overlap_fraction}
    """
    basin_weights = {}
    
    # Filter basins to only those in station_nums
    basins_filtered = basins[basins['StationNum'].isin(station_nums)]
    
    print(f"Calculating overlaps for {len(basins_filtered)} basins...")
    
    for idx, basin in basins_filtered.iterrows():
        basin_id = basin['StationNum']
        basin_geom = basin.geometry
        
        # Prepare geometry for faster intersection tests
        prepared_basin = prep(basin_geom)
        
        weights = {}
        
        # Check each grid cell for overlap
        for (i, j), cell_geom in grid_cells.items():
            # Quick bbox test first
            if not prepared_basin.intersects(cell_geom):
                continue
            
            # Calculate actual intersection
            try:
                intersection = basin_geom.intersection(cell_geom)
                if not intersection.is_empty:
                    # Calculate intersection area
                    overlap_area = intersection.area
                    cell_total_area = cell_geom.area
                    
                    # Weight is the fraction of the cell that overlaps
                    weights[(i, j)] = overlap_area / cell_total_area
            except Exception as e:
                print(f"  Warning: Error calculating overlap for basin {basin_id}, cell ({i},{j}): {e}")
                continue
        
        if len(weights) > 0:
            basin_weights[basin_id] = weights
            if idx % 10 == 0:  # Progress update every 10 basins
                print(f"  Processed {idx + 1}/{len(basins_filtered)} basins...")
        else:
            print(f"  Warning: No grid cell overlap found for basin {basin_id}")
    
    return basin_weights

def process_netcdf_file(nc_path, basin_weights, lats, lons):
    """
    Process a single NetCDF file and return daily precipitation for all basins.
    Returns DataFrame with columns: date, StationNum, total_precip_mm
    """
    print(f"  Processing {nc_path.name}...")
    
    # Open dataset
    ds = xr.open_dataset(nc_path)
    
    # Identify precipitation variable
    precip_var = None
    for var in ['tp', 'precip', 'precipitation', 'total_precipitation']:
        if var in ds.variables:
            precip_var = var
            break
    
    if precip_var is None:
        raise ValueError(f"Could not find precipitation variable in {nc_path}")
    
    # Get time dimension
    if 'valid_time' in ds.coords:
        times = pd.to_datetime(ds['valid_time'].values)
        time_dim = 'valid_time'
    elif 'time' in ds.coords:
        times = pd.to_datetime(ds['time'].values)
        time_dim = 'time'
    else:
        raise ValueError(f"No time dimension found in {nc_path}")
    
    # Store results for this file
    results = []
    
    # Process each time step
    for t_idx, time in enumerate(times):
        # Get precipitation data for this time step
        precip_data = ds[precip_var].isel(**{time_dim: t_idx}).values
        
        # ERA5 precipitation is in meters, convert to mm
        precip_mm = precip_data * 1000
        
        # Calculate precipitation for each basin
        for basin_id, weights in basin_weights.items():
            total_precip = 0
            total_weight = 0
            
            for (i, j), weight in weights.items():
                cell_precip = precip_mm[i, j]
                if not np.isnan(cell_precip):
                    total_precip += cell_precip * weight
                    total_weight += weight
            
            # Total precipitation is area-weighted sum
            if total_weight > 0:
                basin_total_precip = total_precip / total_weight
            else:
                basin_total_precip = 0
            
            results.append({
                'date': time,
                'StationNum': basin_id,
                'total_precip_mm': basin_total_precip
            })
    
    ds.close()
    
    return pd.DataFrame(results)

def compute_basin_precipitation_multifile(
    precip_dir,
    shapefile_path,
    station_metadata_path,
    output_dir='basin_averaged_climate_data'
):
    """
    Compute total daily precipitation for specified water basins across multiple NetCDF files.
    
    Parameters:
    -----------
    precip_dir : str
        Directory containing ERA5 precipitation NetCDF files
    shapefile_path : str
        Path to basin shapefile
    station_metadata_path : str
        Path to CSV with station metadata (must include StationNum column)
    output_dir : str
        Directory to save output CSV file
    """
    
    # Create output directory
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)
    
    # Load station metadata to get list of stations to process
    print("Loading station metadata...")
    station_metadata = pd.read_csv(station_metadata_path)
    station_nums = station_metadata['StationNum'].tolist()
    print(f"Found {len(station_nums)} stations to process")
    
    # Read the basin shapefile
    print("\nLoading basin shapefile...")
    basins = gpd.read_file(shapefile_path)
    
    # Ensure the basins have a CRS
    if basins.crs is None:
        print("Warning: Shapefile has no CRS, assuming EPSG:4326 (WGS84)")
        basins.set_crs("EPSG:4326", inplace=True)
    
    # Reproject basins to WGS84 if needed
    if basins.crs.to_string() != "EPSG:4326":
        print(f"Reprojecting basins from {basins.crs} to EPSG:4326")
        basins = basins.to_crs("EPSG:4326")
    
    print(f"Total basins in shapefile: {len(basins)}")
    print(f"Basins matching station list: {len(basins[basins['StationNum'].isin(station_nums)])}")
    
    # Get a sample NetCDF file to determine grid structure
    print("\nReading grid structure from sample NetCDF file...")
    precip_files = sorted(glob(str(Path(precip_dir) / "era5_precip_*.nc")))
    
    if len(precip_files) == 0:
        raise ValueError(f"No NetCDF files found in {precip_dir}")
    
    print(f"Found {len(precip_files)} NetCDF files to process")
    
    # Load first file to get coordinate structure
    sample_ds = xr.open_dataset(precip_files[0])
    
    # Extract coordinates
    if 'latitude' in sample_ds.coords:
        lat_name, lon_name = 'latitude', 'longitude'
    elif 'lat' in sample_ds.coords:
        lat_name, lon_name = 'lat', 'lon'
    else:
        raise ValueError("Could not identify latitude/longitude coordinates")
    
    lats = sample_ds[lat_name].values
    lons = sample_ds[lon_name].values
    sample_ds.close()
    
    print(f"Grid dimensions: {len(lats)} lat × {len(lons)} lon")
    
    # Create grid cell geometries (only need to do this once)
    print("\nCreating grid cell geometries...")
    grid_cells = create_grid_cells(lons, lats)
    
    # Calculate basin-grid overlaps (only need to do this once)
    print("\nCalculating basin-grid overlaps...")
    basin_weights = calculate_basin_weights(basins, grid_cells, station_nums)
    
    print(f"\nSuccessfully calculated overlaps for {len(basin_weights)} basins")
    
    # Process each NetCDF file
    print("\nProcessing NetCDF files...")
    all_results = []
    
    for nc_file in precip_files:
        nc_path = Path(nc_file)
        df = process_netcdf_file(nc_path, basin_weights, lats, lons)
        all_results.append(df)
    
    # Combine all results
    print("\nCombining results...")
    combined_df = pd.concat(all_results, ignore_index=True)
    combined_df['date'] = pd.to_datetime(combined_df['date']).dt.date
    combined_df = combined_df.sort_values(['date', 'StationNum']).reset_index(drop=True)
    
    # Pivot to wide format: rows = dates, columns = station numbers
    print("\nPivoting to wide format (dates × stations)...")
    wide_df = combined_df.pivot(index='date', columns='StationNum', values='total_precip_mm')
    wide_df = wide_df.sort_index()
    
    # Reset index to make date a column
    wide_df = wide_df.reset_index()
    
    # Save to single CSV file
    output_file = output_path / "basin_total_precipitation.csv"
    wide_df.to_csv(output_file, index=False)
    print(f"\nSaved results to: {output_file}")
    print(f"Shape: {wide_df.shape[0]} days × {wide_df.shape[1]-1} stations")
    
    # Print summary statistics
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    print(f"Date range: {wide_df['date'].min()} to {wide_df['date'].max()}")
    print(f"Number of stations: {wide_df.shape[1] - 1}")
    print(f"Total days: {len(wide_df)}")
    print(f"\nPrecipitation statistics (mm/day):")
    # Get stats from all station columns (excluding date)
    precip_values = wide_df.iloc[:, 1:].values.flatten()
    precip_values = precip_values[~np.isnan(precip_values)]
    print(f"  Mean: {np.mean(precip_values):.2f}")
    print(f"  Std: {np.std(precip_values):.2f}")
    print(f"  Min: {np.min(precip_values):.2f}")
    print(f"  Max: {np.max(precip_values):.2f}")
    
    return wide_df

if __name__ == "__main__":
    # Run the analysis
    results = compute_basin_precipitation_multifile(
        precip_dir="ERA5_data/ERA5_precip",
        shapefile_path="derived_shapefiles/natural_watersheds.shp",
        station_metadata_path="station_cluster_metadata.csv",
        output_dir="basin_averaged_climate_data"
    )
    
    print("\n" + "="*60)
    print("Processing complete!")
    print("="*60)

Loading station metadata...
Found 111 stations to process

Loading basin shapefile...
Reprojecting basins from PROJCS["Canada_Albers_Equal_Area_Conic",GEOGCS["NAD83",DATUM["North_American_Datum_1983",SPHEROID["GRS 1980",6378137,298.257222101,AUTHORITY["EPSG","7019"]],AUTHORITY["EPSG","6269"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4269"]],PROJECTION["Albers_Conic_Equal_Area"],PARAMETER["latitude_of_center",40],PARAMETER["longitude_of_center",-96],PARAMETER["standard_parallel_1",50],PARAMETER["standard_parallel_2",70],PARAMETER["false_easting",0],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["ESRI","102001"]] to EPSG:4326
Total basins in shapefile: 555
Basins matching station list: 111

Reading grid structure from sample NetCDF file...
Found 516 NetCDF files to process
Grid dimensions: 61 lat × 97 lon

Creating grid cell geom