# Module

In [None]:
import os

import torch
import numpy as np
import pandas as pd

from time import time
from tqdm import tqdm

from scipy.interpolate import griddata

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.patches import Rectangle

# Coarse Tune

In [None]:
cond_alpha = lambda t: 1 - (1-eps_alpha)*t

# conditional sigma^2
cond_sigma_sq = lambda t: eps_beta + t * (1 - eps_beta)

# drift function of forward SDE
f = lambda t: -(1-eps_alpha) / cond_alpha(t)

# diffusion function of forward SDE
g_sq = lambda t: 1 - 2 * f(t) * cond_sigma_sq(t)
g = lambda t: np.sqrt(g_sq(t))

# generate sample with reverse SDE
def reverse_SDE(x0, score_likelihood=None, time_steps=100,
                drift_fun=f, diffuse_fun=g, alpha_fun=cond_alpha, sigma2_fun=cond_sigma_sq,  save_path=False):
    # x_T: sample from standard Gaussian
    # x_0: target distribution to sample from

    # reverse SDE sampling process
    # N1 = x_T.shape[0]
    # N2 = x0.shape[0]
    # d = x_T.shape[1]

    # Generate the time mesh
    dt = 1.0/time_steps

    # Initialization
    xt = torch.randn(ensemble_size, n_dim, device=device)
    t = 1.0

    # define storage
    if save_path:
        path_all = [xt]
        t_vec = [t]

    # forward Euler sampling
    for i in range(time_steps):
        # prior score evaluation
        alpha_t = alpha_fun(t)
        sigma2_t = sigma2_fun(t)

        # Evaluate the diffusion term
        diffuse = diffuse_fun(t)

        # Evaluate the drift term
        # drift = drift_fun(t)*xt - diffuse**2 * score_eval

        # Update
        if score_likelihood is not None:
            xt += - dt*( drift_fun(t)*xt + diffuse**2 * ( (xt - alpha_t*x0)/sigma2_t) - diffuse**2 * score_likelihood(xt, t) ) \
                  + np.sqrt(dt)*diffuse*torch.randn_like(xt)
        else:
            xt += - dt*( drift_fun(t)*xt + diffuse**2 * ( (xt - alpha_t*x0)/sigma2_t) ) + np.sqrt(dt)*diffuse*torch.randn_like(xt)

        # Store the state in the path
        if save_path:
            path_all.append(xt)
            t_vec.append(t)

        # update time
        t = t - dt

    if save_path:
        return path_all, t_vec
    else:
        return xt

# the lorenz drift
lorenz96_drift = lambda x: 2*x

# lorenz system
n_dim = 2500
SDE_sigma = 0.5

# filtering setup
dt = 0.05
filtering_steps = 30

# observation sigma
obs_sigma = 0.1

####################################################################
# EnSF setup

# ensemble size
ensemble_size = 100

# forward Euler step
euler_steps = 100

# damping function(tau(0) = 1;  tau(1) = 0;)
g_tau = lambda t: 1-t

# computation setting
torch.set_default_dtype(torch.float64) # half precision
device = torch.device('cuda')

eps_alpha_list, eps_beta_list = np.round(np.arange(0.1, 1, 0.1), 1), np.round(np.arange(0.1, 1, 0.1), 1)

tune_result = []

for eps_alpha in tqdm(eps_alpha_list):
    for eps_beta in eps_beta_list:
        ####################################################################
        ####################################################################
        # initial state
        angles = np.linspace(-2 * np.pi, 2 * np.pi, int(n_dim/2), endpoint=False)
        x = 1.00 * np.cos(angles)
        y = 1.00 * np.sin(angles)
        state_target = torch.tensor(np.vstack((x, y)).T.flatten(), device=device)

        # filtering initial ensemble
        x = 1.15 * np.cos(angles)
        y = 1.05 * np.sin(angles)
        x_prop = torch.tensor(np.vstack((x, y)).T.flatten(), device=device) # Initial set up

        x_state = x_prop.repeat(ensemble_size, 1) + 0.1 * torch.randn(ensemble_size, n_dim, device=device)

        torch.manual_seed(114514)
        torch.cuda.empty_cache()

        rmse_sf = []
        rmse_o = []

        for i in range(filtering_steps):
            x_prop += dt*lorenz96_drift(x_prop)
            # prediction step ############################################
            # state forward in time
            x_state += dt*lorenz96_drift(x_state) + np.sqrt(dt)*SDE_sigma*torch.randn_like(x_state)

            # ensemble prediction (Ground Truth)
            state_target += dt*lorenz96_drift(state_target) + np.sqrt(dt)*SDE_sigma*torch.randn_like(state_target)

            # update step ################################################
            # get observation
            # obs = torch.atan(state_target) + torch.randn_like(state_target)*obs_sigma
            obs = 0.25 * state_target + torch.randn_like(state_target) * obs_sigma

            # define likelihood score
            # obs: (d)
            # xt: (ensemble, d)
            # score_likelihood = lambda xt, t: -(torch.atan(xt) - obs)/obs_sigma**2 * (1./(1. + xt**2)) * g_tau(t)
            score_likelihood = lambda xt, t: -(0.25*xt - obs) / obs_sigma**2 * g_tau(t) * 0.25

            # generate posterior sample
            x_state = reverse_SDE(x0=x_state, score_likelihood=score_likelihood, time_steps=euler_steps)

            # get state estimates
            x_est = torch.mean(x_state, dim=0)

            # get rmse
            rmse_ensf = compute_rmse(x_est.reshape(-1,2).cpu().numpy(), state_target.reshape(-1,2).cpu().numpy())#torch.sqrt(torch.mean((x_est - state_target)**2)).item()
            rmse_ori = compute_rmse(x_est.reshape(-1,2).cpu().numpy(), (4*obs).reshape(-1,2).cpu().numpy())#torch.sqrt(torch.mean((x_prop - state_target)**2)).item()

            rmse_sf.append(rmse_ensf)
            rmse_o.append(rmse_ori)

            if x_state.device.type == 'cuda':
                torch.cuda.current_stream().synchronize() #Wait for all kernels in all streams on a CUDA device to complete.
            if rmse_ensf > 1000:
                print('diverge!')
                break

        tune_result.append([eps_alpha, eps_beta, np.mean(rmse_sf[1:]), np.mean(rmse_o[1:])])

