In [7]:
import torch
from torch_geometric.data import Data

# Allowlist torch_geometric's Data for unpickling (PyTorch 2.6+ safety)
if hasattr(torch.serialization, "add_safe_globals"):
    torch.serialization.add_safe_globals([Data])

# Now load (works whether you saved a list or a dict with {"dataset": ..., "coef_norm": ...})
path = "./train_dataset"  # or "Dataset/train_dataset.pt" if that’s your filename
obj = torch.load(path, map_location="cpu", weights_only=False)

if isinstance(obj, dict) and "dataset" in obj:
    data_list = obj["dataset"]
    coef_norm = obj.get("coef_norm", None)
else:
    data_list = obj
    coef_norm = None

print(f"Loaded {len(data_list)} foils")
d0 = data_list[0]
print(d0)
print("x:", d0.x.shape, "y:", d0.y.shape, "pos:", d0.pos.shape, "surf:", d0.surf.shape, "has g:", hasattr(d0, "g"))


Loaded 180 foils
Data(x=[10, 7], y=[10, 1], pos=[10, 2], surf=[10], g=[1024])
x: torch.Size([10, 7]) y: torch.Size([10, 1]) pos: torch.Size([10, 2]) surf: torch.Size([10]) has g: True


In [8]:
# Show actual values for the first foil (limited for readability)
print("\n--- First foil sample ---")

# x: first few surface points and their features
print("x (first 5 rows):")
print(d0.x[:5].numpy())

# y: wall pressure target
print("\ny (first 5 values):")
print(d0.y[:5].numpy().squeeze())

# pos: 2D coordinates on the airfoil
print("\npos (first 5 points):")
print(d0.pos[:5].numpy())

# g: first few of the 1024 global features
if hasattr(d0, "g"):
    print("\ng (first 10 of 1024 features):")
    print(d0.g[:10].numpy())
else:
    print("\nNo global features (g) attached to this sample.")



--- First foil sample ---
x (first 5 rows):
[[-1.1518016  -0.95302624 -1.4449525   0.34496346  0.         -0.10202727
  -1.0380168 ]
 [ 1.1142102   0.5466951  -1.4449525   0.34496346  0.          1.0764974
   0.97981066]
 [ 0.06793416 -1.0244389  -1.4449525   0.34496346  0.          0.10060588
  -1.0379212 ]
 [-1.2205467   1.7621026  -1.4449525   0.34496346  0.         -1.2968853
   0.9696862 ]
 [ 0.93350536  0.7670427  -1.4449525   0.34496346  0.          1.0284183
   0.9820153 ]]

y (first 5 values):
[ 0.5701196   0.32005188  0.42749652 -0.6711935   0.2931787 ]

pos (first 5 points):
[[ 0.15744904 -0.03342144]
 [ 0.81518626  0.04554505]
 [ 0.511492   -0.03718161]
 [ 0.13749492  0.10954125]
 [ 0.7627345   0.05714726]]

g (first 10 of 1024 features):
[ 0.0000000e+00  3.0793010e-03  7.5377457e-02  1.3654685e-03
 -5.8002864e-10  5.5359541e-03 -7.0725920e-10  7.5342233e-04
 -5.9096994e-10 -8.2222712e-10]
