In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from src.model.gnn import GNNModel
from src.model.sfno import TemporalSFNO
from src.data.ncep_dataloader import get_ncep_test_data
from data.preprocess import get_normalizer

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load models
gnn_model = GNNModel().to(device)
sfno_model = TemporalSFNO().to(device)
gnn_model.eval()
sfno_model.eval()

In [None]:
# Load test data
test_inputs, targets = get_ncep_test_data()  # (B, T, C, H, W)
test_inputs = test_inputs.to(device)
targets = targets.to(device)

In [None]:
# Load normalizer (optional - depends if your data is normalized)
normalizer = get_normalizer()
mean, std = normalizer.mean.to(device), normalizer.std.to(device)

In [None]:
# Prediction
with torch.no_grad():
    gnn_preds = gnn_model(test_inputs)
    sfno_preds = sfno_model(test_inputs)

In [None]:
# Denormalize (optional — only if your data was normalized)
def denormalize(x, mean, std):
    return x * std + mean

gnn_preds = denormalize(gnn_preds, mean, std)
sfno_preds = denormalize(sfno_preds, mean, std)
targets = denormalize(targets, mean, std)

In [None]:
# Evaluation Metrics
def get_metrics(pred, target):
    rmse = torch.sqrt(torch.mean((pred - target) ** 2)).item()
    mae = torch.mean(torch.abs(pred - target)).item()
    acc = 1 - (torch.norm(pred - target) / torch.norm(target)).item()
    return {"rmse": rmse, "mae": mae, "acc": acc}

gnn_metrics = get_metrics(gnn_preds, targets)
sfno_metrics = get_metrics(sfno_preds, targets)
print("GNN Metrics:", gnn_metrics)
print("SFNO Metrics:", sfno_metrics)

In [None]:
# Plotting Functions
def plot_1d_time_series(y_true, y_preds, labels, time_axis):
    plt.figure(figsize=(10, 4))
    plt.plot(time_axis, y_true, label="Ground Truth", color='black')
    for y_pred, label in zip(y_preds, labels):
        plt.plot(time_axis, y_pred, label=label)
    plt.legend()
    plt.xlabel("Time")
    plt.ylabel("Value")
    plt.title("1D Time Series Forecast")
    plt.grid(True)
    plt.show()

In [None]:
def plot_scatter(pred, true, title="Prediction vs Ground Truth"):
    plt.figure(figsize=(6, 6))
    sns.scatterplot(x=true.flatten().cpu(), y=pred.flatten().cpu(), alpha=0.3)
    plt.xlabel("Ground Truth")
    plt.ylabel("Prediction")
    plt.title(title)
    plt.plot([true.min(), true.max()], [true.min(), true.max()], 'k--')
    plt.axis('equal')
    plt.show()

In [None]:
def plot_rmse_map(pred, target, title="Spatial RMSE"):
    rmse_map = torch.sqrt(torch.mean((pred - target) ** 2, dim=(0, 1, 2))).cpu().numpy()  # shape: HxW
    plt.figure(figsize=(6, 5))
    plt.imshow(rmse_map, cmap="magma", origin='lower')
    plt.colorbar(label="RMSE")
    plt.title(title)
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.show()

In [None]:
# Select a location for time series plot (e.g., center grid point)
lat_idx, lon_idx = test_inputs.shape[-2] // 2, test_inputs.shape[-1] // 2
time_axis = np.arange(targets.shape[1])
y_true = targets[0, :, 0, lat_idx, lon_idx].cpu().numpy()
y_gnn = gnn_preds[0, :, 0, lat_idx, lon_idx].cpu().numpy()
y_sfno = sfno_preds[0, :, 0, lat_idx, lon_idx].cpu().numpy()

In [None]:
plot_1d_time_series(y_true, [y_gnn, y_sfno], labels=["GNN", "SFNO"], time_axis=time_axis)

In [None]:
plot_scatter(gnn_preds, targets, title="GNN Predictions")

In [None]:
plot_scatter(sfno_preds, targets, title="SFNO Predictions")

In [None]:
plot_rmse_map(gnn_preds, targets, title="GNN Spatial RMSE")

In [None]:
plot_rmse_map(sfno_preds, targets, title="SFNO Spatial RMSE")