In [3]:
import sys
import os
sys.path.append(os.path.abspath(".."))

import jax 
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import pandas as pd
from gpkan.gpKAN import GPKAN
import scienceplots

plt.style.use([ 'science', "grid" ])

jax.config.update("jax_enable_x64", True)
key = jr.key(123)
px = 1/plt.rcParams['figure.dpi']
plt.rcParams.update({'axes.titlesize': 18})
plt.rcParams.update({'axes.labelsize': 15})

  torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)


### Load dataset

In [None]:
trollveggen_df = pd.read_csv('../datasets/troll3.csv')
print(trollveggen_df.shape)
trollveggen = jnp.array(trollveggen_df.values)
function_name = "trollveggen"

print(jnp.min(trollveggen[:, 0]), jnp.max(trollveggen[:, 0]))
print(jnp.min(trollveggen[:, 1]), jnp.max(trollveggen[:, 1]))
print(jnp.min(trollveggen[:, 2]), jnp.max(trollveggen[:, 2]))
print(jnp.unique(trollveggen[:, 0]).shape)
print(jnp.unique(trollveggen[:, 1]).shape)
print(trollveggen.shape)

In [None]:
plt.scatter(trollveggen[:, 0],trollveggen[:, 1], c=trollveggen[:, 2])
plt.tight_layout()
plt.show()

In [None]:
x_min, x_max = 7.55, 7.96
# y_min, y_max = 62.17125, 62.66875
y_min, y_max = 62.3, 62.66875

filtered_trollveggen = trollveggen[
    (trollveggen[:, 0] >= x_min) & (trollveggen[:, 0] <= x_max) &
    (trollveggen[:, 1] >= y_min) & (trollveggen[:, 1] <= y_max)
]

plt.scatter(filtered_trollveggen[:, 0],filtered_trollveggen[:, 1], c=filtered_trollveggen[:, 2])
plt.show()

In [None]:
print(jnp.unique(filtered_trollveggen[:, 0]).shape)
print(jnp.unique(filtered_trollveggen[:, 1]).shape)

In [None]:
x1 = jnp.sort(jnp.unique(filtered_trollveggen[:, 0]))  
x2 = jnp.sort(jnp.unique(filtered_trollveggen[:, 1]))  
y_grid = jnp.zeros((len(x2), len(x1)))

for i, row in enumerate(filtered_trollveggen):
    ix = jnp.where(x1 == row[0])[0][0]
    iy = jnp.where(x2 == row[1])[0][0]
    y_grid = y_grid.at[iy, ix].set(row[2]) 
print(x1.shape, x2.shape, y_grid.shape)

In [None]:
fig_data, ax_data = plt.subplots(figsize=(10, 7))
contour = ax_data.contour(x1, x2, y_grid, levels=15, colors="white", alpha=0.3)
contourf = ax_data.contourf(x1, x2, y_grid, levels=100)
cbar = fig_data.colorbar(contourf, ax=ax_data, label="Elevation (m)")
ax_data.set_xlabel("Longitude")
ax_data.set_ylabel("Latitude")
ax_data.set_title("Grand Canyon Elevation Contour Map")
plt.show()

In [None]:
# y = y_grid.flatten()
# # y_norm = ( y - jnp.mean(y) ) / (jnp.max(y) - jnp.min(y))
# # y_norm_grid = y_norm.reshape(89, 98)

# x1_norm = ( x1 - jnp.mean(x1) ) / (jnp.max(x1) - jnp.min(x1))
# x2_norm = ( x2 - jnp.mean(x2) ) / (jnp.max(x2) - jnp.min(x2))

# X1_norm, X2_norm = jnp.meshgrid(x1_norm, x2_norm)
# X_norm = jnp.column_stack((X1_norm.flatten(), X2_norm.flatten()))

In [None]:
# eps = 1e-6

# X1, X2 = jnp.meshgrid(x1, x2)
# X = jnp.column_stack((X1.flatten(), X2.flatten()))
# y_clean = y_grid.flatten().reshape(-1, 1)
# y = jnp.where(y_clean < 0, 0, y_clean) + eps
# y_log = jnp.log(y)
# # y = y_clean + eps
# X1_std = (X1.flatten() - jnp.mean(X1.flatten())) / jnp.std(X1.flatten())
# X2_std = (X2 - jnp.mean(X2)) / jnp.std(X2)
# X_std = jnp.column_stack((X1_std.flatten(), X2_std.flatten()))
# print(jnp.min(X1_std), jnp.max(X1_std))
# print(jnp.min(X2_std), jnp.max(X2_std))

