In [None]:
# import torch
# import matplotlib.pyplot as plt
# import numpy as np

# from evaluator import Evaluator
# from trainer import Trainer
# from utils import ax_plot, mean_squared_error, crps_norm

# def process_real_data_for_mogp(real_data, keys):
#     """
#     Processes real data for MOGP model input.
#     """
#     dim = len(keys)
#     sample = torch.tensor(real_data[0], dtype=torch.float32)

#     train_x = []
#     train_y = []
#     train_idx = []
#     test_x = []
#     test_y = []
#     test_idx = []

#     for i in range(dim):
#         train_xi = torch.where((sample[:, dim + i].int() == 1))[0].long()
#         train_yi = sample[:, i][train_xi].float()
#         train_idx_i = torch.zeros_like(train_xi, dtype=torch.long)

#         test_xi = torch.where((sample[:, dim + i].int() == 0))[0].long()
#         test_yi = sample[:, i][test_xi].float()
#         test_idx_i = torch.zeros_like(test_xi, dtype=torch.long)

#         train_x.append(train_xi.float())
#         train_y.append(train_yi)
#         train_idx.append(train_idx_i)

#         test_x.append(test_xi.float())
#         test_y.append(test_yi)
#         test_idx.append(test_idx_i)

#     return train_x, train_idx, train_y, test_x, test_y, test_idx

# def run_mogp_on_real_data(real_data, keys, output_file="coverage_test.npz"):
#     """
#     Runs MOGP on the real data, plots the results, saves predictions and metrics.
#     """
#     train_x, train_idx, train_y, test_x, test_y, test_idx = process_real_data_for_mogp(real_data, keys)

#     # Train the model
#     full_train_x = torch.cat(train_x)
#     full_train_idx = torch.cat(train_idx)
#     full_train_y = torch.cat(train_y)

#     trainer = Trainer(full_train_x, full_train_idx, full_train_y, likelihood='gaussian', iterations=1000, max_early_stop=30)
#     model, likelihood = trainer.train_model()

#     evaluator = Evaluator(model, likelihood, test_x, test_idx, len(keys))
#     observed_preds = evaluator.evaluate()

#     # Create a plot for each key (wavelength)
#     fig, axes = plt.subplots(len(keys), 1, figsize=(8, 2 * len(keys)), sharex=True)
#     if len(keys) == 1:
#         axes = [axes]

#     saved_data = {}
#     mse_per_feature = []
#     crps_per_feature = []

#     for i, key in enumerate(keys):
#         train_xi = train_x[i].detach().numpy()
#         train_yi = train_y[i].detach().numpy()
#         test_xi = test_x[i].detach().numpy()
#         test_yi = test_y[i].detach().numpy()

#         means = observed_preds[i].mean.cpu().detach().numpy()
#         lower, upper = observed_preds[i].confidence_region()
#         lower, upper = lower.cpu().detach().numpy(), upper.cpu().detach().numpy()

#         # Convert CI to sigma assuming Gaussian predictive distribution
#         sigma = (upper - means) / 2.0

#         # --- Metrics ---
#         mse_val = mean_squared_error(test_yi, means)
#         crps_val = crps_norm(test_yi, means, sigma)

#         mse_per_feature.append(mse_val)
#         crps_per_feature.append(crps_val)

#         print(f"[{key}] MSE: {mse_val:.6f}, CRPS: {crps_val:.6f}")

#         # Save results for this wavelength
#         saved_data[f"{key}_train_x"] = train_xi
#         saved_data[f"{key}_train_y"] = train_yi
#         saved_data[f"{key}_test_x"] = test_xi
#         saved_data[f"{key}_test_y"] = test_yi
#         saved_data[f"{key}_predicted_means"] = means
#         saved_data[f"{key}_lower_bound"] = lower
#         saved_data[f"{key}_upper_bound"] = upper
#         saved_data[f"{key}_sigma"] = sigma

#         # Plot
#         ax_plot(axes[i], train_yi, train_xi, test_yi, test_xi, means, lower, upper, label=key)

#     # Aggregate totals
#     mse_total = np.nanmean(mse_per_feature)
#     crps_total = np.nanmean(crps_per_feature)

#     print("MSE per feature:", mse_per_feature)
#     print("MSE total:", mse_total)
#     print("CRPS per feature:", crps_per_feature)
#     print("CRPS total:", crps_total)

#     saved_data["mse_per_feature"] = mse_per_feature
#     saved_data["mse_total"] = mse_total
#     saved_data["crps_per_feature"] = crps_per_feature
#     saved_data["crps_total"] = crps_total

#     axes[-1].set_xlabel("Timesteps")
#     plt.tight_layout()
#     plt.show()

#     np.savez(output_file, **saved_data)
#     print(f"Results + metrics saved to {output_file}")


# # Example Usage
# real_data = np.load('../Analysis/coverage_test_data_mogp.npz')['real_data']
# run_mogp_on_real_data(real_data, ["X", 'NIR', "IR", "Sub-mm"])

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.stats import norm

from evaluator import Evaluator
from trainer import Trainer
from utils import ax_plot, mean_squared_error, crps_norm


def compute_pit_gaussian(y_true, mean_pred, std_pred):
    """Compute PIT values assuming Gaussian predictive distribution."""
    return norm.cdf(y_true, loc=mean_pred, scale=std_pred)


