In [None]:
import argparse
import numpy as np
import torch
import models
import utils
import matplotlib.pyplot as plt
from scipy.stats import norm
from utils import crps_norm, mse_inference
import os


def compute_pit_gaussian(y_true, mu, sigma):
    """Closed-form PIT computation for Gaussian predictive distributions."""
    return norm.cdf((y_true - mu) / sigma)


def save_pit_values(pit_dict, save_path="pit_values_tripletformer.npz"):
    """
    Saves or updates PIT values for TripletFormer.
    If the file exists, append new values; otherwise create a new file.
    """
    if os.path.exists(save_path):
        existing = dict(np.load(save_path, allow_pickle=True))
    else:
        existing = {}

    for key, values in pit_dict.items():
        if key in existing:
            existing[key] = np.concatenate([existing[key], values])
        else:
            existing[key] = values

    np.savez(save_path, **existing)
    print(f"PIT values updated and saved to {save_path}")


def plot_tripletformer_on_real_data(real_data, keys, model_path, save_path="coverage_test_tripletformer.npz"):
    """
    Evaluates the Tripletformer model on real data, saves results and PIT values.
    """

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dim = len(keys)

    # Dummy args to load the Tripletformer model
    class DummyArgs:
        def __init__(self):
            self.dataset = 'sgra'
            self.experiment_id = model_path.split('_')[-1].split('.')[0]
            self.net = 'triple'
            self.imab_dim = 128
            self.cab_dim = 256
            self.decoder_dim = 128
            self.nlayers = 1
            self.enc_num_heads = 4
            self.dec_num_heads = 4
            self.num_ref_points = 128
            self.mse_weight = 1.0
            self.norm = True
            self.sample_tp = 0.1
            self.sample_type = 'random'

    args = DummyArgs()

    # Load model
    net = models.load_network(args, dim, device=device).to(device)
    chp = torch.load(model_path, map_location=device)
    net.load_state_dict(chp['state_dict'])
    net.eval()

    with torch.no_grad():
        real_data_tensor = torch.tensor(real_data, dtype=torch.float32).to(device)

        # Build masks
        original_mask = torch.ones(real_data_tensor[:, :, dim:2 * dim].shape).to(device)
        subsampled_mask = real_data_tensor[:, :, dim:2 * dim]
        recon_mask = original_mask - subsampled_mask
        context_y = torch.cat((real_data_tensor[:, :, :dim] * subsampled_mask, subsampled_mask), -1)

        # Predict
        px, time_indices, channel_indices = net.inference(
            real_data_tensor[:, :, -1],
            context_y,
            real_data_tensor[:, :, -1],
            torch.cat((real_data_tensor[:, :, :dim] * recon_mask, recon_mask), -1)
        )

        # Extract Gaussian parameters
        means = px.mean.squeeze().cpu().numpy()
        stds = torch.sqrt(torch.exp(px.logvar)).squeeze().cpu().numpy()
        time_indices = time_indices.squeeze().cpu().numpy()
        channel_indices = channel_indices.squeeze().cpu().numpy()
        real_data_numpy = real_data_tensor[:, :, :-1].squeeze().cpu().numpy()

        # Containers
        saved_data = {}
        mse, crps = [], []
        pit_dict = {}
        timesteps = np.arange(1, real_data_numpy.shape[0] + 1)

        # Evaluate each wavelength
        for chan in range(dim):
            key_label = keys[chan] if keys else f"Channel_{chan+1}"

            obs_indices = np.where(real_data_numpy[:, chan + dim] == 1)
            masked_indices = np.where(real_data_numpy[:, chan + dim] == 0)

            obs_x, obs_y = timesteps[obs_indices], real_data_numpy[:, chan][obs_indices]
            masked_x, masked_y = timesteps[masked_indices], real_data_numpy[:, chan][masked_indices]

            pred_indices = np.where(channel_indices == chan)
            pred_means, pred_stds = means[pred_indices], stds[pred_indices]
            pred_lower, pred_upper = pred_means - 2 * pred_stds, pred_means + 2 * pred_stds

            # --- Metrics ---
            mse_val = mse_inference(masked_y, pred_means)
            crps_val = crps_norm(masked_y, pred_means, pred_stds)
            pit_vals = compute_pit_gaussian(masked_y, pred_means, pred_stds)

            mse.append(mse_val)
            crps.append(crps_val)
            pit_dict[key_label] = pit_vals

            # Save arrays
            saved_data[f"{key_label}_train_x"] = obs_x
            saved_data[f"{key_label}_train_y"] = obs_y
            saved_data[f"{key_label}_test_x"] = masked_x
            saved_data[f"{key_label}_test_y"] = masked_y
            saved_data[f"{key_label}_predicted_means"] = pred_means
            saved_data[f"{key_label}_lower_bound"] = pred_lower
            saved_data[f"{key_label}_upper_bound"] = pred_upper

            print(f"[{key_label}] MSE: {mse_val:.6f}, CRPS: {crps_val:.6f}")

        # --- Aggregate totals ---
        saved_data["mse_per_feature"] = mse
        saved_data["mse_total"] = np.mean(mse)
        saved_data["crps_per_feature"] = crps
        saved_data["crps_total"] = np.mean(crps)

        np.savez(save_path, **saved_data)
        print(f"TripletFormer results saved to {save_path}")

        # --- Save PIT values ---
        save_pit_values(pit_dict, save_path="pit_values_tripletformer.npz")