In [None]:
data = pd.DataFrame(tune_result, columns = ['eps_alpha', 'eps_beta', 'rmse_sf','rmse_o'])

heatmap_data = data.pivot_table(index='eps_alpha', columns='eps_beta', values='rmse_sf')

def tune_heatmap(heatmap_data, title):
    plt.figure(figsize = (8*1.25, 7*1.25))
    ax = sns.heatmap(
        heatmap_data,
        annot=True,
        fmt=".3f",
        cmap='Blues_r', # 'viridis' or 'plasma' are good perceptually uniform options
        linewidths=1, # Slightly thicker lines
        linecolor='white', # White lines for good contrast
        annot_kws={"size": 16, "weight": "bold"} # Bold annotations
    )

    cbar = ax.collections[0].colorbar
    cbar.set_label(
        'RMSE - Haversine Distance (km)',
        size=20,
        weight='bold'
    )
    cbar.ax.tick_params(labelsize=14)  # Set tick font size

    # Find the 3 lowest values
    # We use a trick with unstacking the dataframe to easily sort the values
    sorted_vals = heatmap_data.unstack().sort_values()
    lowest_three = sorted_vals.head(3)

    # Add circles around the 3 lowest values
    for idx in lowest_three.index:
        # Get the row and column index
        col, row = idx
        col_idx = heatmap_data.columns.get_loc(col)
        row_idx = heatmap_data.index.get_loc(row)

        # Add a rectangle patch
        rect = Rectangle((col_idx, row_idx), 1, 1, color='red', linewidth=2.5, fill=False, zorder=5)
        ax.add_patch(rect)


    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel(r'$\epsilon_\beta$', fontsize=18)
    plt.ylabel(r'$\epsilon_\alpha$', fontsize=18)
    plt.tight_layout()
    plt.savefig(title, bbox_inches='tight')
    plt.show()
tune_heatmap(heatmap_data, 'vs_gt.pdf')
heatmap_data = data.pivot_table(index='eps_alpha', columns='eps_beta', values='rmse_o')
tune_heatmap(heatmap_data, 'vs_obs.pdf')

# Fine Tune

In [None]:
cond_alpha = lambda t: 1 - (1-eps_alpha)*t

# conditional sigma^2
cond_sigma_sq = lambda t: eps_beta + t * (1 - eps_beta)

# drift function of forward SDE
f = lambda t: -(1-eps_alpha) / cond_alpha(t)

# diffusion function of forward SDE
g_sq = lambda t: 1 - 2 * f(t) * cond_sigma_sq(t)
g = lambda t: np.sqrt(g_sq(t))

# damping function(tau(0) = 1;  tau(1) = 0;)
g_tau = lambda t: 1-t

# computation setting
torch.set_default_dtype(torch.float64) # half precision
device = torch.device('cuda')

eps_alpha_list, eps_beta_list = np.round(np.arange(0.8, 1, 0.01), 2), np.round(np.arange(0.01, 0.2, 0.01), 2)

tune_result = []

