# Postprocessing

In [10]:
import xarray as xr
import numpy as np
from pathlib import Path
import sys
import json

project_root = Path.cwd().parents[0]
sys.path.append(str(project_root))

from utils.bathymetry import generate_bathymetry
from utils.config import load_config, default_params
from utils.io import read_raw_output, ensure_dir
from utils.grid import prepare_dsH, interp_ds, mean_onH

In [11]:
# output folder for postprocessed data
output_folder = str(project_root)+"/output/processed/"


# focus depth contour will be in the middle of the slope
focus_j = 45


# subregion for analysis
analysis_region = {
    "short": {
        "focus_time_start" : -(128+8)*8,
        "focus_time_end"   : -8*8,
        "focus_j_start": 20,
        "focus_j_end": 70,
    },
    "long": {
        "focus_time_start" : -(128+64)*8,
        "focus_time_end"   : -64*8,
        "focus_j_start": 20,
        "focus_j_end": 70,
    }
}



def calculate_analytical_estimates_xr(
    ds: xr.Dataset,
    forcing_vars: list[str],
    params: dict,
) -> xr.DataArray:
    """
    y(t_j) = ∫_0^{t_j} exp[-(R/H(j)) * (t_j - τ)] * F(τ, j) dτ
    One-pass recurrence (left Riemann per step) with constant Δt = params["outputtime"].
    Recurrence: y_i = decay * y_{i-1} + alpha * F_{i-1},
    where decay = exp(-k*Δt), alpha = (1 - decay)/k, k = R/H(j).
    """
    R = float(params["R"])
    dt = float(params["outputtime"]) 

    H = ds["depth"]  
    
    # Sum forcing fields (must have dims include "time" and "j")
    F = sum(ds[v] for v in forcing_vars)

    t = ds["time"]
    nT = t.sizes["time"]
    if nT == 0:
        raise ValueError("Empty time axis.")
    if nT == 1:
        return xr.zeros_like(F)

    # k(j) = R/H(j), broadcast over all non-time dims of F
    k = (R / H).broadcast_like(F.isel(time=0))

    # Precompute coefficients (constant in time)
    decay = np.exp(-k * dt)
    # Safe alpha for k≈0: limit -> dt
    alpha = xr.where(np.abs(k) > 0, (1.0 - decay) / k, dt)

    # build y 
    y0 = xr.zeros_like(F.isel(time=0, drop=True))
    ys = [y0]
    for i in range(1, nT):
        Fi_1 = F.isel(time=i - 1, drop=True)  # drop time here too
        y_next = decay * ys[-1] + alpha * Fi_1
        ys.append(y_next)
    y = xr.concat(ys, dim="time").assign_coords(time=t) / H

    return y

## Focus cases
Process focus cases with uniform along-slope. There is one long-period case (128 days) and one short-period case (16 days)

