# Visualize Ground Truth vs Model Rollout
This notebook compares the ground truth particle trajectories with those predicted by the trained GNS model.

In [None]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
from models.gns import GNS

## Load Ground Truth Data

In [None]:
hdf5_path = '../tools/data_gen/sample_output.h5'  # Update with your file
with h5py.File(hdf5_path, 'r') as f:
    gt_positions = f['positions'][:]  # [timesteps, particles, 3]

## Load Trained Model

In [None]:
model = GNS(node_dim=3, edge_dim=0, hidden_dim=64)
model.load_state_dict(torch.load('../models/torchscript/gns.pt'))  # Update path if needed
model.eval()

## Run Model Rollout

In [None]:
pred_positions = [gt_positions[0]]  # Start from initial state
for t in range(1, len(gt_positions)):
    x = torch.tensor(pred_positions[-1], dtype=torch.float)
    # Dummy edge_index and edge_attr for demo
    edge_index = torch.empty((2,0), dtype=torch.long)
    edge_attr = None
    pred = model(x, edge_index, edge_attr).detach().numpy()
    pred_positions.append(pred)

## Plot Trajectories

In [None]:
plt.figure(figsize=(10,5))
for i in range(gt_positions.shape[1]):
    plt.plot([p[i,0] for p in gt_positions], label=f'GT Particle {i}')
    plt.plot([p[i,0] for p in pred_positions], '--', label=f'Pred Particle {i}')
plt.xlabel('Timestep')
plt.ylabel('X Position')
plt.legend()
plt.title('Ground Truth vs Model Rollout (X Position)')
plt.show()