real_data = np.load('../Analysis/coverage_test_data.npz')['real_data']
plot_tripletformer_on_real_data(
    real_data,
    ["X", "NIR", "IR", "Sub-mm"],
    "./saved_models/best_model_sgra_final.h5"
)

In [None]:
# import argparse
# import numpy as np
# import torch
# import models
# import utils
# import matplotlib.pyplot as plt
# from utils import crps_norm, mse_inference

# def plot_tripletformer_on_real_data(real_data, keys, model_path, save_path="coverage_test.npz"):
#     """
#     Evaluates the Tripletformer model on real data and saves the results.
#     """

#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     dim = len(keys)

#     # Create a dummy argparse namespace to load the model
#     class DummyArgs:
#         def __init__(self):
#             self.dataset = 'sgra'  # Assuming 'sgra' dataset configuration
#             self.experiment_id = model_path.split('_')[-1].split('.')[0]
#             self.net = 'triple'
#             self.imab_dim = 128
#             self.cab_dim = 256
#             self.decoder_dim = 128
#             self.nlayers = 1
#             self.enc_num_heads = 4
#             self.dec_num_heads = 4
#             self.num_ref_points = 128 #added this line
#             self.mse_weight = 1.0 #added this line
#             self.norm = True #added this line
#             self.sample_tp = 0.1 #added this line
#             self.sample_type = 'random' #added this line

#     args = DummyArgs()

#     net = models.load_network(args, dim, device=device).to(device)
#     chp = torch.load(model_path, map_location=device)  # Load the model
#     net.load_state_dict(chp['state_dict'])

#     net.eval()  # Set the model to evaluation mode

#     with torch.no_grad():
#         real_data_tensor = torch.tensor(real_data, dtype=torch.float32).to(device)

#         # Create context and reconstruction masks
#         original_mask = torch.ones(real_data_tensor[:, :, dim:2 * dim].shape).to(device)
#         subsampled_mask = real_data_tensor[:, :, dim:2 * dim]
#         recon_mask = original_mask - subsampled_mask
#         context_y = torch.cat((real_data_tensor[:, :, :dim] * subsampled_mask, subsampled_mask), -1)

#         # Compute predictions
#         px, time_indices, channel_indices = net.inference(
#             real_data_tensor[:, :, -1],  # Time progression indicator
#             context_y,
#             real_data_tensor[:, :, -1],
#             torch.cat((real_data_tensor[:, :, :dim] * recon_mask, recon_mask), -1)
#         )

#         means = px.mean
#         logvars = px.logvar
#         std = torch.sqrt(torch.exp(logvars))