In [None]:
eps = 1e-6
X1, X2 = jnp.meshgrid(x1, x2)
X = jnp.column_stack((X1.flatten(), X2.flatten()))
X_std = (X - jnp.mean(X, axis=0)) / jnp.std(X, axis=0)

y_clean = y_grid.flatten().reshape(-1, 1)
y = jnp.where(y_clean < 0, 0, y_clean) + eps
y_sqrt = jnp.sqrt(y)

print(jnp.min(X_std, axis=0), jnp.max(X_std, axis=0))
print(jnp.min(y_sqrt), jnp.max(y_sqrt))

### Model setup

In [None]:
from sklearn.model_selection import train_test_split

# Training-test 
X_train, X_test, y_train, y_test = train_test_split(
    X_std, y_sqrt, test_size=0.2, random_state=42
    )
y_train = y_train.reshape(-1, 1)
y_test = y_test.reshape(-1, 1)

In [None]:
model_size = "2-7-7-7-1"
model = GPKAN(layers=[2, 7, 7, 1], 
              n_grid_points=10, 
              grid_min=jnp.min(X_std), 
              grid_max=jnp.max(X_std), 
              init_paramters=[1.0, 1.0],
              obs_stddev=0.5
              )

def loss_ll(y_true, mean, covariance):
    diag_elements = jnp.diag(covariance)
    covariance_inv = jnp.diag(1.0 / diag_elements)
    log_det = jnp.sum(jnp.log(diag_elements))
    y_true = y_true.flatten()

    return -(-0.5 * (
        y_true.shape[0] * jnp.log(2 * jnp.pi) + 
        log_det + 
        (y_true - mean).T @ covariance_inv @ (y_true - mean)))

val_grad_loss = jax.value_and_grad(
    lambda Xs_latent, ys_latent, kernel_params, X_test, y_test:
        loss_ll(y_test,
                *model.sample_statistics(
                    Xs_latent, ys_latent, X_test, kernel_params, n_samples=10)
                ),
                argnums=(0, 1, 2)
)
val_grad_loss = jax.jit(val_grad_loss)

In [None]:
def get_learning_rate(epoch, initial_lr=0.0001):
    return initial_lr * (0.95 ** (epoch//50))
get_learning_rate = jax.jit(get_learning_rate)

def clip_gradients(grads, max_norm=1.0):
    grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree.leaves(grads)))
    clip_factor = jnp.minimum(1.0, max_norm / grad_norm)
    return jax.tree.map(lambda g: g * clip_factor, grads)
clip_gradients = jax.jit(clip_gradients)

In [None]:
import numpy as np
epochs = 500
# learning_rate = 0.0001
learning_rate = 0.01
loss_history = []

batch_size = 32
patience = 100  # Number of epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0

for epoch in range(epochs):
    epoch_losses = []
    current_lr = get_learning_rate(epoch, initial_lr=learning_rate)

    for i in range(0, X_train.shape[0], batch_size):
        batch_X = X_train[i:i+batch_size, :]
        batch_y = y_train[i:i+batch_size, :]

        loss, (grad_grids, grad_supports, grad_params) = val_grad_loss(
            model.latent_grids, 
            model.latent_supports,
            model.kernel_parameters,
            batch_X, batch_y
            )
        
        # Stop training if loss becomes negative
        if loss < 0 or jnp.isnan(loss):
            print(f"Stopping training at epoch {epoch} as loss became negative: {loss}")
            break 

        # grad_supports = clip_gradients(grad_supports)

        model.latent_supports = jax.tree.map(
            lambda latent_supports, grad_supports_: 
            latent_supports - grad_supports_ * current_lr,
            model.latent_supports,
            grad_supports
        )

        epoch_losses.append(loss)

    # Check for improvement
    epoch_loss = np.mean(epoch_losses)
    if loss < best_loss:
        best_loss = loss
        patience_counter = 0 
    else:
        patience_counter += 1
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}, best loss: {best_loss}")
        break

    loss_history.append(epoch_loss)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss: {epoch_loss:.6f}, LR: {current_lr:.6f}")
    loss_history.append(loss)

