In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import dask
import dask.array as darr
from dask.distributed import Client
import math

import numpy as np
from scipy.interpolate import RectBivariateSpline as rbv
from scipy.interpolate import interpn

In [None]:
cl = Client(n_workers=2, processes=False)
grd = xr.open_dataset("grd_wcofs_large_visc200.nc")
dat = xr.open_mfdataset("zuvt_qck_Exp41_35*.nc", parallel=True, chunks={"ocean_time":48, "eta_rho":600, "xi_rho":300}, combine="by_coords")
cl

In [None]:
zeta = dat["zeta"]
lonrs = grd.coords["lon_rho"]
latrs = grd.coords["lat_rho"]
lonu = grd["lon_u"]
latu = grd["lat_u"]
lonv = grd["lon_v"]
latv = grd["lat_v"]
mr = grd["mask_rho"]
#size of ROMS grid
dx = 1/grd["pm"]*mr #.fillna(0.0)
dy = 1/grd["pn"]*mr #.fillna(0.0)
cp = grd["f"]*mr #.fillna(0.0) #coriolis parameter

dx1, dx2 = dx[:, :-1], dx[:, 1:]
dy1, dy2 = dy[:-1, :], dy[1:, :]
dx_u = .5*(dx1 + dx2)
dy_v = .5*(dy1 + dy2)

fu1, fu2 = cp[:, :-1], cp[:, 1:]
fv1, fv2 = cp[:-1, :], cp[1:, :]
f_u = .5*(fu1 + fu2)
f_v = .5*(fv1 + fv2)

zr = zeta.rolling(ocean_time=48)
zavg = zr.mean()[47:]
dzetamu = zavg[:, :, 1:] - zavg[:, :, :-1] 
dzetamv = zavg[:, 1:, :] - zavg[:, :-1, :]
termu = dzetamu/dx_u
termv = dzetamv/dy_v

vg_u = (9.8/f_u)*termu
ug_v = -(9.8/f_v)*termv
vg_u = vg_u.transpose("ocean_time", "eta_rho", "xi_rho", transpose_coords=False)
ug_v = ug_v.transpose("ocean_time", "eta_rho", "xi_rho", transpose_coords=False)
vg_u = vg_u.rename(eta_rho="eta_u", xi_rho="xi_u")
ug_v = ug_v.rename(eta_rho="eta_v", xi_rho="xi_v")
vg_u = vg_u.fillna(0.0)
ug_v = ug_v.fillna(0.0)

In [None]:
#put u & v on proper grid points
ug_u = np.zeros((vg_u.shape[0],)+lonu.shape)#, chunks=(48, 300, 300))
vg_v = np.zeros((vg_u.shape[0],)+lonv.shape)#, chunks=(48, 300, 300))

ug_r = .5*(ug_v[:, :-1, :] + ug_v[:, 1:, :]) #v to rho
ug_u[:, 1:-1, :] = .5*(ug_r[:, :, :-1] + ug_r[:, :, 1:]) #rho to u
vg_r = .5*(vg_u[:, :, :-1] + vg_u[:, :, 1:])
vg_v[:, :, 1:-1] = .5*(vg_r[:, :-1, :] + vg_r[:, 1:, :])

ug_u = xr.DataArray(ug_u, coords=vg_u.coords, dims=vg_u.dims)
vg_v = xr.DataArray(vg_v, coords=ug_v.coords, dims=ug_v.dims)

**Verify Currents**

In [None]:
rsl = slice(-640, -620)
csl = slice(-350, -330)
plt.contour(lonrs[rsl, csl], latrs[rsl, csl], zavg[4, rsl, csl], 15)
plt.quiver(lonrs[rsl, csl], latrs[rsl, csl], ug_u[4, rsl, csl], vg_v[4, rsl, csl])

In [None]:
plt.subplot(231)
plt.contour(lonrs, latrs, zeta[1], 10)
plt.subplot(232)
plt.contour(lonrs, latrs, zavg[1], 10)
plt.subplot(233)
plt.contour(lonu, latu, dx_u, 20)
plt.subplot(234)
plt.contour(lonv, latv, dy_v, 20)
plt.subplot(235)
plt.contour(lonu[:100, :], latu[:100, :], f_u[:100, :])
plt.subplot(236)
plt.contour(lonrs, latrs, mr)
#plt.quiver(lonrs[-600, -300], latrs[-600, -300], vg_u)

In [None]:
plt.contour(lonrs[rsl, csl], latrs[rsl, csl], zavg[1, rsl, csl], 15)
plt.quiver(lonrs[rsl, csl], latrs[rsl, csl], ug_v[1, rsl, csl], vg_u[1, rsl, csl])

In [None]:
plt.subplot(121)
plt.imshow(vg_u[3])
plt.subplot(122)
plt.imshow(dat["v_sur"][3])

**Save Geostrophic Currents**

In [None]:
gsl = []
names = []
for i in range(0, len(zavg.coords["ocean_time"]), 24):
    gsl.append(xr.Dataset({"vg":vg_v[i:i+24], "ug":ug_u[i:i+24]}))
    names.append(f"geostrophic_uv{int(i/24)}.nc")
xr.save_mfdataset(gsl,names,format="NETCDF4")

cl.close()