In [None]:
from aux_functions import *
from sphere_vector_kernels import *
from sphere_vector_gp import *
from blender_file_generation import *

import matplotlib.pyplot as plt
import os
import netCDF4
import numpy as np
import pandas as pd
from tqdm import tqdm
import itertools as it
import pickle
import seaborn as sns
import re
import jax
import jax.numpy as jnp
import cartopy.crs as ccrs

from sklearn.gaussian_process.kernels import ConstantKernel, WhiteKernel, Matern, RBF, DotProduct

### Read data

Data is monthly averaged wind data at 500hP pressure level from January 2010 to December 2014. We will use each month as a separate experiment.

In [None]:
fpath = os.path.join("era5", "monthly_averaged_wind_500hP.nc")
data = netCDF4.Dataset(fpath,'r')

In [None]:
data.variables.keys()

In [None]:
data.variables['longitude']

In [None]:
n_times = data.variables['time'].shape[0]

In [None]:
lons = data.variables['longitude'][:].data
lats = data.variables['latitude'][:].data
lon_mesh, lat_mesh = np.meshgrid(lons, lats)

In [None]:
def read_data(time):
    u = data.variables['u'][time].data
    v = data.variables['v'][time].data
    df = pd.DataFrame({
        "lon": lon_mesh.flatten(),
        "lat": lat_mesh.flatten(),
        "u": u.flatten(),
        "v": v.flatten(),
    })
    return df

In [None]:
# train sets: an "orbit" or some points around the sphere
def orbit(df):
    lats_train = lats[::40]
    lons_train = np.array([90., 270.])
    orbit = df.query("lat in @lats_train and lon in @lons_train").reset_index(drop=True).copy()
    return orbit

def crystal(df):
    lat_lons = { # lat => lons
        -90: [0.],
        -45: lons[::180],
        0: lons[::90],
        45: lons[::180],
        90: [0.],
    }
    train_points = pd.concat([
        df.query(f"lat == @lat and lon in @lons").copy()
        for lat, lons in lat_lons.items()
    ]).reset_index(drop=True)
    return train_points

In [None]:
def train_test_sets(time, train_set):
    # read in data
    df = read_data(time)
    df_train = {
        "crystal": crystal,
        "orbit": orbit,
    }[train_set](df)
    df_test = test_set(df)
    # X sets, degrees to radians
    X_train, X_test = np.pi * df_train[["lat", "lon"]].to_numpy() / 180, np.pi * df_test[["lat", "lon"]].to_numpy() / 180
    # y sets, normalization
    y_train, y_test = df_train[["v", "u"]].to_numpy(), df_test[["v", "u"]].to_numpy()
    norm_constant = jax.vmap(jnp.linalg.norm)(y_train).mean()
    y_train /= norm_constant
    y_test /= norm_constant
    return X_train, X_test, y_train, y_test

In [None]:
# test set: input locations for blender, matched to closest point in the data

# for each of the inputs, match the closest point in the data

blender_folder = "blender-data"
out_folder = os.path.join("blender-data", "outputs")

mean_inputs = pd.read_csv(
    os.path.join(blender_folder, "input_locations.csv"),
    names=["x", "y", "z"]
).to_numpy()

std_inputs = pd.read_csv(
    os.path.join(blender_folder, "std_inputs.csv"),
    names=["x", "y", "z"]
).to_numpy()

all_points = sph_to_car(read_data(0)[["lat", "lon"]].to_numpy() * np.pi / 180)

@jax.jit
def _match_point_idx(input_car, all_points):
    idx = jnp.argmin(
        jax.vmap(lambda a, b: jnp.linalg.norm(a - b), in_axes=(None, 0))(input_car, all_points)
    )
    return idx

MATCH_IDXS = jax.vmap(_match_point_idx, in_axes=(0, None))(mean_inputs, all_points)

In [None]:
def test_set(df):
    test = df.iloc[MATCH_IDXS].copy()
    return test

In [None]:
time = 0

X_train, X_test, y_train, y_test = train_test_sets(time, train_set="orbit")

