In [96]:
### Imports
from datetime import timedelta

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import xarray as xr
from pyproj import Transformer


In [None]:
DATA = "https://opendap.4tu.nl/thredds/dodsC/data2/test/spatial/15_days_avg_std.nc"  # change for your folder with data if run locally
URL_DATA = "https://opendap.4tu.nl/thredds/catalog/data2/test/spatial/catalog.html"
LOCAL = False
#to change later
ds = xr.open_dataset(DATA)

In [4]:
from datetime import datetime
# Date format is (year, month, day) for example (2000, 1, 1)
start_date = datetime(
    1976, 1, 1
)  # Displayed timeslots starts from this date or later, and new timeslot starts at 01:00:00
end_date = datetime(
    1976, 4, 1
)  # Displayed timeslots ends at this date or earlier, and last timeslot ends at 00:00:00

# TEST

In [80]:
variable_name = "S"

In [None]:
# Extract boundary points
boundary_points = ds.bdr_dws.values

# Find the indices of the time steps that correspond to the chosen dates
delta_left = timedelta(days=7.5)
delta_right = timedelta(days=7.5) - timedelta(
    hours=1
)  # so as to show the exact period
time_steps = ds["time"].values
time_steps = pd.to_datetime(time_steps)
mask_ind = (time_steps - delta_left >= start_date) & (
    time_steps + delta_right <= end_date
)

# Extract the data and flip the arrays so that the origin is at the bottom left
# (y axis is inverted later beacuse of the way plotly displays the data)
avg = (
    ds["S_avg"].values[mask_ind]
    if variable_name == "S"
    else ds["T_avg"].values[mask_ind]
)
sd = (
    ds["S_sd"].values[mask_ind]
    if variable_name == "S"
    else ds["T_sd"].values[mask_ind]
)
# # Change for avg and sd display - rotations
# xh = ds["xh"].values
# yh = ds["yh"].values
# Xh, Yh = np.meshgrid(xh, yh)
# Z = np.sin(Xh) * np.cos(Yh)

# Extract the time steps for given time period
time_steps_update = time_steps[mask_ind]
merged_data = np.stack([avg, sd], axis=1)

ds.close()

# Create the figure
fig = px.imshow(
    merged_data,
    x=ds["xc"].values,
    y=ds["yc"].values,
    facet_col=1,
    animation_frame=0,
    origin="lower",
    title=("Salinity" if variable_name == "S" else "Temperature")
    + " : 15 days average (in facet_col=0) and standard deviation (in facet_col=1)",
)

# if LOCAL:
# Add boundary to the first facet
fig.add_trace(
    go.Scatter(
        x=boundary_points[:, 0],
        y=boundary_points[:, 1],
        mode="lines",
        line=dict(color="black", width=2),
        name="",
        showlegend=False,
    ),
    row=1,  # First facet
    col=1,
)

# Add boundary to the second facet
fig.add_trace(
    go.Scatter(
        x=boundary_points[:, 0],
        y=boundary_points[:, 1],
        mode="lines",
        line=dict(color="black", width=2),
        name="",
        showlegend=False,
    ),
    row=1,  # Second facet
    col=2,
)

# Drop animation buttons
fig["layout"].pop("updatemenus")

# Modify the colorbar
fig.update_layout(
    coloraxis=dict(
        cmin=0,
        cmax=int(np.nanmax(merged_data)) + 1,
        colorbar=dict(
            title=(
                "Salinity (g kg<sup>-1</sup>)"
                if variable_name == "S"
                else "Temperature (°C)"
            ),
        ),
    )
)

# Modify the layout x and y axis
for i in range(1, merged_data.shape[1]):
    fig.update_layout(
        **{
            f"xaxis{i}": dict(title="xc", tickformat=".1f"),
        },
        **{
            f"yaxis{i}": dict(title="yc", tickformat=".1f"),
        },
    )

# Add slider
fig.update_layout(
    sliders=[
        {
            "currentvalue": {
                "prefix": "15 days time slot: ",
                "visible": True,
                "xanchor": "center",
            },
            "len": 0.9,
            "steps": [
                {
                    "label": f'{(time_steps_update[i]-delta_left).strftime("%d/%m/%Y") }-{(time_steps_update[i]+delta_right).strftime("%d/%m/%Y") }',
                    "method": "animate",
                    "args": [[i], {"frame": {"duration": 500, "redraw": True}}],
                }
                for i in range(len(time_steps_update))
            ],
        }
    ],
)

fig.show()

#### Matrix rotation

In [None]:
import matplotlib.pyplot as plt
dws_b = xr.open_dataset("") #add proper file name
val = ds["S_avg"].values[0]

#### rotation
#from epgs:4326(LatLon with WGS84) to epgs:28992(DWS) 
inproj = Transformer.from_crs('epsg:4326','epsg:28992',always_xy=True)
xct=ds.lonc.values;  yct=ds.latc.values #lon,lat units #to change later for reading from dws_b
xctp,yctp,z = inproj.transform(xct,yct,xct*0.)
xctp=(xctp)/1e2; yctp=(yctp)/1e2
#first projected point to correct the coordinates of model local meter units
xctp0=xctp[0,0]; yctp0=yctp[0,0]

#matrix rotation -17degrees-----
ang=-17*np.pi/180
angs=np.ones((2,2))
angs[0,0]=np.cos(ang); angs[0,1]=np.sin(ang)
angs[1,0]=-np.sin(ang); angs[1,1]=np.cos(ang)

#original topo points in meter
xct2,yct2=np.meshgrid(dws_b.xc.values,dws_b.yc.values)
xy=np.array([xct2.flatten(),yct2.flatten()]).T
#rotate
xyp=np.matmul(angs,xy.T).T/1e2
xyp0=xyp[0,:] #the first point in the bathy data in local meter units=0,0

#DWS area
values = dws_b.mask_dws.values
xc = dws_b.xc
yc = dws_b.yc
y_idx, x_idx = np.where(values) #True values
x_true = xc[x_idx]
y_true = yc[y_idx]
points = np.column_stack((x_true, y_true))

#rotate values
points_rot = np.matmul(angs,points.T).T/1e2
points_rot=points_rot-xyp0 
points_rot[:,0]=points_rot[:,0]+xctp0; points_rot[:,1]=points_rot[:,1]+yctp0

In [177]:
mask = np.full((6400, 6400), np.nan)

for idx in range(0, len(x_idx)):
    mask[int(points_rot[idx][1])-1:int(points_rot[idx][1])+2, int(points_rot[idx][0])-1:int(points_rot[idx][0])+2] = val[y_idx[idx], x_idx[idx]]

    
mask = mask[ 5400:6400, 1000:2400]

In [None]:
# Plotting
fig, ax = plt.subplots(figsize=(10,5))

# Plot original mask
cs = ax.imshow(mask,)
plt.title("Original - Square (7x7)")
plt.legend()
plt.gca().invert_yaxis()

yticks=np.arange(0,1000,200); ax.set_yticks(yticks);ax.set_yticklabels((yticks)/10);
xticks=np.arange(0,1400,200); ax.set_xticks(xticks);ax.set_xticklabels((xticks)/10);
ax.axis('equal'); ax.axis([xticks[0],xticks[-1],yticks[0],yticks[-1]]);

cbar=fig.colorbar(cs,ax=ax,ticks=np.arange(0,40,10),aspect=10,pad=0.03);
cbar.set_label(label="Residence time (days)",rotation=90)
cbar.ax.tick_params(labelsize=10)

plt.show()