#         # Move everything to CPU for processing
#         means = means.squeeze().cpu().numpy()
#         stds = std.squeeze().cpu().numpy()
#         time_indices = time_indices.squeeze().cpu().numpy()
#         channel_indices = channel_indices.squeeze().cpu().numpy()
#         real_data_numpy = real_data_tensor[:, :, :-1].squeeze().cpu().numpy()

#         # Save data
#         saved_data = {}

#         # Create subplots
#         fig, axs = plt.subplots(dim, 1, figsize=(10, 2 * dim), sharex=True, gridspec_kw={'hspace': 0})
#         timesteps = np.arange(1, real_data_numpy.shape[0] + 1)

#         if dim == 1:
#             axs = [axs]

#         mse = []
#         crps = []
#         for chan in range(dim):
#             ax = axs[chan]

#             # Observed values
#             obs_indices = np.where(real_data_numpy[:, chan + dim] == 1)
#             obs_x = timesteps[obs_indices]
#             obs_y = real_data_numpy[:, chan][obs_indices]

#             # Masked values
#             masked_indices = np.where(real_data_numpy[:, chan + dim] == 0)
#             masked_x = timesteps[masked_indices]
#             masked_y = real_data_numpy[:, chan][masked_indices]

#             # Predictions
#             pred_indices = np.where(channel_indices == chan)
#             pred_times = time_indices[pred_indices]
#             pred_means = means[pred_indices]
#             pred_stds = stds[pred_indices]
#             pred_lower = pred_means - 2 * pred_stds
#             pred_upper = pred_means + 2 * pred_stds

#             mse.append(mse_inference(masked_y, pred_means))
#             crps.append(crps_norm(masked_y, pred_means, pred_stds))

#             # Save arrays
#             key_label = keys[chan] if keys else f"Channel_{chan+1}"
#             saved_data[f"{key_label}_train_x"] = obs_x
#             saved_data[f"{key_label}_train_y"] = obs_y
#             saved_data[f"{key_label}_test_x"] = masked_x
#             saved_data[f"{key_label}_test_y"] = masked_y
#             saved_data[f"{key_label}_predicted_means"] = pred_means
#             saved_data[f"{key_label}_lower_bound"] = pred_lower
#             saved_data[f"{key_label}_upper_bound"] = pred_upper

#             # # Plot
#             # ax.scatter(obs_x, obs_y, s=3, color='#E6C229', label='Observed')
#             # ax.scatter(masked_x, masked_y, s=3, color='#1B998B', label='Masked')
#             # ax.scatter(pred_times, pred_means, linewidth=1, color='#DF2935', label='Predicted Mean', s=3)
#             # ax.fill_between(pred_times, pred_lower, pred_upper, alpha=0.2, color='#DF2935', label=r'2-$\sigma$')

#             # # Compute and print MSE
#             # tmp_mse = np.mean((masked_y - pred_means) ** 2)
#             # print(f'MSE for key {key_label} is: {tmp_mse:.5f}')

#             # # Labels
#             # ax.set_ylabel(key_label)
#             # ax.set_ylim(-3.5, 7.5)
#             # if chan == 0:
#             #     ax.legend(loc="upper right")
#             # ax.grid(True)

#         # axs[-1].set_xlabel("Timesteps")
#         # plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#         # plt.savefig('figures/tripletformer_coverage.png')

#         # Save metrics into the npz
#         saved_data["mse_per_feature"] = mse
#         saved_data["mse_total"] = np.mean(mse)
#         saved_data["crps_per_feature"] = crps
#         saved_data["crps_total"] = np.mean(crps)

#         # Save the results
#         np.savez(save_path, **saved_data)
#         print(f"Results saved to {save_path}")

# # Example Usage
# real_data = np.load('../Analysis/coverage_test_data.npz')['real_data']
# plot_tripletformer_on_real_data(real_data, ["X", 'NIR', 'IR', 'Sub-mm'], "./saved_models/best_model_sgra_final.h5")

