In [None]:
# ===============================================================================================
# @des: Ensemble Kalmer Filter: We will later integrate this EnKF with several ICE sheet models.
# @date: 12-09-2024
# @authors: Brian KYANJO and Alex Robel
# ===============================================================================================

In [None]:
# import libraries --------------------------------------------------------------
import numpy as np
import jax
import jax.numpy as jnp
from jax import jacfwd
import matplotlib.pyplot as plt
from scipy import optimize
import warnings
warnings.filterwarnings("ignore")

jax.config.update("jax_enable_x64", True) # Set the precision in JAX to use float64

## Steps to follow
- Start with replicating the Julia Code in Python
- This will help in comparing results later 
- After when python agrees with Julia
- Form General Classes for the code to make it more genral using OBP
- Profile the code to see how long the filter takes for very big vlaues of N
- Rewrite the Filter to C and call it in python 
- if Still slow, add parallel capabilities using MPI or OpenMP and still wrap it in python.
- Once this works, now proceed to trying to couple the filter with another problem
- If sucess then, now try with an actual ICE-sheet model
- Continue to remodel and develop the wrapper ....
- Finnally convert the wrapper into a python script

In [None]:
# Parameters definition

def params_define():
    # Prescribed initial values of model parameters in a dictionary
    params = {}
    params["A"]         = 4e-26         
    params["year"]      = 3600 * 24 * 365
    params["n"]         = 3
    params["C"]         = 3e6
    params["rho_i"]     = 900
    params["rho_w"]     = 1000
    params["g"]         = 9.8
    params["B"]         = params["A"] ** (-1 / params["n"])
    params["m"]         = 1 / params["n"]
    params["accum"]     = 0.65 / params["year"]
    params["facemelt"]  = 5 / params["year"]

    # Scaling parameters
    params["hscale"]    = 1000
    params["ascale"]    = 1.0 / params["year"]
    params["uscale"]    = (params["rho_i"] * params["g"] * params["hscale"] * params["ascale"] / params["C"]) ** (1 / (params["m"] + 1))
    params["xscale"]    = params["uscale"] * params["hscale"] / params["ascale"]
    params["tscale"]    = params["xscale"] / params["uscale"]
    params["eps"]       = params["B"] * ((params["uscale"] / params["xscale"]) ** (1 / params["n"])) / (2 * params["rho_i"] * params["g"] * params["hscale"])
    params["lambda"]    = 1 - (params["rho_i"] / params["rho_w"])

    # Grid parameters
    params["NT"]        = 1
    params["TF"]        = params["year"]
    params["dt"]        = params["TF"] / params["NT"]
    params["transient"] = 0
    params["tcurrent"]  = 1

    params["N1"]        = 40
    params["N2"]        = 10
    params["sigGZ"]     = 0.97
    params["NX"]        = params["N1"] + params["N2"]

    # Bed params
    # params["b0"] = -400
    params["xsill"]      = 50e3
    params["sillamp"]    = 500
    params["sillsmooth"] = 1e-5
    # params["bxr"] = 1e-3
    # params["bxp"] = -1e-3

    # EnKF params
    params["inflation"] = 1.0
    params["assim"]     = False

    # Generating sigma values
    sigma1    = np.linspace(params["sigGZ"] / (params["N1"] + 0.5), params["sigGZ"], int(params["N1"]))
    sigma2    = np.linspace(params["sigGZ"], 1, int(params["N2"] + 1))
    sigma     = np.concatenate((sigma1, sigma2[1:params["N2"] + 1]))

    # Create the grid dictionary
    grid                = {"sigma": sigma}
    grid["sigma_elem"]  = np.concatenate(([0], (sigma[:-1] + sigma[1:]) / 2))
    grid["dsigma"]      = np.diff(grid["sigma"])

    return params, grid

