In [None]:
import glob
import os
from pathlib import Path

import numpy as np
import xarray as xr
import xesmf as xe
from scipy.interpolate import interp1d

In [None]:
dirpath = os.path.join(Path.home(), "FjordsSim_data", "oslofjord")

In [None]:
z_faces = [-340.0, -300.0, -260.0, -220.0, -180.0, -140.0, -100.0, -80.0, -60.0, -40.0, -20.0, -10.0, -5.0, 0.0]
z_centers = [(z_faces[i] + z_faces[i + 1]) / 2 for i in range(len(z_faces) - 1)]
ds_out_c = xr.Dataset(
    {
        "lat": (["lat"], np.linspace(59.1, 59.98, num=440), {"units": "degrees_north"}),
        "lon": (["lon"], np.linspace(10.2, 10.85, num=66), {"units": "degrees_east"}),
    }
)
ds_out_u = xr.Dataset(
    {
        "lat": (["lat"], np.linspace(59.1, 59.98, num=440), {"units": "degrees_north"}),
        "lon": (["lon"], np.linspace(10.2, 10.85, num=66 + 1), {"units": "degrees_east"}),
    }
)
ds_out_v = xr.Dataset(
    {
        "lat": (["lat"], np.linspace(59.1, 59.98, num=440 + 1), {"units": "degrees_north"}),
        "lon": (["lon"], np.linspace(10.2, 10.85, num=66), {"units": "degrees_east"}),
    }
)


def tranform_to_z(ds):
    """
    Transforms s coordingate to z with Vtransform = 2
    """
    zo_rho = (ds.hc * ds.s_rho + ds.Cs_r * ds.h) / (ds.hc + ds.h)
    z_rho = ds.zeta + (ds.zeta + ds.h) * zo_rho
    return z_rho.transpose()


def regrid_from_s_to_depths(values, z_values):
    interpolated_shape = list(z_values.shape)
    interpolated_shape[1] = len(z_centers)
    interpolated_values = np.empty(interpolated_shape)

    T, D, X, Y = values.shape
    for t in range(T):
        for x in range(X):
            for y in range(Y):
                f = interp1d(
                    z_values[t, :, x, y],
                    values[t, :, x, y],
                    kind="linear",
                    bounds_error=False,
                )
                interpolated_values[t, :, x, y] = f(z_centers)

    return interpolated_values

In [None]:
def regrid(ds_in):
    ds_in["z_rho"] = tranform_to_z(ds_in)

    regridder_rho = xe.Regridder(
        ds_in.rename({"lon_rho": "lon", "lat_rho": "lat"}), ds_out_c, "bilinear", unmapped_to_nan=True
    )
    regridder_u = xe.Regridder(
        ds_in.rename({"lon_u": "lon", "lat_u": "lat"}), ds_out_u, "bilinear", unmapped_to_nan=True
    )
    regridder_v = xe.Regridder(
        ds_in.rename({"lon_v": "lon", "lat_v": "lat"}), ds_out_v, "bilinear", unmapped_to_nan=True
    )

    da_temp = regridder_rho(ds_in["temp"])
    da_salt = regridder_rho(ds_in["salt"])
    da_zrho = regridder_rho(ds_in["z_rho"])
    da_u = regridder_u(ds_in["u"])
    da_v = regridder_v(ds_in["v"])

    zrho = da_zrho.values
    zrho = np.transpose(zrho, (1, 0, 2, 3))

    np_temp = regrid_from_s_to_depths(da_temp.values, zrho)
    np_salt = regrid_from_s_to_depths(da_salt.values, zrho)

    zu = np.zeros_like(da_u)
    zu[:, :, :, :-1] = zrho
    zu[:, :, :, -1] = zu[:, :, :, -2]
    zv = np.zeros_like(da_v)
    zv[:, :, :-1, :] = zrho
    zv[:, :, -1, :] = zv[:, :, -2, :]

    np_u = regrid_from_s_to_depths(da_u.values, zu)
    np_v = regrid_from_s_to_depths(da_v.values, zv)

    np_time = ds_in.ocean_time.values

    return np_time, np_temp, np_salt, np_u, np_v

