**Import:** We import modules needed for this analysis, such as xarray.
**Import:** We import modules needed for this analysis, such as numpy.
**Import:** We import modules needed for this analysis, such as matplotlib.pyplot.
Instantiate the Gaussian Process model with specified kernel for interpolation.
Compute the mean squared error to quantify interpolation accuracy.

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.metrics import mean_squared_error

We load the NetCDF dataset containing SSH data using xarray.

In [None]:
path = 'C:/Users/ABHISHEK/OneDrive/Documents/SSH Data/ostst-single-layer-fd-lat-40-urms-5-kf-13-kr-4-beta.nc'

ds = xr.open_dataset(path)

In [None]:
ds

In [None]:
ssh = ds["ssh"]  

%matplotlib inline

We select a specific subset from the DataArray: e.g., a fixed location or time slice.
Convert the time coordinate to a pandas DatetimeIndex for easier datetime arithmetic.
Compute elapsed time in desired units (e.g., days) using numpy timedelta.
Reshape arrays to the format expected by scikit-learn (2D feature arrays).

In [None]:
ts = ssh.isel(x=0, y=0)         
times = ts["t"].to_index()              
days = (times - times[0]) / np.timedelta64(1, "D")
X_time = days.values.reshape(-1, 1)   
y_time = ts.values

#times = times / (86000 * 1e9) #scaling times
#times

In [None]:
print(" times dtype:", times.dtype)
print(" first five times:", times[:5])
print(" days array (first five):", days[:5])
print(" X_time shape & min/max:", X_time.shape, X_time.min(), X_time.max())

In [None]:
n = len(X_time)
m = int(n * 0.4)
idx = np.random.choice(n, size=m, replace=False)
mask = np.zeros(n, bool); mask[idx] = True

In [None]:
X_train = X_time[mask];    y_train = y_time[mask]
X_test  = X_time[~mask];   y_test  = y_time[~mask]

In [None]:
print(" n, m:", n, m)
print(" # train points:", mask.sum())
print(" # test points:", (~mask).sum())

In [None]:
ssh.values

In [None]:
times

Instantiate the Gaussian Process model with specified kernel for interpolation.
Fit the Gaussian Process to the training data.

In [None]:
kernel = 1.0 * RBF(length_scale=1.0) + WhiteKernel(noise_level=0.01)
gp_time = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=5)
gp_time.fit(X_train, y_train)


In [None]:
print(" Learned kernel:", gp_time.kernel_)
print(" Noise level :", gp_time.kernel_.k2.noise_level)

Perform predictions at the test points using the fitted GP.
Compute the mean squared error to quantify interpolation accuracy.

In [None]:
y_pred, y_std = gp_time.predict(X_test, return_std=True)
mse = mean_squared_error(y_test, y_pred)
print(f"Time-series interpolation MSE: {mse:.4f}")

Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.

In [None]:
plt.scatter(y_test, y_pred, s=25, alpha=0.7)
lims = [min(y_test.min(), y_pred.min()), max(y_test.max(), y_pred.max())]
plt.plot(lims, lims, "--", color="gray")
plt.xlabel("True SSH"); plt.ylabel("Predicted SSH")
plt.title("True vs. Predicted (Test Set)")
plt.show()

# Time series overlay
plt.plot(times, y_time, "k.", label="true")
plt.plot(X_test.flatten(), y_pred, "rx", label="interp")
plt.legend(); plt.title("Interpolation Overlay")
plt.show()

Convert the time coordinate to a pandas DatetimeIndex for easier datetime arithmetic.
We select a specific subset from the DataArray: e.g., a fixed location or time slice.

In [None]:
t0 = ssh["t"].to_index()[0]
field = ssh.sel(t=t0) 

In [None]:
#lats = field["x"].values
#lons = field["y"].values
#XX, YY = np.meshgrid(lons, lats)         # note: lon→x, lat→y
#coords  = np.vstack([XX.ravel(), YY.ravel()]).T
#values  = field.values.ravel()

In [None]:
lons = field["x"].values
lats = field["y"].values

# normalize to [0,1]
XXn = (XX - lons.min()) / (lons.max() - lons.min())
YYn = (YY - lats.min()) / (lats.max() - lats.min())

coords_norm = np.vstack([XXn.ravel(), YYn.ravel()]).T

In [None]:
n_pts = coords_norm.shape[0]
k     = int(n_pts * 0.1)
idx   = np.random.choice(n_pts, size=k, replace=False)

In [None]:
train_coords_norm = coords_norm[idx]
train_vals        = values[idx]

test_coords_norm  = np.delete(coords_norm, idx, axis=0)
test_vals         = np.delete(values,    idx, axis=0)

print(f"Train coords (norm): {train_coords_norm.shape}, Test coords (norm): {test_coords_norm.shape}")


Instantiate the Gaussian Process model with specified kernel for interpolation.
Fit the Gaussian Process to the training data.

In [None]:
from sklearn.gaussian_process.kernels import ConstantKernel
amp = np.var(train_vals)

kernel = (ConstantKernel(amp, (1e-3*amp, 1e3*amp))
          * RBF(length_scale=0.1, length_scale_bounds=(1e-2, 10)))+ WhiteKernel(noise_level=1e-4, noise_level_bounds=(1e-6, 1))

gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=5)
gp.fit(train_coords_norm, train_vals)
print("Learned kernel:", gp.kernel_)

Perform predictions at the test points using the fitted GP.
Compute the mean squared error to quantify interpolation accuracy.

In [None]:
pred_vals = batch_predict(gp, test_coords_norm, batch_size=2000)
mse       = mean_squared_error(test_vals, pred_vals)
print("Spatial interpolation MSE:", mse)

Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.
Use matplotlib to visualize the results.

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(test_vals, pred_vals, s=15, alpha=0.6)
lims = [test_vals.min(), test_vals.max()]
plt.plot(lims, lims, "--", color="gray")
plt.xlabel("True SSH")
plt.ylabel("Predicted SSH")
plt.title("True vs. Predicted (Spatial Interpolation)")
plt.show()