In [12]:
for case in ["long", "short"]:
    # load data
    params = load_config(f"../configs/baseline_forcing/{case}.json")
    ds = read_raw_output(params)
    

    #####################################################
    #### DEPTH-FOLLOWING INTERPOLATION AND DIAGNOSES ####
    #####################################################
    
    ### interpolate to depth-following grid ###
    # determine target depths H_targets (mean depth along xC at each yC)
    H_targets = ds.bath.mean("xC").values
    
    dsH = prepare_dsH(ds, params, H_targets)
        
    ### calculate momentum terms ###
    dsH["circulation"] = mean_onH(dsH, variable="ui")
    dsH["BS"] = -dsH.circulation*params["R"]
    #dsH["TFS"] = dsH.circulation*0
    dsH["RVF"] = mean_onH(dsH, variable="zetaflux") * dsH.depth
    dsH["SS"] = mean_onH(dsH, variable="forcing_i") 

    ### timeseries at focus depth contour ###
    
    # circulation estimates
    dsH["linear_estimate"] = calculate_analytical_estimates_xr(dsH, ["SS"], params)
    dsH["nonlinear_estimate"] = calculate_analytical_estimates_xr(dsH, ["SS", "RVF"], params)
    
    ts = xr.Dataset()
    ts["circulation"] = dsH["circulation"].isel(j=focus_j)
    ts["linear_estimate"] = dsH["linear_estimate"].isel(j=focus_j)
    ts["nonlinear_estimate"] = dsH["nonlinear_estimate"].isel(j=focus_j)
    
    # select focus period
    ts = ts.isel(
        time=slice(
            analysis_region[case]["focus_time_start"],
            analysis_region[case]["focus_time_end"],
        ),
    )
    ts["time"] = ts["time"] - ts["time"].values[0]
    
    ensure_dir(output_folder+"/timeseries/")
    ts.squeeze().to_netcdf(output_folder+f"/timeseries/analytical_estimates_{case}.nc")
    
    
    # select focus region and period
    dsH = dsH.isel(
        time=slice(
            analysis_region[case]["focus_time_start"],
            analysis_region[case]["focus_time_end"],
        ),
        j=slice(
            analysis_region[case]["focus_j_start"],
            analysis_region[case]["focus_j_end"],
        ),
    )
    dsH["time"] = dsH["time"] - dsH["time"].values[0]  # reset time to start at 0
    
    
    ### save relevant momentum terms ###
    if case == "short":
        dsH = dsH.isel(time=slice(-16*8-1, None))  # only last 16 days for short case
        dsH["time"] = dsH["time"] - dsH["time"].values[0]  # reset time to start at 0
        
    ensure_dir(output_folder+"/momentum_terms_H/")
    dsH[["circulation", "RVF", "BS", "SS"]].to_netcdf(
        output_folder+f"/momentum_terms_H/momentum_terms_H_{case}.nc"
    )
    
    
    #####################################################
    ############### CARTESIAN DIAGNOSES #################
    #####################################################
    
    dsY = interp_ds(ds, params, ["u", "v", "zetav","forcing_x", "detadx", "duvhdy"])
    dsY = dsY.isel(
        time=slice(
            analysis_region[case]["focus_time_start"],
            analysis_region[case]["focus_time_end"],
        ),
        yC=slice(
            analysis_region[case]["focus_j_start"],
            analysis_region[case]["focus_j_end"],
        ),
    )
    
    if case == "short":
        dsY = dsY.isel(time=slice(-16*8, None))  # only last 16 days for short case   
    
    dsY["time"] = dsY["time"] - dsY["time"].values[0]  # reset time to start at 0

    # calculate momentum terms
    H0 = dsY.bath.mean("xC")
    h = H0 - dsY.bath 

    dsY["circulation"] = dsY.u.mean("xC")
    dsY["BS"] = -dsY.circulation*params["R"]
    dsY["TFS"] = (-params["gravitational_acceleration"]*dsY.detadx*dsY.bath).mean("xC")
    dsY["MFC"] = (-dsY.duvhdy).mean("xC")
    dsY["SS"] = (dsY.forcing_x).mean("xC")
    dsY["QGPVF"] = (dsY.zetav).mean("xC")*H0 + (dsY.v*h).mean("xC")*params["f"]
    
    
    # save relevant momentum terms
    ensure_dir(output_folder+"/momentum_terms_y/")
    dsY[["circulation", "BS","TFS", "MFC", "SS", "QGPVF"]].to_netcdf(
        output_folder+f"/momentum_terms_y/momentum_terms_y_{case}.nc"
    )   


Loading configuration from ../configs/baseline_forcing/long.json
Directory created: /itf-fi-ml/home/alsjur/temporal-topo-flow/output/processed/timeseries
Directory created: /itf-fi-ml/home/alsjur/temporal-topo-flow/output/processed/momentum_terms_H
Directory created: /itf-fi-ml/home/alsjur/temporal-topo-flow/output/processed/momentum_terms_y
Loading configuration from ../configs/baseline_forcing/short.json


## No-bumps