In [52]:
# Observation operator ----------------------
def Obs(huxg_virtual_obs, m_obs):
    # Determine the size of the observation vector
    n = huxg_virtual_obs.shape[0]
    m = m_obs

    # Initialize the H matrix with zeros
    H = np.zeros((m * 2 + 1, n))

    # Calculate the distance between measurements
    di = int((n - 2) / (2 * m))  # Python uses int for integer division

    # Fill in the H matrix
    for i in range(1, m + 1):  # Python's range is 0-based, so start from 1
        H[i - 1, i * di] = 1  # Adjust for 0-based indexing
        H[m + i - 1, int((n - 2) / 2) + i * di] = 1  # Adjust for 0-based indexing

    # Final element of H matrix
    H[m * 2, n - 1] = 1  # Adjust for 0-based indexing

    # Perform matrix multiplication
    z = H @ huxg_virtual_obs  # '@' operator for matrix multiplication in Python
    return z

# Jacobian of the observation operator ----------------------
def JObs(n_model, m_obs):
    # Initialize the H matrix with zeros
    n = n_model
    m = m_obs
    H = np.zeros((m * 2 + 1, n))

    # Calculate the distance between measurements
    di = int((n - 2) / (2 * m))  # Convert distance to an integer

    # Fill in the H matrix
    for i in range(1, m + 1):  # Python uses 0-based indexing
        H[i - 1, i * di] = 1  # Adjust for 0-based indexing
        H[m + i - 1, int((n - 2) / 2) + i * di] = 1  # Adjust for 0-based indexing

    # Final element of H matrix
    H[m * 2, n - 1] = 1  # Adjust for 0-based indexing

    return H

# EnKF function ------------------------------
def EnKF(huxg_ens, huxg_obs, ObsFun, JObsFun, Cov_obs, Cov_model, params, taper):
    """
    Analysis step of the Ensemble Kalman Filter (EnKF).
    
    huxg_ens: ndarray (n x N) - The ensemble matrix of model states (n is state size, N is ensemble size).
    huxg_obs: ndarray (m,) - The observation vector (m is measurement size).
    ObsFun: Function - Observation function.
    JObsFun: Function - Jacobian of the observation function.
    Cov_obs: ndarray (m x m) - Observation covariance matrix.
    Cov_model: ndarray (n x n) - Model covariance matrix.
    params: dict - Dictionary containing parameters like "m_obs".
    taper: ndarray (n x n) - Covariance taper matrix.
    
    Returns:
    analysis_ens: ndarray (n x N) - The updated ensemble after analysis.
    analysis_cov: ndarray (n x n) - The updated covariance after analysis.
    """
    n, N = huxg_ens.shape  # n is the state size, N is the ensemble size
    m = huxg_obs.shape[0]  # Measurement size (m)

    # Compute the ensemble mean
    huxg_ens_mean = np.mean(huxg_ens, axis=1, keepdims=True)  # mean of model forecast ensemble
    
    # Compute the Jacobian of the observation operator
    Jobs = JObsFun(n, params["m_obs"])  # Jacobian of the observation operator
    
    # Compute the Kalman Gain
    KalGain = Cov_model @ Jobs.T @ np.linalg.inv(Jobs @ Cov_model @ Jobs.T + Cov_obs)

    # Initialize variables
    obs_virtual = np.zeros((m, N))  # Virtual observations for the ensemble
    analysis_ens = np.zeros_like(huxg_ens)  # Analysis ensemble

    # Perform the analysis for each ensemble member
    for i in range(N):
        # Generate virtual observations using multivariate normal distribution
        obs_virtual[:, i] = huxg_obs + multivariate_normal.rvs(mean=np.zeros(m), cov=Cov_obs)

        # Update the ensemble member with the Kalman gain
        analysis_ens[:, i] = huxg_ens[:, i] + KalGain @ (obs_virtual[:, i] - ObsFun(huxg_ens[:, i], params["m_obs"]))
    
    # Compute the mean of the analysis ensemble
    analysis_ens_mean = np.mean(analysis_ens, axis=1, keepdims=True)
    
    # Compute the analysis error covariance
    analysis_cov = (1 / (N - 1)) * (analysis_ens - analysis_ens_mean) @ (analysis_ens - analysis_ens_mean).T
    analysis_cov = analysis_cov * taper  # Apply covariance tapering
    
    return analysis_ens, analysis_cov


