In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from evaluator import Evaluator
from trainer import Trainer
from utils import ax_plot

def process_real_data_for_mogp(real_data, keys):
    """
    Processes real data for MOGP model input.
    """
    dim = len(keys)
    sample = torch.tensor(real_data[0], dtype=torch.float32)

    train_x = []
    train_y = []
    train_idx = []
    test_x = []
    test_y = []
    test_idx = []

    for i in range(dim):
        train_xi = torch.where((sample[:, dim + i].int() == 1))[0].long()
        train_yi = sample[:, i][train_xi].float()
        train_idx_i = torch.zeros_like(train_xi, dtype=torch.long)

        test_xi = torch.where((sample[:, dim + i].int() == 0))[0].long()
        test_yi = sample[:, i][test_xi].float()
        test_idx_i = torch.zeros_like(test_xi, dtype=torch.long)

        train_x.append(train_xi.float())
        train_y.append(train_yi)
        train_idx.append(train_idx_i)

        test_x.append(test_xi.float())
        test_y.append(test_yi)
        test_idx.append(test_idx_i)

    return train_x, train_idx, train_y, test_x, test_y, test_idx

def run_mogp_on_real_data(real_data, keys, output_file="coverage_test.npz"):
    """
    Runs MOGP on the real data, plots the results, and saves them to a file.
    """
    train_x, train_idx, train_y, test_x, test_y, test_idx = process_real_data_for_mogp(real_data, keys)

    # Train the model
    full_train_x = torch.cat(train_x)
    full_train_idx = torch.cat(train_idx)
    full_train_y = torch.cat(train_y)

    trainer = Trainer(full_train_x, full_train_idx, full_train_y, 'gaussian', iterations=1000, max_early_stop=30)
    model, likelihood = trainer.train_model()

    evaluator = Evaluator(model, likelihood, test_x, test_idx, len(keys))
    observed_preds = evaluator.evaluate()

    # Create a plot for each key (wavelength)
    fig, axes = plt.subplots(len(keys), 1, figsize=(8, 2 * len(keys)), sharex=True)

    if len(keys) == 1:
        axes = [axes]

    saved_data = {}  # Dictionary to store results

    for i, key in enumerate(keys):
        train_xi = train_x[i].detach().numpy()
        train_yi = train_y[i].detach().numpy()
        test_xi = test_x[i].detach().numpy()
        test_yi = test_y[i].detach().numpy()

        means = observed_preds[i].mean.cpu().detach().numpy()
        lower, upper = observed_preds[i].confidence_region()
        lower, upper = lower.cpu().detach().numpy(), upper.cpu().detach().numpy()

        # Save results for this wavelength
        saved_data[f"{key}_train_x"] = train_xi
        saved_data[f"{key}_train_y"] = train_yi
        saved_data[f"{key}_test_x"] = test_xi
        saved_data[f"{key}_test_y"] = test_yi
        saved_data[f"{key}_predicted_means"] = means
        saved_data[f"{key}_lower_bound"] = lower
        saved_data[f"{key}_upper_bound"] = upper

        # Use ax_plot function to visualize the sample
        ax_plot(axes[i], train_yi, train_xi, test_yi, test_xi, means, lower, upper, label=key)

    axes[-1].set_xlabel("Timesteps")
    plt.tight_layout()
    plt.show()

    # Save data to a file
    np.savez(output_file, **saved_data)

    print(f"Results saved to {output_file}")

# Example Usage (assuming 'real_data' is your real data array):
real_data = np.load('../Analysis/coverage_test_data_mogp.npz')['real_data']
run_mogp_on_real_data(real_data, ["X", 'NIR', "IR", "Sub-mm"])