In [None]:
# learning_rate = 0.01
loss_history_params = []

batch_size = 32
patience = 100  # Number of epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0

for epoch in range(epochs):
    epoch_losses_params = []
    current_lr = get_learning_rate(epoch, initial_lr=learning_rate)

    for i in range(0, X_train.shape[0], batch_size):
        batch_X = X_train[i:i+batch_size, :]
        batch_y = y_train[i:i+batch_size, :]

        loss, (grad_grids, grad_supports, grad_params) = val_grad_loss(
            model.latent_grids, 
            model.latent_supports,
            model.kernel_parameters,
            batch_X, batch_y
            )
        
        # Stop training if loss becomes negative
        if loss < 0 or jnp.isnan(loss):
            print(f"Stopping training at epoch {epoch} as loss became negative: {loss}")
            break 

        grad_params = clip_gradients(grad_params)

        model.kernel_parameters = jax.tree.map(
            lambda kernel_params, grad_params_:
            kernel_params - grad_params_ * current_lr,
            model.kernel_parameters,
            grad_params 
        )

        epoch_losses_params.append(loss)

    # Check for improvement
    epoch_loss_params = np.mean(epoch_losses_params)
    if loss < best_loss:
        best_loss = loss
        patience_counter = 0 
    else:
        patience_counter += 1
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}, best loss: {best_loss}")
        break

    loss_history_params.append(epoch_loss_params)

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss: {epoch_loss_params:.6f}, LR: {current_lr:.6f}")
    loss_history.append(loss)

In [None]:
model.plot_neurons(
    # save_fig=True,
    save_fig=False,
    save_path=f"figs/function_predictions/{function_name}/"
    )

In [None]:
batch_size = 32 # Adjust as needed
n = X_std.shape[0]
mu_batches = []
cov_blocks = []

progress_interval = int(n * 0.05)
for i in range(0, n, batch_size):
    X_batch = X_std[i:i+batch_size]
    mu_batch, cov_batch = model.sample_statistics(
        model.latent_grids, model.latent_supports, X_batch, model.kernel_parameters, 5, key=jr.key(233 + i)
    )
    mu_batches.append(mu_batch)
    cov_blocks.append(cov_batch)
    if (i // batch_size) % (progress_interval // batch_size) == 0:
        percent = int(100 * i / n)
        print(f"{percent}% done predicting...")

mu_full = jnp.concatenate(mu_batches)
cov_full = jax.scipy.linalg.block_diag(*cov_blocks)
y_stddev = jnp.sqrt(jnp.diag(cov_full))

In [None]:
mu_full = jnp.concatenate(mu_batches)
cov_full = jax.scipy.linalg.block_diag(*cov_blocks)
y_stddev = jnp.sqrt(jnp.diag(cov_full))

In [None]:
import pandas as pd
# df_mu = pd.DataFrame(mu_full)
# df_cov = pd.DataFrame(cov_full)

# df_mu.to_csv(f"figs/function_predictions/trollveggen/mu_nogradclip_{model_size}.csv")
# df_cov.to_csv(f"figs/function_predictions/trollveggen/cov_nogradclip_{model_size}.csv")

df_mu = pd.read_csv("figs/function_predictions/trollveggen/mu_nogradclip.csv", index_col=False)
df_cov = pd.read_csv("figs/function_predictions/trollveggen/cov_nogradclip.csv", index_col=False)

mu_full = jnp.array(df_mu.values[:, 1:])
cov_full = jnp.array(df_cov.values[:, 1:])
y_stddev = jnp.sqrt(jnp.diag(cov_full))
print(mu_full.shape, cov_full.shape)

In [None]:
# mu, cov = model.sample_statistics(model.latent_grids, model.latent_supports, X_norm, model.kernel_parameters, 5, key=jr.key(233))
residuals = y_sqrt.flatten() - mu_full.flatten()
# y_sample = jr.multivariate_normal(jr.key(2341), mu, cov, shape=(1, )).flatten()

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(22, 5))

# Determine the min/max values across both datasets to create a consistent color scale
vmin = min(y_sqrt.min(), mu_full.min())
vmax = max(y_sqrt.max(), mu_full.max())