In [54]:
# bed topography function --------------------------------------------------------------
def bed(x,params):
    b = params['sillamp'] * (-2 * jnp.arccos((1 - params['sillsmooth']) * jnp.sin(jnp.pi * x / (2 * params['xsill'])))/jnp.pi - 1)
    return b

In [86]:
# Implicit flowline model function --------------------------------------------------------------
# JAX version of the flowline function --------------------------------------------------------------
def flowline(F, varin, varin_old, params, grid, bedfun):
    # Unpack grid
    NX          = params["NX"]
    N1          = params["N1"]
    dt          = params["dt"] / params["tscale"]
    ds          = grid["dsigma"]
    sigma       = grid["sigma"]
    sigma_elem  = grid["sigma_elem"]

    # Unpack parameters
    tcurrent    = params["tcurrent"]
    xscale      = params["xscale"]
    hscale      = params["hscale"]
    lambd       = params["lambda"]
    m           = params["m"]
    n           = params["n"]
    a           = params["accum"] / params["ascale"]
    mdot        = params["facemelt"] / params["uscale"]
    eps         = params["eps"]
    transient   = params["transient"]

    # Unpack variables
    h   = varin[:NX]
    u   = varin[NX:2*NX]
    xg  = varin[2*NX]

    h_old   = varin_old[:NX]
    xg_old  = varin_old[2*NX]


    # Calculate bed 
    hf  = -bedfun(xg * xscale, params) / (hscale * (1 - lambd))
    hfm = -bedfun(xg * sigma_elem[-1] * xscale, params) / (hscale * (1 - lambd))
    b   = -bedfun(xg * sigma * xscale, params) / hscale

    # Calculate thickness functions
    F = jnp.zeros_like(F)  # Initialize F to be sure it's JAX-compatible
    F = F.at[0].set(transient * (h[0] - h_old[0]) / dt + (2 * h[0] * u[0]) / (ds[0] * xg)  - a)
    # print(F[0])
    
    F = F.at[1].set(
        transient * (h[1] - h_old[1]) / dt
        - transient * sigma_elem[1] * (xg - xg_old) * (h[2] - h[0]) / (2 * dt * ds[1] * xg)
        + (h[1] * (u[1] + u[0])) / (2 * xg * ds[1]) - a
    )

    F = F.at[2:NX-1].set(
        transient * (h[2:NX-1] - h_old[2:NX-1]) / dt
        - transient * sigma_elem[2:NX-1] * (xg - xg_old) * (h[3:NX] - h[1:NX-2]) / (2 * dt * ds[2:NX-1] * xg)
        + (h[2:NX-1] * (u[2:NX-1] + u[1:NX-2]) - h[1:NX-2] * (u[1:NX-2] + u[0:NX-3])) / (2 * xg * ds[2:NX-1]) - a
    )

    F = F.at[N1-1].set(
        (1 + 0.5 * (1 + (ds[N1-1] / ds[N1-2]))) * h[N1-1]
        - 0.5 * (1 + (ds[N1-1] / ds[N1-2])) * h[N1-2]
        - h[N1]
    )

    F = F.at[NX-1].set(
        transient * (h[NX-1] - h_old[NX-1]) / dt
        - transient * sigma[NX-1] * (xg - xg_old) * (h[NX-1] - h[NX-2]) / (dt * ds[NX-2] * xg)
        + (h[NX-1] * (u[NX-1] + mdot * hf / h[NX-1] + u[NX-2]) - h[NX-2] * (u[NX-2] + u[NX-3])) / (2 * xg * ds[NX-2])
        - a
    )

    # Calculate velocity functions
    F = F.at[NX].set(
        ((4 * eps / (xg * ds[0]) ** ((1 / n) + 1)) * (h[1] * (u[1] - u[0]) * abs(u[1] - u[0]) ** ((1 / n) - 1)
            - h[0] * (2 * u[0]) * abs(2 * u[0]) ** ((1 / n) - 1)))
        - u[0] * abs(u[0]) ** (m - 1)
        - 0.5 * (h[0] + h[1]) * (h[1] - b[1] - h[0] + b[0]) / (xg * ds[0])
    )

    F = F.at[NX+1:2*NX-1].set(
        (4 * eps / (xg * ds[1:NX-1]) ** ((1 / n) + 1))
        * (h[2:NX] * (u[2:NX] - u[1:NX-1]) * abs(u[2:NX] - u[1:NX-1]) ** ((1 / n) - 1)
           - h[1:NX-1] * (u[1:NX-1] - u[0:NX-2]) * abs(u[1:NX-1] - u[0:NX-2]) ** ((1 / n) - 1))
        - u[1:NX-1] * abs(u[1:NX-1]) ** (m - 1)
        - 0.5 * (h[1:NX-1] + h[2:NX]) * (h[2:NX] - b[2:NX] - h[1:NX-1] + b[1:NX-1]) / (xg * ds[1:NX-1])
    )

    F = F.at[NX+N1-1].set((u[N1] - u[N1-1]) / ds[N1-1] - (u[N1-1] - u[N1-2]) / ds[N1-2])
    F = F.at[2*NX-1].set(
        (1 / (xg * ds[NX-2]) ** (1 / n)) * (abs(u[NX-1] - u[NX-2]) ** ((1 / n) - 1)) * (u[NX-1] - u[NX-2])
        - lambd * hf / (8 * eps)
    )

    # Calculate grounding line functions
    F = F.at[2*NX].set(3 * h[NX-1] - h[NX-2] - 2 * hf)

    return F