fig, ax = plt.subplots(figsize=(15, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
q = ax.quiver(X_test[:, 1] * 180 / np.pi, X_test[:, 0] * 180 / np.pi, y_test[:, 1], y_test[:, 0], angles="uv")
q._init()
ax.quiver(X_train[:, 1] * 180 / np.pi, X_train[:, 0] * 180 / np.pi, y_train[:, 1], y_train[:, 0], angles="uv", scale=q.scale, color="r")

X_train_car, y_train_car = v_sph_to_car(X_train, y_train)
X_test_car, y_test_car = v_sph_to_car(X_test, y_test)

np.savetxt(os.path.join(out_folder, f"ERA5_orbit_{time}__mercator__mean.csv"), np.hstack([X_test_car, y_test_car]), delimiter=",")
np.savetxt(os.path.join(out_folder, f"ERA5_orbit_{time}__mercator__observations.csv"), np.hstack([X_train_car, y_train_car]), delimiter=",")

In [None]:
time = 0

X_train, X_test, y_train, y_test = train_test_sets(time, train_set="crystal")

fig, ax = plt.subplots(figsize=(15, 10))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines()
q = ax.quiver(X_test[:, 1] * 180 / np.pi, X_test[:, 0] * 180 / np.pi, y_test[:, 1], y_test[:, 0], angles="uv")
q._init()
ax.quiver(X_train[:, 1] * 180 / np.pi, X_train[:, 0] * 180 / np.pi, y_train[:, 1], y_train[:, 0], angles="uv", scale=q.scale, color="r")

X_train_car, y_train_car = v_sph_to_car(X_train, y_train)
X_test_car, y_test_car = v_sph_to_car(X_test, y_test)

np.savetxt(os.path.join(out_folder, f"ERA5_crystal_{time}__mercator__mean.csv"), np.hstack([X_test_car, y_test_car]), delimiter=",")
np.savetxt(os.path.join(out_folder, f"ERA5_crystal_{time}__mercator__observations.csv"), np.hstack([X_train_car, y_train_car]), delimiter=",")

### Experiment utilities

In [None]:
def mse(y_true, y_pred):
    return jax.vmap(jax.jit(lambda a, b: jnp.linalg.norm(a - b)**2))(y_true, y_pred).mean()

def pred_nll(y_true, y_pred, std_pred):
    return -jax.vmap(jax.scipy.stats.multivariate_normal.logpdf)(y_true, y_pred, std_pred).mean()

def run_single_experiment(X_train, y_train, X_test, y_test, model, number, verbose=True, n_restarts_optimizer=0):
    name, k = model
    
    metrics = {}
    
    gp = SphereVectorGP(kernel=k, n_restarts_optimizer=0)

    gp.fit(X_train, y_train)
    if verbose:
        display(gp)
        display("MLL:", -gp.log_marginal_likelihood_value_ / X_train.shape[0])

    mu_star, std = gp.predict(X_test, return_std=True)
    metrics["name"] = name
    metrics["n"] = number
    metrics["fitted_gp"] = str(gp)
    metrics["MSE"] = float(mse(y_test, mu_star))
    metrics["PNLL"] = float(pred_nll(y_test, mu_star, std))
    if verbose:
        display(metrics)
    return metrics

In [None]:
def run_experiments(models, train_test, fname, n_experiments=n_times, n_restarts_optimizer=0, verbose=False):
    fpath = os.path.join("temp", fname)
    if os.path.exists(fpath):
        with open(fpath, 'rb')as f:
            results = pickle.load(f)
    else:
        results = {}
    try:
        for i, model in (pbar := tqdm(it.product(range(n_experiments), models), total=n_experiments * len(models))):
            name = model[0]
            if (i, name) in results:
                continue
            X_train, X_test, y_train, y_test = train_test(time=i)
            results[(i, name)] = run_single_experiment(X_train, y_train, X_test, y_test, model, i, verbose=verbose, n_restarts_optimizer=n_restarts_optimizer)
            with open(fpath, 'wb')as f:
                pickle.dump(results, f)
    finally:
        return pd.DataFrame(results.values())

### Kernels

In [None]:
factory = [ # (base_name, base_kernel, args)
    (r"div-free H.--M.--$\tfrac{1}{2}$", HodgeMaternDivFreeSphereKernel, {"nu": 0.5}),
    (r"Proj.~M.--$\tfrac{1}{2}$", ProjectedMaternSphereKernel, {"nu": 0.5}),
]

models = [
    ("Pure noise", WhiteKernel()),
    (r"Proj.", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * ProjectedSphereKernel(kappa=.2) + WhiteKernel()),
    (r"Hodge", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeSphereKernel(kappa=.2) + WhiteKernel()),
    (r"H.--M.--$\tfrac{1}{2}$", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeMaternSphereKernel(kappa=.2, nu=0.5) + WhiteKernel()),
    (r"div-free Hodge", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeDivFreeSphereKernel(kappa=.2) + WhiteKernel()),
    (r"div+curl Hodge", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeDivFreeSphereKernel(kappa=.2)
     + ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeCurlFreeSphereKernel(kappa=.2) + WhiteKernel()),
    (r"div+curl H.--M.--$\tfrac{1}{2}$", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeMaternDivFreeSphereKernel(kappa=.2, nu=0.5)
     + ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * HodgeMaternCurlFreeSphereKernel(kappa=.2, nu=0.5) + WhiteKernel()),
]

for base_name, base_kernel, kwargs in factory:
    models += [
        (base_name, ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * base_kernel(**kwargs) + WhiteKernel()),
        (base_name + r" $\kappa=0.5$", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * base_kernel(kappa=.5, kappa_bounds="fixed", **kwargs) + WhiteKernel()),
        (base_name + r" $\kappa=1.0$", ConstantKernel(constant_value_bounds=(1e-5, 1e8)) * base_kernel(kappa=1., kappa_bounds="fixed", **kwargs) + WhiteKernel()),
    ]

models = [
    (re.sub("\.\s", ".~", name), kernel) for name, kernel in models
]

In [None]:
def mse_and_pnll_table(df, n_drop=0, tex_fname=None):
    df = df.rename(columns={"name": "Kernel"})
    results = {}
    for col in ["MSE", "PNLL"]:
        if n_drop > 0:
            df_stats = df.groupby("Kernel").apply(
                lambda gp: gp.sort_values(col).iloc[:-n_drop]
            ).reset_index(drop=True)
        else:
            df_stats = df.copy()
        df_stats= df_stats[["Kernel", col]].groupby("Kernel")[col].describe()
        df_stats = df_stats.reset_index(drop=False)
        df_stats = df_stats.set_index("Kernel")[["mean", "std"]].rename(columns={"mean": "Mean", "std": "Std"})
        df_stats = df_stats.round(2)
        
        df_stats.columns = pd.MultiIndex(
            levels=[[col], ["Mean", "Std"]],
            codes=[[0, 0], [0, 1]],
            sortorder=None,
            names=None, dtype=None, copy=False, name=None, verify_integrity=True)
        
        results[col] = df_stats
        
    results = results["MSE"].join(results["PNLL"])
        
    s = results.style.highlight_min(
        axis=0, subset=[("MSE", "Mean"), ("PNLL", "Mean")], props='font-weight:bold;'
    )
    s = s.format(precision=2)
    
    latex_table = s.to_latex(hrules=True)
    latex_table = latex_table.replace("\\font-weightbold", "\\bf")

    if tex_fname is not None:
        with open(os.path.join("tables", f"{tex_fname}.tex"), "w") as f:
            f.write(latex_table)

    display(s)

In [None]:
def prediction_and_samples(kernel, time, train_set, kernel_name, n_samples=0):
    out_folder = os.path.join("blender-data", "outputs")
    
    X_train, X_test, y_train, y_test = train_test_sets(time, train_set=train_set)
    
    # plot ground truth
    fig, ax = plt.subplots()
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    ax.quiver(X_test[:, 1] * 180 / np.pi, X_test[:, 0] * 180 / np.pi, y_test[:, 1], y_test[:, 0], angles="uv")
    ax.set_title("Ground truth")
    plt.show()
    
    gp = SphereVectorGP(kernel=kernel)
    gp.fit(X_train, y_train)
    display(gp)
    # mean
    y_pred = gp.predict(X_test)
    
    _, y_pred_car = v_sph_to_car(X_test, y_pred)
    np.savetxt(os.path.join(out_folder, f"ERA5_{kernel_name}_{train_set}_{time}_pred__mercator__mean.csv"), np.hstack([mean_inputs, y_pred_car]), delimiter=",")
    
    # display predicted mean
    fig, ax = plt.subplots()
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.coastlines()
    ax.quiver(X_test[:, 1] * 180 / np.pi, X_test[:, 0] * 180 / np.pi, y_pred[:, 1], y_pred[:, 0], angles="uv")
    ax.set_title("Mean")
    plt.show()
    
    # uncertainty
    std_fname = os.path.join(out_folder, f"ERA5_{kernel_name}_{train_set}_{time}_pred__mercator__std.csv")
    if not os.path.exists(std_fname):
        std_inputs_sph = car_to_sph(std_inputs)
        _, std = gp.predict(std_inputs_sph, return_std=True, verbose=True)
        uncertainty = np.array([np.linalg.norm(cov) for cov in std])

        np.savetxt(std_fname, uncertainty, delimiter=",")
    
    if n_samples > 0:
        # samples
        y_samples = gp.sample_y(X_test, n_samples=n_samples)
        for i in range(y_samples.shape[2]):
            sample = y_samples[:, :, i]
            _, sample_car = v_sph_to_car(X_test, sample)

            np.savetxt(os.path.join(out_folder, f"ERA5_{kernel_name}_{train_set}_{time}_sample_{i}__mercator__mean.csv"), np.hstack([mean_inputs, sample_car]), delimiter=",")

            # display sample
            fig, ax = plt.subplots()
            ax = plt.axes(projection=ccrs.PlateCarree())
            ax.coastlines()
            ax.quiver(X_test[:, 1] * 180 / np.pi, X_test[:, 0] * 180 / np.pi, sample[:, 1], sample[:, 0], angles="uv")
            ax.set_title(f"Posterior sample {i + 1}")
    plt.show()

In [None]:
def _params(s):
    r_const = r"([\d\.]+e*\+*-*\d*)\*\*2"
    r_kappa = r"kappa=([\d\.]+e*\+*-*\d*)"
    r_noise = r"noise_level=([\d\.]+e*\+*-*\d*)"
    r_nu = r"nu=([\d\.]+e*\+*-*\d*)"
    
    m_const = re.search(r_const, s)
    m_kappa = re.search(r_kappa, s)
    m_noise = re.search(r_noise, s)
    m_nu = re.search(r_nu, s)
    return pd.Series({
        "constant": float(m_const.groups()[0])**2 if m_const is not None else np.NaN,
        "kappa": float(m_kappa.groups()[0]) if m_kappa is not None else np.NaN,
        "nu": float(m_nu.groups()[0]) if m_nu is not None else np.NaN,
        "noise_level": float(m_noise.groups()[0]) if m_noise is not None else np.NaN,
    })

def extract_parameters(df):
    df = df.copy()
    df[["constant", "kappa", "nu", "noise_level"]] = df.fitted_gp.apply(_params)
    return df

# Train on orbit

In [None]:
df_orbit = run_experiments(
    models=models,
    train_test=lambda time: train_test_sets(time, train_set="orbit"),
    fname="clean_era5_orbit.pickle",
    n_experiments=12,
    n_restarts_optimizer=0,
    verbose=False
)

In [None]:
df_orbit = df_orbit.query(f"name in {[name for name, _ in models]}")

In [None]:
original_orbit_experiment = df_orbit.query("not name.str.contains('kappa')").copy()
further_orbit_experiments = df_orbit.query("name.str.contains('kappa')").copy()
further_orbit_experiments["name"] = further_orbit_experiments["name"].apply(
    lambda s: ("H.--M. " if s.startswith("d") else "Proj.~M. ") + s.split(" ")[-1]
)

mse_and_pnll_table(original_orbit_experiment, n_drop=0, tex_fname="era5_orbit")
mse_and_pnll_table(further_orbit_experiments, n_drop=0, tex_fname="era5_orbit_fixed_kappa")

In [None]:
list(original_orbit_experiment.query("fitted_gp.str.contains('HodgeMaternDiv')").fitted_gp)

### Study parameters

In [None]:
def _single_kernel_param(k):
    # regex
    r_const = r"([\d\.]+e*\+*-*\d*)\*\*2"
    r_kappa = r"kappa=([\d\.]+e*\+*-*\d*)"
    r_noise = r"noise_level=([\d\.]+e*\+*-*\d*)"
    r_nu = r"nu=([\d\.]+e*\+*-*\d*)"
    
    m_const = re.search(r_const, k)
    m_kappa = re.search(r_kappa, k)
    m_noise = re.search(r_noise, k)
    m_nu = re.search(r_nu, k)
    
    noise_sigma = float(m_noise.groups()[0]) if m_noise else None
    kappa = float(m_kappa.groups()[0]) if m_kappa else None
    constant = float(m_const.groups()[0])**2 if m_const else None
    nu = float(m_nu.groups()[0]) if m_nu else None
    
    # white noise
    if "WhiteKernel" in k:
        return {"noise_sigma": noise_sigma}
    elif "ProjectedSphereKernel" in k:
        # normalization
        kernel = ProjectedSphereKernel(kappa=kappa)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"p_sigma": np.sqrt(sigma_squared), "p_kappa": kappa}
    elif "ProjectedMaternSphereKernel" in k:
        # normalization
        kernel = ProjectedMaternSphereKernel(kappa=kappa, nu=nu)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"pm_sigma": np.sqrt(sigma_squared), "pm_kappa": kappa, "pm_nu": nu}
    elif "HodgeSphereKernel" in k:
        # normalization
        kernel = HodgeSphereKernel(kappa=kappa)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"h_sigma": np.sqrt(sigma_squared), "h_kappa": kappa}
    elif "HodgeMaternSphereKernel" in k:
        # normalization
        kernel = HodgeMaternSphereKernel(kappa=kappa, nu=nu)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"hm_sigma": np.sqrt(sigma_squared), "hm_kappa": kappa, "hm_nu": nu}
    elif "HodgeDivFreeSphereKernel" in k:
        # normalization
        kernel = HodgeDivFreeSphereKernel(kappa=kappa)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"hdf_sigma": np.sqrt(sigma_squared), "hdf_kappa": kappa}
    elif "HodgeMaternDivFreeSphereKernel" in k:
        # normalization
        kernel = HodgeMaternDivFreeSphereKernel(kappa=kappa, nu=nu)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"hmdf_sigma": np.sqrt(sigma_squared), "hmdf_kappa": kappa, "hmdf_nu": nu}
    elif "HodgeCurlFreeSphereKernel" in k:
        # normalization
        kernel = HodgeCurlFreeSphereKernel(kappa=kappa)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"hcf_sigma": np.sqrt(sigma_squared), "hcf_kappa": kappa}
    elif "HodgeMaternCurlFreeSphereKernel" in k:
        # normalization
        kernel = HodgeMaternCurlFreeSphereKernel(kappa=kappa, nu=nu)
        norm_const = np.trace(kernel(np.array([0., 0.])))
        sigma_squared = constant * norm_const
        return {"hmcf_sigma": np.sqrt(sigma_squared), "hmcf_kappa": kappa, "hmcf_nu": nu}
    else:
        raise NotImplementedError(k)