In [13]:
for case in ["long_nobumps", "short_nobumps", "long_nobumps_crosswind", "short_nobumps_crosswind"]:
    # load data
    params = load_config(f"../configs/baseline_forcing/{case}.json")
    ds = read_raw_output(params)
    


    dsY = interp_ds(ds, params, ["u", "v", "zetav","forcing_x", "detadx", "duvhdy"])
    dsY = dsY.isel(
        time=slice(
            analysis_region[case.split("_")[0]]["focus_time_start"],
            analysis_region[case.split("_")[0]]["focus_time_end"],
        ),
        yC=slice(
            analysis_region[case.split("_")[0]]["focus_j_start"],
            analysis_region[case.split("_")[0]]["focus_j_end"],
        ),
    )
    
    if case.split("_")[0] == "short":
        dsY = dsY.isel(time=slice(-16*8, None))  # only last 16 days for short case   
    
    dsY["time"] = dsY["time"] - dsY["time"].values[0]  # reset time to start at 0



    dsY["circulation"] = dsY.u.mean("xC")
    dsY["TFS"] = (-params["gravitational_acceleration"]*dsY.detadx*dsY.bath).mean("xC")
    dsY["MFC"] = (-dsY.duvhdy).mean("xC")
    dsY["SS"] = (dsY.forcing_x).mean("xC")
    dsY["BS"] = -dsY.circulation*params["R"]
    
    # time-mean terms
    dsY["TFSy"] = dsY["TFS"].mean("time")
    dsY["MFCy"] = dsY["MFC"].mean("time")
    dsY["SSy"] = dsY["SS"].mean("time")
    dsY["BSy"] = dsY["BS"].mean("time")
    
    # y meaned terms
    dsY["TFSt"] = dsY["TFS"].mean("yC")
    dsY["MFCt"] = dsY["MFC"].mean("yC")
    dsY["SSt"] = dsY["SS"].mean("yC")
    dsY["BSt"] = dsY["BS"].mean("yC")
    
    # save relevant momentum terms
    ensure_dir(output_folder+"/momentum_terms_y/")
    dsY[["circulation", "TFSt", "MFCt", "SSt", "BSt","TFSy", "MFCy", "SSy", "BSy"]].to_netcdf(
        output_folder+f"/momentum_terms_y/momentum_terms_y_{case}.nc"
    )   

Loading configuration from ../configs/baseline_forcing/long_nobumps.json
Loading configuration from ../configs/baseline_forcing/short_nobumps.json
Loading configuration from ../configs/baseline_forcing/long_nobumps_crosswind.json
Loading configuration from ../configs/baseline_forcing/short_nobumps_crosswind.json


# Wave calculation

## Arrest speed

In [14]:
def estimate_prograde_retrograde_speed(dsH,selection=slice(40,50), spread=True):
    slope = dsH.sel(j=selection).circulation
    #slope = dsH.sel(j=slice(40,50)).circulation    
    
    if spread:
        prograde = slope.max("time").mean("j").item()
        retrograde = np.abs(slope.min("time").mean("j").item())
        
        retrograde_max = np.abs(slope.min("time").min("j").item())
        retrograde_min = np.abs(slope.min("time").max("j").item())
    
    
        return retrograde, prograde, retrograde_min, retrograde_max

    else:
        retrograde = np.abs(np.min(slope))
        prograde = np.max(slope)
        
        return retrograde, prograde

In [15]:

params_22km = load_config("../configs/varying_bathymetry/half.json")
params_45km = load_config("../configs/baseline_forcing/long.json")
params_90km = load_config("../configs/varying_bathymetry/double.json")

arrest_speed_results = {}

for params, wavelength in zip([params_22km, params_45km, params_90km], [22.5, 45, 90]):
    ds = read_raw_output(params)
    dsH = prepare_dsH(ds, params, H_targets)
    
    # diagnose circulation
    dsH["circulation"] = mean_onH(dsH, variable="ui")
    
    retrograde, prograde, retrograde_min, retrograde_max = estimate_prograde_retrograde_speed(dsH)
    
    arrest_speed_results[wavelength] = {
        "retrograde": retrograde,
        "prograde": prograde,
        "retrograde_min": retrograde_min,
        "retrograde_max": retrograde_max,
    }
    
ensure_dir(output_folder+"/wave_comparison/")
json.dump(arrest_speed_results, open(output_folder+"/wave_comparison/arrest_speeds.json", "w"))

Loading configuration from ../configs/varying_bathymetry/half.json
Loading configuration from ../configs/baseline_forcing/long.json
Loading configuration from ../configs/varying_bathymetry/double.json
Directory created: /itf-fi-ml/home/alsjur/temporal-topo-flow/output/processed/wave_comparison