In [None]:
filepaths = sorted(glob.glob(os.path.join(dirpath, "OF160_avg_*.nc")))

In [None]:
time_list = []
temp_list = []
salt_list = []
u_list = []
v_list = []

for filepath in filepaths:
    ds_in = xr.open_dataset(filepath)
    np_time, np_temp, np_salt, np_u, np_v = regrid(ds_in)

    time_list.append(np_time)
    temp_list.append(np_temp)
    salt_list.append(np_salt)
    u_list.append(np_u)
    v_list.append(np_v)

In [None]:
np_time = np.concatenate(time_list, axis=0)
np_temp = np.concatenate(temp_list, axis=0).astype(np.float32)
np_temp = np.clip(np_temp, 0, 30)
np_salt = np.concatenate(salt_list, axis=0).astype(np.float32)
np_salt = np.clip(np_salt, 0, 36)
np_u = np.concatenate(u_list, axis=0).astype(np.float32)
np_u = np.clip(np_u, -1, 1)
np_v = np.concatenate(v_list, axis=0).astype(np.float32)
np_v = np.clip(np_v, -1, 1)

#### Create output

In [None]:
Tout_lambda = np.full(np_temp.shape, np.nan, dtype=np.float32)
Sout_lambda = np.full(np_salt.shape, np.nan, dtype=np.float32)
Uout_lambda = np.full(np_u.shape, np.nan, dtype=np.float32)
Vout_lambda = np.full(np_v.shape, np.nan, dtype=np.float32)

In [None]:
Tout_lambda[:, -1, :, :] = 1 / 60 / 60 / 24 / 20
Sout_lambda[:, -1, :, :] = 1 / 60 / 60 / 24 / 20
Uout_lambda[:, -1, :, :] = 1 / 60 / 60 / 24 / 20
Vout_lambda[:, -1, :, :] = 1 / 60 / 60 / 24 / 20

In [None]:
lons = np.linspace(10.2, 10.85, num=66)
lons_faces = np.linspace(10.2, 10.85, num=66 + 1)
lats = np.linspace(59.1, 59.98, num=440)
lats_faces = np.linspace(59.1, 59.98, num=440 + 1)

In [None]:
dsout = xr.Dataset(
    {
        "T": (["time", "Nz", "Ny", "Nx"], np_temp),
        "T_lambda": (["time", "Nz", "Ny", "Nx"], Tout_lambda),
        "S": (["time", "Nz", "Ny", "Nx"], np_salt),
        "S_lambda": (["time", "Nz", "Ny", "Nx"], Sout_lambda),
        "u": (["time", "Nz", "Ny", "Nx_faces"], np_u),
        "u_lambda": (["time", "Nz", "Ny", "Nx_faces"], Uout_lambda),
        "v": (["time", "Nz", "Ny_faces", "Nx"], np_v),
        "v_lambda": (["time", "Nz", "Ny_faces", "Nx"], Vout_lambda),
    },
    coords={
        "time": np_time,
        "Nz": z_centers,
        "Ny": lats,
        "Ny_faces": lats_faces,
        "Nx": lons,
        "Nx_faces": lons_faces,
    },
)

In [None]:
dsout

In [None]:
dsout.T.isel(time=1, Nz=12).plot()

In [None]:
dsout.S.isel(time=1, Nz=12).plot()

In [None]:
dsout.u.isel(time=1, Nz=12).plot()

In [None]:
dsout.v.isel(time=1, Nz=12).plot()

In [None]:
dsout = dsout.fillna(-999)
dsout.to_netcdf(os.path.join(dirpath, "OF_inner_66to440_forcing.nc"))