contourf_test = axs[0].contourf(x1, x2, y_sqrt.reshape(x2.shape[0], x1.shape[0]),  
                                levels=50,
                                cmap="viridis",
                                vmin=vmin,
                                vmax=vmax
                                )

axs[0].set_title("Actual Function")
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")
fig.colorbar(contourf_test, ax=axs[0])

contourf_pred = axs[1].contourf(x1, x2, mu_full.reshape(x2.shape[0], x1.shape[0]), 
                                cmap="viridis",
                                levels=50,
                                vmin=vmin,
                                vmax=vmax)
axs[1].set_title("Approximated Mean Function")
axs[1].set_xlabel("x")
axs[1].set_ylabel("y")
fig.colorbar(contourf_pred, ax=axs[1])

contourf_res = axs[2].contourf(x1, x2, 
                               jnp.abs(residuals).reshape(x2.shape[0], x1.shape[0]), 
                               levels=100, cmap="jet")
axs[2].set_title("Residuals")
axs[2].set_xlabel("x")
axs[2].set_ylabel("y_sqrt")
fig.colorbar(contourf_res, ax=axs[2])

contourf_var = axs[3].contourf(x1, x2, y_stddev.reshape(x2.shape[0], x1.shape[0]), 
                                levels=100, cmap="jet")
fig.colorbar(contourf_var, ax=axs[3])
axs[3].set_title("Standard deviation")

plt.show()

In [None]:
print(jnp.min(y_stddev), jnp.max(y_stddev))
print(jnp.min(mu_full), jnp.max(mu_full))

In [None]:
mad = jnp.mean(jnp.abs(residuals))
print("Mean Absolute Deviation (MAD):", mad)

pointwise_mad = jnp.concatenate([jnp.sqrt(jnp.diag(cov)) for cov in cov_blocks])
print("Pointwise Mean Absolute Deviation (from predictive stddev):", jnp.mean(pointwise_mad))

In [None]:
x1_std = jnp.sort(jnp.unique(X_std[:, 0]))
x2_std = jnp.sort(jnp.unique(X_std[:, 1])) 
print(jnp.min(x1_std), jnp.max(x1_std))
print(jnp.min(x2_std), jnp.max(x2_std))

In [None]:
fig2, ax2 = plt.subplots(figsize=(5,5))
q = ax2.contourf(x1_std, x2_std, pointwise_mad.reshape(x2_std.shape[0], x1_std.shape[0]), 
                                levels=50,
                                cmap="viridis",
                                )
fig2.colorbar(q, ax=ax2)
plt.show()

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(10, 10), constrained_layout=True)
axs = axs.flatten()
function_name = "åndalsnes"

# Determine the min/max values for the original function and mean
vmin = min(y_sqrt.min(), mu_full.min())
vmax = max(y_sqrt.max(), mu_full.max())
norm1 = plt.Normalize(vmin, vmax)

# First plot - Actual Function (unchanged)

contourf_test = axs[0].contourf(x1_std, x2_std, y_sqrt.reshape(x2_std.shape[0], x1_std.shape[0]), 
                                levels=50,
                                cmap="viridis",
                                vmin=vmin,
                                vmax=vmax
                                )

axs[0].set_title("Underlying data")
axs[0].set_xlabel("$x_1$")
axs[0].set_ylabel("$x_2$")

contourf_pred = axs[1].contourf(x1_std, x2_std, mu_full.reshape(x2_std.shape[0], x1_std.shape[0]), 
                                cmap="viridis",
                                levels=50,
                                vmin=vmin,
                                vmax=vmax
                                )
                                
axs[1].set_title("Approximated Mean Function")
axs[1].set_xlabel("$x_1$")
axs[1].set_ylabel("$x_2$")

from matplotlib.cm import ScalarMappable
import numpy as np
sm1 = ScalarMappable(cmap="viridis", norm=norm1)
sm1.set_array([])  # Empty array - using the norm instead
cbar_row1 = fig.colorbar(sm1, ax=[axs[0], axs[1]], location='right', shrink=0.98)
cbar_row1.set_label("Function value")