def _params(s):
    parameters = {}
    # get kernel string
    s = re.search(r"SphereVectorGP\((.*)\)", s).groups()[0]
    for k in s.split(" + "):
        parameters.update(_single_kernel_param(k))
    parameters = pd.Series(parameters)
    return parameters
    

def extract_parameters(df):
    df = df.copy()
    df = pd.concat([df, df.fitted_gp.apply(_params)], axis=1)
    return df

In [None]:
parameters = extract_parameters(df_orbit)
parameters = parameters
parameters

In [None]:
with pd.option_context('display.max_rows', 100 , 'display.max_columns', 20):
    display(
        parameters
        .sort_values(["n", "name"])
        .query("name.str.contains('div-free') and not name.str.contains('kappa')")
        # .query("name.str.contains('M.')")
        .dropna(how='all', axis=1)
        .round(3)
        .astype(str)
        .replace("nan", "")
    )

### Generate some samples

In [None]:
prediction_and_samples(
    kernel=ConstantKernel(1.**2, constant_value_bounds=(1e-5, 1e8)) * HodgeMaternDivFreeSphereKernel(kappa=.5, nu=.5, kappa_bounds="fixed") + WhiteKernel(),
    time=0, train_set="orbit", kernel_name="dfhm12_kappa=0.5", n_samples=10
)

In [None]:
prediction_and_samples(
    kernel=ConstantKernel(1.**2, constant_value_bounds=(1e-5, 1e8)) * ProjectedMaternSphereKernel(kappa=.5, nu=.5, kappa_bounds="fixed") + WhiteKernel(),
    time=0, train_set="orbit", kernel_name="projm12_kappa=0.5", n_samples=10
)