def update_pit_file(pit_file, pit_data):
    """Append or create a PIT file (pit_values.npz)."""
    if os.path.exists(pit_file):
        existing = dict(np.load(pit_file))
        existing.update(pit_data)
        np.savez(pit_file, **existing)
        print(f"Updated existing PIT file: {pit_file}")
    else:
        np.savez(pit_file, **pit_data)
        print(f"Created new PIT file: {pit_file}")


def process_real_data_for_mogp(real_data, keys):
    """
    Processes real data for MOGP model input.
    """
    dim = len(keys)
    sample = torch.tensor(real_data[0], dtype=torch.float32)

    train_x, train_y, train_idx = [], [], []
    test_x, test_y, test_idx = [], [], []

    for i in range(dim):
        train_xi = torch.where((sample[:, dim + i].int() == 1))[0].long()
        train_yi = sample[:, i][train_xi].float()
        train_idx_i = torch.zeros_like(train_xi, dtype=torch.long)

        test_xi = torch.where((sample[:, dim + i].int() == 0))[0].long()
        test_yi = sample[:, i][test_xi].float()
        test_idx_i = torch.zeros_like(test_xi, dtype=torch.long)

        train_x.append(train_xi.float())
        train_y.append(train_yi)
        train_idx.append(train_idx_i)

        test_x.append(test_xi.float())
        test_y.append(test_yi)
        test_idx.append(test_idx_i)

    return train_x, train_idx, train_y, test_x, test_y, test_idx


def run_mogp_on_real_data(real_data, keys, output_file="coverage_test.npz", pit_file="pit_values.npz"):
    """
    Runs MOGP on the real data, plots results, saves predictions, metrics, and PIT values.
    """
    train_x, train_idx, train_y, test_x, test_y, test_idx = process_real_data_for_mogp(real_data, keys)

    # --- Train the model ---
    full_train_x = torch.cat(train_x)
    full_train_idx = torch.cat(train_idx)
    full_train_y = torch.cat(train_y)

    trainer = Trainer(full_train_x, full_train_idx, full_train_y,
                      likelihood='gaussian', iterations=1000, max_early_stop=30)
    model, likelihood = trainer.train_model()

    evaluator = Evaluator(model, likelihood, test_x, test_idx, len(keys))
    observed_preds = evaluator.evaluate()

    # --- Setup ---
    fig, axes = plt.subplots(len(keys), 1, figsize=(8, 2 * len(keys)), sharex=True)
    if len(keys) == 1:
        axes = [axes]

    saved_data = {}
    pit_data = {}
    mse_per_feature = []
    crps_per_feature = []

    # --- Per-feature analysis ---
    for i, key in enumerate(keys):
        train_xi = train_x[i].detach().numpy()
        train_yi = train_y[i].detach().numpy()
        test_xi = test_x[i].detach().numpy()
        test_yi = test_y[i].detach().numpy()

        means = observed_preds[i].mean.cpu().detach().numpy()
        lower, upper = observed_preds[i].confidence_region()
        lower, upper = lower.cpu().detach().numpy(), upper.cpu().detach().numpy()

        # Approximate predictive standard deviation
        sigma = (upper - means) / 2.0

        # --- Metrics ---
        mse_val = mean_squared_error(test_yi, means)
        crps_val = crps_norm(test_yi, means, sigma)
        mse_per_feature.append(mse_val)
        crps_per_feature.append(crps_val)

        # --- PIT computation ---
        pit_values = compute_pit_gaussian(test_yi, means, sigma)
        pit_data[f"{key}"] = pit_values

        print(f"[{key}] MSE: {mse_val:.6f}, CRPS: {crps_val:.6f}, PIT mean: {np.mean(pit_values):.4f}")

        # --- Save per-wavelength results ---
        saved_data[f"{key}_train_x"] = train_xi
        saved_data[f"{key}_train_y"] = train_yi
        saved_data[f"{key}_test_x"] = test_xi
        saved_data[f"{key}_test_y"] = test_yi
        saved_data[f"{key}_predicted_means"] = means
        saved_data[f"{key}_lower_bound"] = lower
        saved_data[f"{key}_upper_bound"] = upper
        saved_data[f"{key}_sigma"] = sigma

        # --- Plot ---
        ax_plot(axes[i], train_yi, train_xi, test_yi, test_xi, means, lower, upper, label=key)

    # --- Aggregate totals ---
    mse_total = np.nanmean(mse_per_feature)
    crps_total = np.nanmean(crps_per_feature)

    print("MSE per feature:", mse_per_feature)
    print("MSE total:", mse_total)
    print("CRPS per feature:", crps_per_feature)
    print("CRPS total:", crps_total)

    saved_data["mse_per_feature"] = mse_per_feature
    saved_data["mse_total"] = mse_total
    saved_data["crps_per_feature"] = crps_per_feature
    saved_data["crps_total"] = crps_total

    # --- Save outputs ---
    np.savez(output_file, **saved_data)
    print(f"Results + metrics saved to {output_file}")

    update_pit_file(pit_file, pit_data)

    axes[-1].set_xlabel("Timesteps")
    plt.tight_layout()
    plt.show()


# Example Usage
if __name__ == "__main__":
    real_data = np.load('../Analysis/coverage_test_data_mogp.npz')['real_data']
    run_mogp_on_real_data(real_data, ["X", "NIR", "IR", "Sub-mm"])