def Jac_calc(huxg_old, params, grid, bedfun, flowlinefun):
    """
    Use automatic differentiation to calculate Jacobian for nonlinear solver.
    """

    def f(varin):
        # Initialize F as an array of zeros with size 2*NX + 1
        F = jnp.zeros(2 * params["NX"] + 1, dtype=jnp.float64)
        # Call the flowline function with current arguments
        flowlinefun(F, varin, huxg_old, params, grid, bedfun)
        return F

    # Create a function that calculates the Jacobian using JAX
    def Jf(varin):
        # Jacobian of f with respect to varin
        return jax.jacfwd(f)(varin)

    return Jf


def flowline_run(varin, params, grid, bedfun, flowlinefun):
    nt = params["NT"]
    huxg_old = varin
    huxg_all = jnp.zeros((huxg_old.shape[0], nt))

    for i in range(nt):
        if not params["assim"]:
            params["tcurrent"] = i + 1  # Adjusting for 1-based indexing in Julia
        
        # Jacobian calculation
        Jf = Jac_calc(huxg_old, params, grid, bedfun, flowlinefun)
        
        # Solve the system of nonlinear equations
        solve_result = optimize.root(
            lambda varin: flowlinefun(jnp.zeros_like(varin), varin, huxg_old, params, grid, bedfun), 
            huxg_old, 
            jac=Jf, 
            method='hybr',  # Hybr is a commonly used solver like nlsolve
            options={'maxiter': 100}
        )
        
        # Update the old solution
        huxg_old = solve_result.x
        
        # Store the result for this time step
        huxg_all[:, i] = huxg_old

        # Check for convergence
        if not solve_result.success:
            print(f"Solver didn't converge at time step {i + 1}")
        
        if not params["assim"]:
            print(f"Step {i + 1}\n")
    
    return huxg_all



In [89]:
# Initial setup from params_define and initial guess
params, grid = params_define()

# Initial guess and steady-state
xg = 300e3 / params["xscale"]
hf = (-bed(xg * params["xscale"], params) / params["hscale"]) / (1 - params["lambda"])
h  = 1 - (1 - hf) * grid["sigma"]
u  = 1.0 * (grid["sigma_elem"] ** (1 / 3)) + 1e-3
huxg_old = np.concatenate((h, u, [xg]))

Jf = Jac_calc(huxg_old, params, grid, bed, flowline)

def flowline_wrapper(varin):
    F = np.zeros_like(varin)
    flowline(F, varin, huxg_old, params, grid, bed)
    return F

solve_result = root(flowline_wrapper, huxg_old, jac=Jf, method='hybr', options={'maxfev': 1000})
huxg_out0 = solve_result.x
huxg_out0