for eps_alpha in tqdm(eps_alpha_list):
    for eps_beta in eps_beta_list:
        ####################################################################
        ####################################################################
        # initial state
        angles = np.linspace(-2 * np.pi, 2 * np.pi, int(n_dim/2), endpoint=False)
        x = 1.00 * np.cos(angles)
        y = 1.00 * np.sin(angles)
        state_target = torch.tensor(np.vstack((x, y)).T.flatten(), device=device)

        # filtering initial ensemble
        x = 1.15 * np.cos(angles)
        y = 1.05 * np.sin(angles)
        x_prop = torch.tensor(np.vstack((x, y)).T.flatten(), device=device) # Initial set up

        x_state = x_prop.repeat(ensemble_size, 1) + 0.1 * torch.randn(ensemble_size, n_dim, device=device)

        torch.manual_seed(114514)
        torch.cuda.empty_cache()

        rmse_sf = []
        rmse_o = []

        for i in range(filtering_steps):
            x_prop += dt*lorenz96_drift(x_prop)
            # prediction step ############################################
            # state forward in time
            x_state += dt*lorenz96_drift(x_state) + np.sqrt(dt)*SDE_sigma*torch.randn_like(x_state)

            # ensemble prediction (Ground Truth)
            state_target += dt*lorenz96_drift(state_target) + np.sqrt(dt)*SDE_sigma*torch.randn_like(state_target)

            # update step ################################################
            # get observation
            # obs = torch.atan(state_target) + torch.randn_like(state_target)*obs_sigma
            obs = 0.25 * state_target + torch.randn_like(state_target) * obs_sigma

            # define likelihood score
            # obs: (d)
            # xt: (ensemble, d)
            # score_likelihood = lambda xt, t: -(torch.atan(xt) - obs)/obs_sigma**2 * (1./(1. + xt**2)) * g_tau(t)
            score_likelihood = lambda xt, t: -(0.25*xt - obs) / obs_sigma**2 * g_tau(t) * 0.25

            # generate posterior sample
            x_state = reverse_SDE(x0=x_state, score_likelihood=score_likelihood, time_steps=euler_steps)

            # get state estimates
            x_est = torch.mean(x_state, dim=0)

            # get rmse
            rmse_ensf = compute_rmse(x_est.reshape(-1,2).cpu().numpy(), state_target.reshape(-1,2).cpu().numpy())#torch.sqrt(torch.mean((x_est - state_target)**2)).item()
            rmse_ori = compute_rmse(x_est.reshape(-1,2).cpu().numpy(), (4*obs).reshape(-1,2).cpu().numpy())#torch.sqrt(torch.mean((x_prop - state_target)**2)).item()

            rmse_sf.append(rmse_ensf)
            rmse_o.append(rmse_ori)

            if x_state.device.type == 'cuda':
                torch.cuda.current_stream().synchronize() #Wait for all kernels in all streams on a CUDA device to complete.
            if rmse_ensf > 1000:
                print('diverge!')
                break

        tune_result.append([eps_alpha, eps_beta, np.mean(rmse_sf[1:]), np.mean(rmse_o[1:])])

In [None]:
data = pd.DataFrame([[1,1,5,8],[1,2,3,6],[2,1,2,2],[2,2,15,10]], columns = ['eps_alpha', 'eps_beta', 'rmse_sf','rmse_o'])

xi = np.linspace(data.eps_alpha.min(), data.eps_alpha.max(), 100)
yi = np.linspace(data.eps_beta.min(), data.eps_beta.max(), 100)
zi = griddata((data.eps_alpha, data.eps_beta), data.rmse_sf, (xi[None,:], yi[:,None]), method='cubic')

# 3. Find the coordinates of the best score
# This finds the point with the lowest RMSE to highlight it on the plot.
min_rmse_idx = np.argmin(data.rmse_sf)
min_alpha = data.eps_alpha[min_rmse_idx]
min_beta = data.eps_beta[min_rmse_idx]

# 4. Generate the plot
# Create the filled contour plot with a reversed 'viridis' colormap
# ('viridis_r') so that lower values (better RMSE) are darker.
fig, ax = plt.subplots(figsize=(7.5, 5))
sc = ax.contourf(xi, yi, zi, levels=15, cmap='viridis_r')

# Add a color bar to show the RMSE scale.
# Add the color bar with the label and fontsize
cbar = plt.colorbar(sc)
cbar.set_label('RMSE - Haversine Distance (km)', fontsize=18)

# You can also set the tick label size
cbar.ax.tick_params(labelsize=14)

# Add contour lines for better readability.
ax.contour(xi, yi, zi, levels=15, colors='white', alpha=0.5, linewidths=0.5)

# 5. Mark the optimal point
# Place a red star on the best hyperparameter combination.
ax.scatter(min_alpha, min_beta, marker = 'X', color = 'red', s = 100,
         label='Optimal Point\n'
               rf'$\epsilon_\alpha$ = {min_alpha}'
               '\n'
               rf'$\epsilon_\beta$ = {min_beta}'
               '\n'
               f'RMSE: {data.rmse_sf.min():.2f}'
               )

legend = ax.legend(frameon=True, loc='best', fontsize=14) # 'best' tries to find the least obstructive location
legend.get_frame().set_edgecolor('gray') # Add a light border to the legend

# 6. Add labels and a title for clarity
plt.xlabel(r'$\epsilon_\alpha$', fontsize=18)
plt.ylabel(r'$\epsilon_\beta$', fontsize=18)

# 7. Display the plot
plt.savefig('vs_mgt.pdf', bbox_inches='tight')
