In [None]:
# import necessary modules
import os
import numpy as np
import pandas as pd
import geopandas as gpd
from matplotlib import path
import matplotlib.pyplot as plt

import time
import xarray as xr

In [None]:
# # https://docs.xarray.dev/en/stable/user-guide/reshaping.html

# array = xr.DataArray(
#     np.random.randn(2, 2, 3), coords=[("time", [0, 3600]),("lon", ["a", "b"]), ("lat", [0, 1, 2])]
# )

# stacked = array.stack(stations=("lon", "lat"))
# stacked

In [None]:
from pyproj import CRS

epsg = 32617  # WGS 84 / UTM zone 17N
crs = CRS.from_epsg(epsg)
crs

In [None]:
# logger = setuplog("sfincs_outerbanks_hydromt", log_level=10)
from hydromt_sfincs import SfincsModel
from hydromt_sfincs import utils

ymlfile = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\01_data_analysis\dem\new_flosup.yml"
sf_qt = SfincsModel(
    data_libs=ymlfile, root="test_clip_ERA5", mode="w+"
)  # , logger=logger)


# for 400 m resolution SnapWave > refinement_level 1
file_name = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\02_model_setup\shpfiles\include_snapwave_offshore_plus_coast_refinelevel2.shp"
refine_gdf_lev1 = sf_qt.data_catalog.get_geodataframe(file_name, crs=4326)
new_refine_gdf_lev1 = refine_gdf_lev1.to_crs(crs=crs)

# 200m resolution of FLosup model > refinement_level 2
file_name = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\01_data_analysis\from_flosup\include_polygon_flosup_v5_UTM17N_tekal.pol"

refine_gdf_lev2 = utils.polygon2gdf(feats=utils.read_geoms(fn=file_name), crs=crs)

gdf_refinement = gpd.GeoDataFrame(
    {
        "refinement_level": [
            1,
            2,
        ]
    },
    # {"refinement_level": [5,]},
    geometry=[
        new_refine_gdf_lev1.unary_union,
        refine_gdf_lev2.unary_union,
    ],
    crs=crs,
)

gdf_refinement
output_file = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\02_model_setup\shpfiles\include_polygon_flosup_v5_UTM17N_tekal.geojson"
# gdf_refinement.to_file(output_file, driver='GeoJSON')

sf_qt.setup_grid(
    x0=656551,
    y0=3470589,
    dx=800.0,
    dy=800.0,
    nmax=647,
    mmax=954,
    rotation=40.0,
    epsg=epsg,  # WGS 84 / UTM zone 17N
    refinement_polygons=gdf_refinement,
)

In [None]:
filename = r"p:\archivedprojects\11206515-flosup2021\01_data\ERA5\ERA5_waves2018.nc"

era5 = xr.open_dataset(filename)
era5

### Try cliping ERA5 data to domain:

In [None]:
era5new = era5.copy()
# we need: "hs", "tp", "wd", "ds"
era5new = era5new.rename({"swh": "hs"})
era5new = era5new.rename({"mwp": "tp"})
era5new = era5new.rename({"mwd": "wd"})

# add directional spreading
era5new = era5new.assign(ds=30.0 * xr.ones_like(era5new["hs"]))

#  rename
era5new = era5new.set_coords(["longitude", "latitude"])
# ds = ds.rename_vars({"point_hm0":"hs", "point_tp":"tp", "point_wavdir":"wd", "point_dirspr":"ds"})
era5new = era5new.rename(
    {"longitude": "lon", "latitude": "lat"}
)  # , "stations": "index"})
era5new = era5new.set_coords(["lon", "lat"])

# convert coordinates
tmp = era5new.coords[
    "lon"
].values  # = era5new.coords['lon'].values - 360 #convert [0-360] to [-180,180]
tmp2 = tmp - 360
era5new.coords["lon"] = tmp2

# Fill NaN values with zero
era5new = era5new.fillna(0)
era5new.attrs["crs"] = "EPSG:4326"  # Example CRS: WGS84

era5new
# era5new.to_netcdf(r'd:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\01_data_analysis\ERA5\ERA5_waves2018_rename_added_ds.nc')

In [None]:
era5new.coords["lon"].values

In [None]:
lon = era5new.coords["lon"].values
lat = era5new.coords["lat"].values

[lon2d, lat2d] = np.meshgrid(lon, lat)

lon2dflat = lon2d.flat
lat2dflat = lat2d.flat

# Create a GeoDataFrame from the coordinates
gdf = gpd.GeoDataFrame(geometry=gpd.points_from_xy(lon2dflat, lat2dflat), crs=4326)
gdf
# # # Create a GeoDataFrame
gdf.plot()
output_file = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\02_model_setup\shpfiles\ERA5_points.geojson"
# gdf.to_file(output_file, driver='GeoJSON')

In [None]:
file_name = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\01_data_analysis\ERA5\select_era5_points.shp"
include_era5 = sf_qt.data_catalog.get_geodataframe(file_name, crs=4326)
include_era5

In [None]:
era5new

In [None]:
# Check if points are inside the polygon
gdf["inside_polygon"] = gdf.within(include_era5.unary_union)

# print(gdf['inside_polygon'] == True)
# Spatial join to check which points are within the polygon
points_within_polygon = gpd.sjoin(gdf, include_era5, how="inner", op="within")
points_within_polygon

# index = points_within_polygon.index
# index
xpoint = points_within_polygon.geometry.centroid.x
ypoint = points_within_polygon.geometry.centroid.y

# Convert the geometry column of the GeoDataFrame to a list
geometry_list = points_within_polygon["geometry"].tolist()
geometry_list