In [None]:
# import numpy as np
# import torch
# import matplotlib.pyplot as plt
# from scipy.stats import norm
# import models

# def mse_inference(y, preds):
#     return np.mean((y - preds) ** 2)

# def crps_norm(y, mu, sigma):
#     w = (y - mu) / sigma
#     return np.mean(sigma * (w * (2 * norm.cdf(w) - 1) + 2 * norm.pdf(w) - 1/np.sqrt(np.pi)))

# def plot_tripletformer_on_real_data(real_data, keys, model_path, save_path="coverage_test.npz"):
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     dim = len(keys)

#     class DummyArgs:
#         def __init__(self):
#             self.dataset = 'sgra'
#             self.experiment_id = model_path.split('_')[-1].split('.')[0]
#             self.net = 'triple'
#             self.imab_dim = 128
#             self.cab_dim = 256
#             self.decoder_dim = 128
#             self.nlayers = 1
#             self.enc_num_heads = 4
#             self.dec_num_heads = 4
#             self.num_ref_points = 128
#             self.mse_weight = 1.0
#             self.norm = True
#             self.sample_tp = 0.1
#             self.sample_type = 'random'

#     args = DummyArgs()

#     net = models.load_network(args, dim, device=device).to(device)
#     chp = torch.load(model_path, map_location=device)
#     net.load_state_dict(chp['state_dict'])
#     net.eval()

#     saved_data = {}
#     mse_per_feature = []
#     crps_per_feature = []

#     with torch.no_grad():
#         real_data_tensor = torch.tensor(real_data, dtype=torch.float32).to(device)
#         original_mask = torch.ones(real_data_tensor[:, :, dim:2 * dim].shape).to(device)
#         subsampled_mask = real_data_tensor[:, :, dim:2 * dim]
#         recon_mask = original_mask - subsampled_mask
#         context_y = torch.cat((real_data_tensor[:, :, :dim] * subsampled_mask, subsampled_mask), -1)

#         px, time_indices, channel_indices = net.inference(
#             real_data_tensor[:, :, -1],
#             context_y,
#             real_data_tensor[:, :, -1],
#             torch.cat((real_data_tensor[:, :, :dim] * recon_mask, recon_mask), -1)
#         )

#         means = px.mean.squeeze().cpu()
#         stds = torch.sqrt(torch.exp(px.logvar)).squeeze().cpu()
#         time_indices = time_indices.squeeze().cpu()
#         channel_indices = channel_indices.squeeze().cpu()
#         real_data_numpy = real_data_tensor[:, :, :-1].squeeze().cpu().numpy()

#         for chan in range(dim):
#             # Masked values
#             masked_indices = np.where(real_data_numpy[:, chan + dim] == 0)
#             masked_y = real_data_numpy[masked_indices, chan].flatten()
#             pred_indices = np.where(channel_indices == chan)
#             pred_means = means[pred_indices]
#             pred_stds = stds[pred_indices]

#             masked_y = masked_y
#             pred_means = pred_means.cpu().numpy()
#             pred_stds = pred_stds.cpu().numpy()

#             mse_val = mse_inference(masked_y, pred_means)
#             crps_val = crps_norm(masked_y, pred_means, pred_stds)

#             mse_per_feature.append(mse_val)
#             crps_per_feature.append(crps_val)

#             print(f"MSE for {keys[chan]}: {mse_val:.4f}, CRPS: {crps_val:.4f}")

#     # Compute total average over features
#     total_mse = np.mean(mse_per_feature)
#     total_crps = np.mean(crps_per_feature)
#     print(f"\nTotal average MSE: {total_mse:.4f}")
#     print(f"Total average CRPS: {total_crps:.4f}")

#     # Save results
#     saved_data["mse_per_feature"] = mse_per_feature
#     saved_data["crps_per_feature"] = crps_per_feature
#     saved_data["mse_total"] = total_mse
#     saved_data["crps_total"] = total_crps

#     np.savez(save_path, **saved_data)
#     print(f"Results saved to {save_path}")