## Mode structure

In [16]:
params = load_config("../configs/baseline_forcing/long.json")
ds = read_raw_output(params)
ds = ds.isel(
    time=slice(
        analysis_region["long"]["focus_time_start"],
        analysis_region["long"]["focus_time_end"],
    ),
)
ds["time"] = ds["time"] - ds["time"].values[0]  # reset time to start at 0

eta = ds["h"] - ds["bath"]
etanod = eta - eta.mean(dim="xC")

etanod_neg = etanod.isel(time=slice(0,64*8))
etanod_pos = etanod.isel(time=slice(64*8,None))
etanod_pos["time"] = etanod_pos["time"] - etanod_pos["time"].isel(time=0)

mode = etanod_neg + etanod_pos 

mode.to_netcdf(output_folder+"/wave_comparison/mode_structure.nc")

Loading configuration from ../configs/baseline_forcing/long.json


## 

## Max speed for varying forcing

In [17]:
depths_to_use = [H_targets[40], H_targets[45], H_targets[50]]
periods = ["short", "long"]
n_runs = 8 

# --- Pre-allocate tidy dataset ---
ds_out = xr.Dataset(
    coords=dict(
        period=("period", periods),
        run=("run", np.arange(1, n_runs + 1)),
        depth=("depth", depths_to_use),
    ),
    data_vars=dict(
        forcing_strength=(("period", "run"), np.full((len(periods), n_runs), np.nan)),
        prograde_max=(("period", "depth", "run"), np.full((len(periods), len(depths_to_use), n_runs), np.nan)),
        retrograde_max=(("period", "depth", "run"), np.full((len(periods), len(depths_to_use), n_runs), np.nan)),
    ),
)

for p_idx, p in enumerate(periods):
    for r in range(1, n_runs + 1):
        params_i = load_config(f"../configs/varying_forcing/{p}_{r:03d}.json")
        forcing = params_i["tau0"]
        ds_i = read_raw_output(params_i).isel(time=slice(-128*8, None))

        dsH_i = prepare_dsH(ds_i, params_i, depths_to_use)
        circ = mean_onH(dsH_i, variable="ui")
        dsH_i["circulation"] = circ#hanning_filter(circ, window_length=2*8)

        ds_out["forcing_strength"][p_idx, r - 1] = forcing

        for k, H in enumerate(dsH_i["depth"].values):
            retro, pro = estimate_prograde_retrograde_speed(dsH_i, selection=k, spread=False)
            ds_out["prograde_max"][p_idx, k, r - 1] = pro
            ds_out["retrograde_max"][p_idx, k, r - 1] = retro
            
ds_sorted = ds_out.copy() 
for p in periods: 
    F = ds_out["forcing_strength"].sel(period=p) 
    order = np.argsort(F.values) 
    ds_sorted["forcing_strength"].loc[dict(period=p)] = F.values[order] 
    for var in ["prograde_max", "retrograde_max"]: 
        ds_sorted[var].loc[dict(period=p)] = ds_out[var].sel(period=p).values[..., order]
        
vars_to_save = ["forcing_strength", "prograde_max", "retrograde_max"]
ds_to_write = ds_sorted[vars_to_save]

ds_to_write.to_netcdf(output_folder+"/wave_comparison/forcing_vs_arrest_speeds.nc")


Loading configuration from ../configs/varying_forcing/short_001.json
Loading configuration from ../configs/varying_forcing/short_002.json
Loading configuration from ../configs/varying_forcing/short_003.json
Loading configuration from ../configs/varying_forcing/short_004.json
Loading configuration from ../configs/varying_forcing/short_005.json
Loading configuration from ../configs/varying_forcing/short_006.json
Loading configuration from ../configs/varying_forcing/short_007.json
Loading configuration from ../configs/varying_forcing/short_008.json
Loading configuration from ../configs/varying_forcing/long_001.json
Loading configuration from ../configs/varying_forcing/long_002.json
Loading configuration from ../configs/varying_forcing/long_003.json
Loading configuration from ../configs/varying_forcing/long_004.json
Loading configuration from ../configs/varying_forcing/long_005.json
Loading configuration from ../configs/varying_forcing/long_006.json
Loading configuration from ../configs/va