# Create evenly spaced ticks for the first colorbar
n_ticks = 9  # Number of ticks including min and max
ticks1 = np.linspace(vmin, vmax, n_ticks)
cbar_row1.set_ticks(ticks1)
cbar_row1.set_ticklabels([f"{tick:.2f}" for tick in ticks1])  # Format to 2 decimal places)  # Optional: ensure min/max are shown

# ----------------------------------------------------------------------------------------------------

# Third plot - Normalized Residuals (as percentage of mean)
reshaped_residuals = residuals.reshape(x2_std.shape[0], x1_std.shape[0])
epsilon = 1e-10
normalized_residuals = 100 * jnp.abs(reshaped_residuals.flatten()) / (jnp.abs(mu_full.flatten()) + epsilon)
normalized_stddev = 100 * (y_stddev.flatten() / (jnp.abs(mu_full.flatten()) + epsilon))

vmin_2 = min(normalized_residuals.min(), normalized_stddev.min())
vmax_2 = max(normalized_residuals.max(), normalized_stddev.max())
cbar_limit = 250 
levels = np.linspace(vmin_2, cbar_limit, 20)
# levels = np.arange(vmin_2, cbar_limit, 25)

contourf_std_res = axs[2].contourf(x1_std, x2_std, normalized_residuals.reshape(x2_std.shape[0], x1_std.shape[0]), 
                               levels=levels, 
                               cmap="jet",
                               vmin=vmin_2,
                               vmax=cbar_limit,
                               extend="max",
                               )

axs[2].set_title("Normalized Residuals")
axs[2].set_xlabel("$x_1$")
axs[2].set_ylabel("$x_2$")

contourf_std_var = axs[3].contourf(x1_std, x2_std, normalized_stddev.reshape(x2_std.shape[0], x1_std.shape[0]), 
                                # levels=50, 
                                levels=levels,
                                cmap="jet",
                                vmin=vmin_2,
                                vmax=cbar_limit,
                                extend="max",
                                )
                                
axs[3].set_title("Normalized Uncertainty")
# axs[3].set_title("Coefficient of Variation")
axs[3].set_xlabel("$x_1$")
axs[3].set_ylabel("$x_2$")

# norm2 = plt.Normalize(vmin_2, vmax_2)  # Create explicit normalization
norm2 = plt.Normalize(vmin_2, cbar_limit)  # Create explicit normalization
sm2 = ScalarMappable(cmap="jet", norm=norm2)
sm2.set_array([])  # Empty array - using the norm instead
cbar_row2 = fig.colorbar(sm2, ax=[axs[2], axs[3]], location='right', shrink=0.98, extend="max")
cbar_row2.set_label('Relative Error (\%)')

# Create evenly spaced ticks for the first colorbar
n_ticks = 9 # Number of ticks including min and max
# ticks2 = np.linspace(vmin_2, vmax_2, n_ticks)
# ticks2 = np.linspace(vmin_2, cbar_limit, n_ticks)
# ticks2 = np.linspace(0, cbar_limit, n_ticks)
ticks2 = np.arange(0, cbar_limit + 1, 25)
cbar_row2.set_ticks(ticks2)
# cbar_row2.set_ticklabels([f"{tick:.1f}" for tick in ticks2])
cbar_row2.set_ticklabels([f"{tick}" for tick in ticks2])

# fig.suptitle("Åndalsnes and Romsdalen ($\sqrt{y}$-transformed)", fontsize=25)
fig.suptitle("Åndalsnes and Romsdalen", fontsize=30, fontweight="bold")

# plt.subplots_adjust(top=0.88)
plt.savefig(f"figs/function_predictions/{function_name}/{function_name}_NORAposter.png", dpi=500, bbox_inches="tight")
plt.show()

In [None]:
max_residual_idx = jnp.argmax(normalized_residuals)
max_stddev_idx = jnp.argmax(normalized_stddev)
print("Index of max normalized_residuals:", max_residual_idx)
print("Index of max normalized_stddev:", max_stddev_idx)

In [None]:
print(normalized_residuals[max_residual_idx])
print(mu_full[max_residual_idx])
print(y_stddev[max_residual_idx])

In [None]:
plt.figure(figsize=(8, 6))
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test.flatten(), cmap='viridis', s=15)
plt.colorbar(label='y_test')
plt.xlabel('x1 (standardized)')
plt.ylabel('x2 (standardized)')
plt.title('Test Set Locations (Missing Data)')
plt.show()