In [None]:
import os
import glob
import numpy as np
import rasterio
from rasterio.transform import from_origin
from netCDF4 import Dataset, num2date
from swe_calculator import load_interpolators, compute_swe
from datetime import datetime

# ========== User Parameters ==========
input_folder = "/Users/marzi/OneDrive - University of New Mexico/SNOW RESEARCH/SNOW DATA/Banded Peak Ranch/S1Retrievals/SpicySnow S1Retrievals"  # Folder with snow depth NetCDFs
snow_depth_var = "snow_depth"              # Variable name in NetCDF file
lat_var = "x"
lon_var = "y"
time_var = "time"
td_path = "/Users/marzi/Desktop/td_final.txt"
pptwt_path = "/Users/marzi/Desktop/ppt_wt_final.txt"
output_folder = "/Users/marzi/OneDrive - University of New Mexico/SNOW RESEARCH/SNOW DATA/Banded Peak Ranch/S1Retrievals/SpicySnowS1SWE"
os.makedirs(output_folder, exist_ok=True)

# ========== Load Interpolators ==========
f_td, f_ppt = load_interpolators(td_path, pptwt_path)

# ========== Process Each NetCDF File ==========
nc_files = sorted(glob.glob(os.path.join(input_folder, "BPR_snd*.nc")))

for nc_file in nc_files:
    print(f"Processing {nc_file}...")
    ds = Dataset(nc_file)

    snow_depth = ds.variables[snow_depth_var][:]  # shape: (time, lat, lon) or (lat, lon)
    lats = ds.variables[lat_var][:]
    lons = ds.variables[lon_var][:]
    time_var_data = ds.variables[time_var]
    time_units = time_var_data.units
    calendar = time_var_data.calendar if hasattr(time_var_data, 'calendar') else "standard"

    times = num2date(time_var_data[:], units=time_units, calendar=calendar)

    # Flatten the lat/lon grid (assume 2D static lat/lon)
    lon2d, lat2d = np.meshgrid(lons, lats)
    flat_lat = lat2d.ravel()
    flat_lon = lon2d.ravel()

    for t_index, dt in enumerate(times):
        print(f"  Date: {dt.strftime('%Y-%m-%d')}")
        
        # Get snow depth slice and flatten it
        H = snow_depth[t_index, :, :].ravel() * 1000  # convert m to mm if needed

        # Remove invalid/masked values
        valid = np.isfinite(H) & (H > 0)

        if not np.any(valid):
            print("  No valid snow depth data on this date.")
            continue

        Y = np.full(valid.sum(), dt.year)
        M = np.full(valid.sum(), dt.month)
        D = np.full(valid.sum(), dt.day)

        SWE, DOY = compute_swe(
            Y, M, D,
            H[valid],
            flat_lat[valid],
            flat_lon[valid],
            f_td, f_ppt
        )

        # Reconstruct SWE map
        SWE_map = np.full_like(H, np.nan)
        SWE_map[valid] = SWE
        SWE_map = SWE_map.reshape(snow_depth.shape[1:])

        # Save output as NetCDF using rasterio
        transform = from_origin(lons.min(), lats.max(), lons[1]-lons[0], lats[1]-lats[0])
        out_path = os.path.join(output_folder, f"swe_{dt.strftime('%Y%m%d')}.nc")

        with rasterio.open(
            out_path,
            'w',
            driver='NETCDF',
            height=SWE_map.shape[0],
            width=SWE_map.shape[1],
            count=1,
            dtype='float32',
            crs='EPSG:4326',
            transform=transform,
            nodata=np.nan
        ) as dst:
            dst.write(SWE_map.astype(np.float32), 1)
            dst.set_band_description(1, "SWE_mm")

    ds.close()

print("\n✅ SWE processing complete. NetCDF files saved.")