In [None]:
fig, ax = plt.subplots()
gdf.plot(ax=ax)
include_era5.plot(ax=ax, color="r", ls="--")  # , facealpha=0.5)
ax.scatter(xpoint, ypoint, color="y")

ax.set_xlim(-82, -74)
ax.set_ylim(31, 38)

In [None]:
# Create a mask to filter out the specific points
mask = xr.DataArray(
    False, dims=("lon", "lat"), coords={"lon": era5new["lon"], "lat": era5new["lat"]}
)

for point in geometry_list:
    x_coordinate = point.x
    y_coordinate = point.y

    # mask.loc[x_coordinate,y_coordinate] = True
    mask.loc[{"lon": x_coordinate, "lat": y_coordinate}] = True

# Apply the mask to extract the specific points
selected_points = era5new.where(mask, drop=False)

print(selected_points)
# print(mask)
mask.plot(x="lon", y="lat")

In [None]:
selected_points["hs"][0, :, :].plot()

In [None]:
# stack 2D grid into 1D 'stations'
stacked = selected_points.stack(stations=("lon", "lat"))
stacked

In [None]:
# remove filtered out stations
ds_cleaned = stacked.dropna(dim="stations", how="all")
ds_cleaned

In [None]:
ds_cleaned["hs"][:, 1].plot()

In [None]:
# reset_index()
ds_cleaned.reset_index("stations")

ds_reset = ds_cleaned.reset_index("stations")
ds_reset

In [None]:
# set_spatial_dims

ds_reset = ds_reset.rename({"lon": "x"})
ds_reset = ds_reset.rename({"lat": "y"})
ds_reset

In [None]:
# save:
file_name = r"d:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\01_data_analysis\ERA5\ERA5_waves2018_rename_added_ds_cleaned.nc"
ds_reset.to_netcdf(file_name)

### Actually write data now!

In [None]:
sf_qt.config["tref"] = "20180904 000000"
sf_qt.config["tstart"] = "20180910 152325"
sf_qt.config["tstop"] = "20180928 190000"

In [None]:
# ds_reprojected = ds_reset.rio.reproject(crs)
# ds_reset.raster.to_crs(self.crs)

In [None]:
df_ts = ds_reset.transpose(..., ds_reset.vector.index_dim).to_dataframe()
gdf_locs = ds_reset.vector.to_gdf()
gdf_locs2 = gdf_locs.copy()
gdf_locs2 = gdf_locs2.to_crs(crs)
gdf_locs2
# crs

In [None]:
sf_qt.setup_wave_forcing(ds_reset)

In [None]:
sf_qt.write()

In [None]:
# # sf_qt.setup_wave_forcing(geodataset = 'era5_waves_2018_rename')

# data_era5 = r'd:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\Outer_Banks_FloSup\01_data_analysis\ERA5\ERA5_waves2018_rename_added_ds.nc'
# ds =  xr.open_mfdataset(data_era5)

# ds = ds.set_coords(["longitude", "latitude"])
# # ds = ds.rename_vars({"point_hm0":"hs", "point_tp":"tp", "point_wavdir":"wd", "point_dirspr":"ds"})
# ds = ds.rename({"longitude": "lon", "latitude" : "lat"})#, "stations": "index"})
# ds = ds.set_coords(["lon", "lat"])
# #ds = ds.set_index(stations = ["station_x", "station_y"], append = True )

# lon = era5.coords['longitude'].values - 360 #convert [0-360] to [-180,180]
# lat = era5.coords['latitude'].values

# [lon2d, lat2d] = np.meshgrid(lon,lat)

# lon2dflat = lon2d.flat
# lat2dflat = lat2d.flat

# ds
# #TODO: try as flat 1D arrays > first selected within range? lat/lon/time
# # ds
# # sf_qt.setup_wave_forcing(ds)

In [None]:
# # fn = r'd:\repos\phd_waves\paper_wave_driven_flooding\02_modelling\LaJola\01_data_analysis\ndbc_station46254_2015_nearshore_clean_smooth.nc'
# fn = r'p:\archivedprojects\11206515-flosup2021\01_data\ERA5\ERA5_waves2018.nc'

# ndbc_deep = xr.open_dataset(fn)

# ndbc_deep

# # x&y-locations:
# x = ndbc_deep.station_x
# y = ndbc_deep.station_y

# # add to Geopandas dataframe as needed by HydroMT
# pnts = gpd.points_from_xy(x, y)
# index = [1]  # NOTE that the index should start at one
# bnd = gpd.GeoDataFrame(index=index, geometry=pnts, crs=4326)
# bnd


# # Wanted values:
# hs = ndbc_deep.hs #[[5.0], [5.0]]
# tp = ndbc_deep.tp #[[10.0], [10.0]]
# dir = ndbc_deep.mwd #[[290.0], [290.0]]

# # Assumption wave spreading
# s = 20.0
# sigma = np.sqrt(2 / (s + 1)) / np.pi * 180.0 #s  =    20 > to degress

# # ds = hs.copy()
# # ds = ds.where(ds > 0, sigma)
# ds = sigma * np.ones_like(hs)
# # ds.values = sigma

# ds
# # ds['hs'].values = sigma
# # # ds = sigma * np.ones_like(hs) #[[sigma], [sigma]]
# time = ndbc_deep.time# [0, 99999]

# df_hs = pd.DataFrame(index= time, data = hs)
# df_tp = pd.DataFrame(index= time, data = tp)
# df_dir = pd.DataFrame(index= time, data = dir)
# df_ds = pd.DataFrame(index= time, data = ds)

# list_df = [df_hs, df_tp, df_dir, df_ds]

# list_df

# sf_qt.setup_wave_forcing(timeseries = list_df, locations= bnd)