In [None]:
# Global variables
import torch
import inr_src as inr
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from scipy import interpolate
from sklearn.metrics import mean_squared_error, mean_absolute_error


gpu = torch.cuda.is_available()
device = "cuda" if gpu else "cpu"
tdevice = torch.device(device)
# project variables
opt = inr.AttrDict()
opt.name = "wires_notebook_unnormalised__test546"
# model meta data
npz = np.load(f"meta/{opt.name}.npz")
model_hp = inr.AttrDict(npz)
model_hp = inr.util_train.clean_hp(model_hp)
# data path
path = "./data/test_data.npy"
path_coherence = "./data/coherence.npy"

In [None]:
# load data
xytz_ds = inr.XYTZ(
        path,
        train_fold=False,
        train_fraction=0.0,
        seed=42,
        pred_type="pc",
        nv=tuple(model_hp.nv),
        nv_targets=tuple(model_hp.nv_target),
        normalise_targets=model_hp.normalise_targets,
        gpu=gpu
    )
coherence = np.load(path_coherence)


In [None]:
# Or if you prefer to load the model
## From saved
weights = f"meta/{opt.name}.pth"

model = inr.ReturnModel(
    model_hp.input_size,
    output_size=model_hp.output_size,
    arch=model_hp.architecture,
    args=model_hp,
)
print(f"loading weight: {weights}")
print(f"Model_hp: {model_hp}")
model.load_state_dict(torch.load(weights, map_location=tdevice))

In [None]:
prediction = inr.predict_loop(xytz_ds, 2048, model, device=device, verbose=False)

def scale(x):
    s = model_hp.nv_target[0,1]
    m = model_hp.nv_target[0,0]
    return (x + m)*s 
for val in [0, 0.7, 0.8, 0.95]:
    n = prediction.shape[0]
    if val == 0:
        idx = range(n)
    else:
        idx = np.where(coherence > val)[0]
    mse_norm = mean_squared_error(xytz_ds.targets[idx], prediction[idx])
    mae_norm = mean_absolute_error(xytz_ds.targets[idx], prediction[idx])
    print(f"RMSE: {scale(mse_norm):.3f} MAE: {scale(mae_norm):.3f} Size: {n}, coherence > {val}")


In [None]:


q33 = np.quantile(xytz_ds.samples[:,2], 0.33)
q66 = np.quantile(xytz_ds.samples[:,2], 0.66)

idx_0 = xytz_ds.samples[:,2] < q33
idx_1 = (xytz_ds.samples[:,2] >= q33) & (xytz_ds.samples[:,2] < q66)
idx_2 = (xytz_ds.samples[:,2] >= q66)

for idx, title in [(idx_0, "Jan."), (idx_1, "Feb."), (idx_2, "Mar.")]:
    samples = xytz_ds.samples[idx] * model_hp.nv[:,1] + model_hp.nv[:,0]
    pred = prediction[idx,0] * model_hp.nv_target[0,1] + model_hp.nv_target[0,0]

    idx = np.random.choice(np.arange(samples.shape[0]), replace=False, size=int(1e5))
    fig = px.scatter_3d(x=samples[idx,0], y=samples[idx,1], z=pred[idx],
                color=samples[idx,2])
    fig.update_layout(title=title, legend_title_text="Time")
    fig.update_traces(marker_size=2)
    fig.show()
