In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import optuna
from optuna.samplers import TPESampler

from pykrige.ok import OrdinaryKriging

from src.datasets.vitae_dataset import load_data
from src.utils.evaluation import compute_all_metrics, compute_relative_error

seed = 42

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def kriging_interpolate_image(
    img: np.ndarray,
    variogram_model: str,
    variogram_parameters: dict,
    nlags: int,
    weight: bool,
    anisotropy_scaling: float,
    anisotropy_angle: float,
    exact_values: bool
) -> np.ndarray:
    h, w = img.shape
    y_idx, x_idx = np.where(img != 0)
    values = img[y_idx, x_idx]

    # if too few points, skip
    if len(values) < 3:
        return img.copy()

    # build the Ordinary Kriging model with hyperparameters
    OK = OrdinaryKriging(
        x_idx.astype(float),
        y_idx.astype(float),
        values.astype(float),
        variogram_model=variogram_model,
        variogram_parameters=variogram_parameters,
        nlags=nlags,
        weight=weight,
        anisotropy_scaling=anisotropy_scaling,
        anisotropy_angle=anisotropy_angle,
        exact_values=exact_values,
        verbose=False,
        enable_plotting=False
    )

    grid_x = np.arange(w, dtype=float)
    grid_y = np.arange(h, dtype=float)
    z_interp, _ = OK.execute("grid", grid_x, grid_y)
    return z_interp


def kriging_interpolate_tensor(
    tensor_np: np.ndarray,
    variogram_model: str,
    variogram_parameters: dict,
    nlags: int,
    weight: bool,
    anisotropy_scaling: float,
    anisotropy_angle: float,
    exact_values: bool
) -> np.ndarray:
    out = np.zeros_like(tensor_np)
    for i in range(tensor_np.shape[0]):      # time
        for c in range(tensor_np.shape[1]):  # channel
            out[i, c] = kriging_interpolate_image(
                tensor_np[i, c],
                variogram_model,
                variogram_parameters,
                nlags,
                weight,
                anisotropy_scaling,
                anisotropy_angle,
                exact_values
            )
    return out

In [3]:
from src.datasets.real_obs_dataset import load_data as load_real

In [4]:
dataset, stats = load_real("vitae", 1)

In [5]:
obs = torch.stack([obs for obs, _, _ in dataset])
gts = torch.stack([gt for _, gt, _ in dataset])
mask = torch.stack([mask for _, _, mask in dataset])

In [6]:
best_params = {'variogram_model': 'linear', 'slope': 2.0506246505745622, 'nugget': 0.24566496231367718, 'nlags': 4, 'weight': True, 'anisotropy_scaling': 1.527749730358887, 'anisotropy_angle': 179.58372854952466, 'exact_values': False}


pred = kriging_interpolate_tensor(
            obs.numpy(),
            best_params.get("variogram_model", None),
            best_params.get("variogram_parameters", None),
            best_params.get("nlags", None),
            best_params.get("weight", None),
            best_params.get("anisotropy_scaling", None),
            best_params.get("anisotropy_angle", None),
            best_params.get("exact_values", None)
        )

KeyboardInterrupt: 

In [None]:
err = compute_relative_error(gts * mask, torch.from_numpy(pred) * mask)

In [None]:
print(np.mean(err))

1.0062051776445304


### Finding the pers parameters for the algorithms

In [None]:
def objective(trial):

    train_dataset, _, _, _ = load_data(sensor_type="real", combine_train_val=True)

    train_obs = torch.stack([obs for obs, _, _ in train_dataset]).numpy()
    train_gts = torch.stack([gt for _, gt, _ in train_dataset])

    # Define the hyperparameters here

    variogram_model = trial.suggest_categorical(
        "variogram_model", ["linear", "power", "gaussian", "spherical", "exponential"]
    )

    if variogram_model == "power":
        # power: scale, exponent, nugget
        scale   = trial.suggest_float("scale",   1e-3, float(np.nanstd(train_obs)), log=True)
        exponent= trial.suggest_float("exponent", 0.1, 2.0)
        nugget  = trial.suggest_float("nugget",  0.0, scale * 0.5)
        variogram_parameters = {
            "scale": scale,
            "exponent": exponent,
            "nugget": nugget
        }

    elif variogram_model == "linear":
        # linear: slope, nugget
        slope   = trial.suggest_float("slope",   0.1, 10.0, log=True)
        nugget  = trial.suggest_float("nugget",  0.0, slope * 0.5)
        variogram_parameters = {
            "slope": slope,
            "nugget": nugget
        }

    else:
        # gaussian / spherical / exponential: range, sill, nugget
        v_range = trial.suggest_float("range",  1e-1, max(train_obs.shape), log=True)
        sill    = trial.suggest_float("sill",   1e-3, float(np.nanstd(train_obs)), log=True)
        nugget  = trial.suggest_float("nugget", 0.0, sill * 0.5)
        variogram_parameters = {
            "range": v_range,
            "sill": sill,
            "nugget": nugget
        }

    nlags = trial.suggest_int("nlags", 2, 20)
    weight = trial.suggest_categorical("weight", [True, False])
    anisotropy_scaling = trial.suggest_float("anisotropy_scaling", 0.1, 10.0, log=True)
    anisotropy_angle = trial.suggest_float("anisotropy_angle", 0.0, 180.0)
    exact_values = trial.suggest_categorical("exact_values", [True, False])

    # Call the kriging here

    pred = kriging_interpolate_tensor(
        train_obs,
        variogram_model,
        variogram_parameters,
        nlags,
        weight,
        anisotropy_scaling,
        anisotropy_angle,
        exact_values
    )

    # Evaluate the performance

    relative_errs = compute_relative_error(train_gts, torch.from_numpy(pred))

    return np.mean(relative_errs)

In [None]:
study = optuna.create_study(study_name="Optimizing Kriging", direction="minimize", sampler=TPESampler(seed=seed))
study.optimize(objective, n_trials=100)

print("Best trial:")
print(study.best_trial.params)

### Randomly sampled sensors

In [None]:
sensor_numbers = [5, 10, 15, 20, 25, 30]

preds = []
errors = []
ssims = []
psnrs = []
local_errors = []

for sensor_number in sensor_numbers:
    _, _, test_dataset, _ = load_data(sensor_type="random", sensor_number=sensor_number)

    obs = torch.stack([obs for obs, _, _ in test_dataset])
    gts = torch.stack([gt for _, gt, _ in test_dataset])

    pred = kriging_interpolate_tensor(obs.numpy())

    error, ssim, psnr, local_error = compute_all_metrics(gts, torch.from_numpy(pred))

    preds.append(pred)
    errors.append(error)
    ssims.append(ssim)
    psnrs.append(psnr)
    local_errors.append(local_error)

    preds_dir = f"results/predictions/kriging"
    os.makedirs(preds_dir, exist_ok=True)
    pred_file = os.path.join(preds_dir, f"random_random_{sensor_number}_predictions.npz")

    np.savez_compressed(
        pred_file,
        observations=obs.numpy(),
        ground_truth=gts.numpy(),
        predictions=np.array(pred),
        errors=np.array(error),
        ssim=np.array(ssim),
        psnr=np.array(psnr),
        local_errors=local_error.numpy(),
    )

    print(f"Finished Kriging with {sensor_number} sensors.")