array([1.00266118e+00, 1.00535648e+00, 1.00805177e+00, 1.01074707e+00,
       1.01344237e+00, 1.01613767e+00, 1.01883296e+00, 1.02152826e+00,
       1.02422356e+00, 1.02691886e+00, 1.02961415e+00, 1.03230945e+00,
       1.03500475e+00, 1.03770005e+00, 1.04039534e+00, 1.04309064e+00,
       1.04578594e+00, 1.04848124e+00, 1.05117653e+00, 1.05387183e+00,
       1.05656713e+00, 1.05926242e+00, 1.06195772e+00, 1.06465302e+00,
       1.06734832e+00, 1.07004361e+00, 1.07273891e+00, 1.07543421e+00,
       1.07812951e+00, 1.08082480e+00, 1.08352010e+00, 1.08621540e+00,
       1.08891070e+00, 1.09160599e+00, 1.09430129e+00, 1.09699659e+00,
       1.09969189e+00, 1.10238718e+00, 1.10508248e+00, 1.10777778e+00,
       1.10811111e+00, 1.10844444e+00, 1.10877778e+00, 1.10911111e+00,
       1.10944444e+00, 1.10977778e+00, 1.11011111e+00, 1.11044444e+00,
       1.11077778e+00, 1.11111111e+00, 1.00000000e-03, 3.31435469e-01,
       3.93218632e-01, 4.39983204e-01, 4.78470029e-01, 5.11587765e-01,
      

In [None]:
# call the parameters function
params, grid = params_define()

# Initial guess and steady-state 
xg = 300e3 / params["xscale"]
hf = (-bed(xg * params["xscale"], params) / params["hscale"]) / (1 - params["lambda"])
h = 1 - (1 - hf) * grid["sigma"]
u = 1.0 * (grid["sigma_elem"] ** (1 / 3)) + 1e-3

# Concatenate h, u, and xg into a single array
huxg_old = np.concatenate((h, u, [xg]))

# Calculate the Jacobian function
Jf = Jac_calc(huxg_old, params, grid, bed, flowline)

# Define the function to solve, which is equivalent to the anonymous function in Julia
def nonlinear_system(varin):
    F = np.zeros_like(varin)
    flowline(F, varin, huxg_old, params, grid, bed)
    return F

# Solve the nonlinear system
solve_result = optimize.root(nonlinear_system, huxg_old, jac=Jf, method='hybr', options={'maxiter': 1000})

# # Get the result
# huxg_out0 = solve_result.x

# # True simulation
# params["NT"] = 150
# params["TF"] = params["year"] * 150
# params["dt"] = params["TF"] / params
# params["transient"] = 1
# params["facemelt"] = np.linspace(5, 85, params["NT"] + 1) / params["year"]
# fm_dist = np.random.normal(0, 20.0)
# fm_truth = params["facemelt"] 
# params["facemelt"] = fm_truth
# huxg_out1 = flowline_run(huxg_out0, params, grid, bed, flowline)

# # wrong simulation
# fm_wrong = np.linspace(5, 45, params["NT"] + 1) / params["year"]
# params["facemelt"] = np.linspace(5, 45, params["NT"] + 1) / params["year"]

# # Run the flowline model
# huxg_out2 = flowline_run(huxg_out0, params, grid, bed, flowline)

# # Time steps array
# ts = np.linspace(0, params["TF"] / params["year"], params["NT"] + 1)

# # xg_truth and xg_wrong calculations
# xg_truth = np.concatenate(([huxg_out0[2 * params["NX"]]], huxg_out1[2 * params["NX"], :])) * params["xscale"]
# xg_wrong = np.concatenate(([huxg_out0[2 * params["NX"]]], huxg_out2[2 * params["NX"], :])) * params["xscale"]

# # Plotting the results
# plt.plot(ts, xg_truth / 1e3, lw=3, color='black', label="truth")
# plt.plot(ts, xg_wrong / 1e3, lw=3, color='red', label="wrong")
# plt.plot(ts, 250.0 * np.ones_like(ts), lw=1, color='black', linestyle='--', label="sill")

# # Add labels and legend
# plt.xlabel("Time (years)")
# plt.ylabel("xg (km)")
# plt.legend()
# plt.show()

