# Test d'inférence de la première version entraînée du GNN (sans critères "physics informed")

In [15]:
from model.GNN import GNN_NBody, InteractionNetwork
from data import solarSystemDataSet
import torch
from torch_geometric.data import Data
import torch_geometric.nn
import torch_geometric.inspector
import inspect
import _operator
import typing
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from joblib import load

torch.serialization.add_safe_globals(
    [
        GNN_NBody, 
        torch.nn.modules.linear.Linear, 
        torch.nn.modules.container.ModuleList, 
        InteractionNetwork, 
        torch_geometric.nn.aggr.basic.SumAggregation, 
        torch.nn.modules.container.Sequential,
        torch.nn.modules.activation.ReLU,
        torch_geometric.inspector.Inspector,
        torch_geometric.inspector.Signature,
        torch_geometric.inspector.Parameter,
        inspect._empty,
        _operator.getitem,
        typing.OrderedDict,
        typing.Union,
        type,
        int
    ])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_state_dict_path = "C:\\Repos\\GIF-7005-Project\\mlruns\\304169667892621182\\models\\m-9e94d3ae95cd448ebd765b3674c0cf1f\\artifacts\\data\\model.pth"
model: GNN_NBody = torch.load(model_state_dict_path, map_location=torch.device('cpu'))

model.eval()
model.to(device)

Using device: cpu


GNN_NBody(
  (node_encoder): Linear(in_features=7, out_features=128, bias=True)
  (interaction_layers): ModuleList(
    (0-2): 3 x InteractionNetwork()
  )
  (output_decoder): Linear(in_features=128, out_features=6, bias=True)
)

In [16]:
# Chargement du scaler sauvegardé.
scaler: StandardScaler = load("scaler.joblib")

# Chargement des données de référence dans un DataFrame.
df_targets: pd.DataFrame = pd.read_json("data/body_coordinates_and_velocities_from_1749-12-31_to_2200-01-09.json", lines=True)

In [26]:
# Pour les fins du test, on ne conserve que les données pour l'année 2025.
df_targets_2025 = df_targets[(df_targets["datetime_str"].str.startswith("A.D. 2025-"))]
display(df_targets_2025)

X_dataset: solarSystemDataSet.SolarSystemDataset = solarSystemDataSet.SolarSystemDataset(df_targets_2025, scaler)

df_predictions_2025: pd.DataFrame = pd.DataFrame()

with torch.no_grad():
    state = X_dataset.states[0]
    state_numpy = state.numpy()

    # On ne garde que la dernière colonne, qui correspond à la masse des planètes normalisée par le scaler.
    normalized_body_masses = state_numpy[:, -1:]

    # Prédiction des des features pour l'ensemble des corps.
    predicted_normalized = model(Data(x=X_dataset.states[0], edge_index=X_dataset.edge_index))

    # On dénormalise les valeurs et on les ajoute au DataFrame des prédictions.
    predicted_normalized_numpy = predicted_normalized.numpy()
    predicted_normalized_numpy = np.append(predicted_normalized_numpy, normalized_body_masses, axis=1)
    predicted = scaler.inverse_transform(predicted_normalized_numpy)
    
    display(pd.DataFrame(predicted_normalized_numpy))
    display(pd.DataFrame(predicted))

    display(df_targets_2025[(df_targets_2025["datetime_str"] == "A.D. 2025-Jan-01 00:00:00.0000")][['x', 'y', 'z', 'vx', 'vy', 'vz', 'body_mass']])


Unnamed: 0,body_id,body_name,body_mass,datetime_jd,datetime_str,x,y,z,vx,vy,vz
100443,10,Soleil,1.989000e+30,2460676.5,A.D. 2025-Jan-01 00:00:00.0000,-0.005731,-0.004911,0.000180,0.000007,-0.000004,-1.177000e-07
100444,10,Soleil,1.989000e+30,2460677.5,A.D. 2025-Jan-02 00:00:00.0000,-0.005723,-0.004914,0.000179,0.000007,-0.000004,-1.178000e-07
100445,10,Soleil,1.989000e+30,2460678.5,A.D. 2025-Jan-03 00:00:00.0000,-0.005716,-0.004918,0.000179,0.000007,-0.000004,-1.180000e-07
100446,10,Soleil,1.989000e+30,2460679.5,A.D. 2025-Jan-04 00:00:00.0000,-0.005709,-0.004922,0.000179,0.000007,-0.000004,-1.181000e-07
100447,10,Soleil,1.989000e+30,2460680.5,A.D. 2025-Jan-05 00:00:00.0000,-0.005702,-0.004925,0.000179,0.000007,-0.000004,-1.183000e-07
...,...,...,...,...,...,...,...,...,...,...,...
1415755,899,Neptune,1.024130e+26,2461036.5,A.D. 2025-Dec-27 00:00:00.0000,29.869422,0.497619,-0.698619,-0.000073,0.003157,-6.324300e-05
1415756,899,Neptune,1.024130e+26,2461037.5,A.D. 2025-Dec-28 00:00:00.0000,29.869348,0.500777,-0.698682,-0.000074,0.003157,-6.293360e-05
1415757,899,Neptune,1.024130e+26,2461038.5,A.D. 2025-Dec-29 00:00:00.0000,29.869275,0.503934,-0.698745,-0.000074,0.003158,-6.303590e-05
1415758,899,Neptune,1.024130e+26,2461039.5,A.D. 2025-Dec-30 00:00:00.0000,29.869201,0.507092,-0.698809,-0.000074,0.003158,-6.343520e-05


Processing data groups: 100%|██████████| 365/365 [00:00<00:00, 2379.36it/s]


Unnamed: 0,0,1,2,3,4,5,6
0,0.601308,0.23415,-0.462682,-0.493333,-0.072953,-0.181986,2.828426
1,0.63188,0.265201,-0.502169,-0.542106,-0.068598,-0.171772,-0.353948
2,0.621668,0.263929,-0.492368,-0.534275,-0.086022,-0.164767,-0.353177
3,0.617141,0.247116,-0.496133,-0.522425,-0.082053,-0.195438,-0.354086
4,0.62909,0.26144,-0.501149,-0.532824,-0.078227,-0.175739,-0.354077
5,0.621534,0.266391,-0.489374,-0.54103,-0.0889,-0.157196,-0.353923
6,0.633872,0.268541,-0.501888,-0.539534,-0.075985,-0.163814,-0.354079
7,0.627977,0.263073,-0.498854,-0.534526,-0.076585,-0.1722,-0.35105
8,0.628353,0.260042,-0.49873,-0.53281,-0.077857,-0.175134,-0.354086


Unnamed: 0,0,1,2,3,4,5,6
0,5.339189,2.007836,-0.110181,-0.0049,-0.000723,-0.000154,1.989e+30
1,5.613423,2.276541,-0.119852,-0.005382,-0.000679,-0.000145,8.68107e+25
2,5.521821,2.265533,-0.117451,-0.005305,-0.000852,-0.000139,5.68341e+26
3,5.481213,2.120039,-0.118373,-0.005187,-0.000813,-0.000166,3.347293e+23
4,5.588396,2.243988,-0.119602,-0.00529,-0.000775,-0.000149,5.96379e+24
5,5.520616,2.286835,-0.116718,-0.005371,-0.000881,-0.000133,1.024134e+26
6,5.631288,2.305439,-0.119783,-0.005357,-0.000753,-0.000139,4.868201e+24
7,5.578413,2.258122,-0.11904,-0.005307,-0.000759,-0.000146,1.898197e+27
8,5.581785,2.231894,-0.119009,-0.00529,-0.000771,-0.000148,6.369607e+23


Unnamed: 0,x,y,z,vx,vy,vz,body_mass
100443,-0.005731,-0.004911,0.00018,7e-06,-4e-06,-1.177e-07,1.989e+30
264812,-0.393034,-0.166635,0.022487,0.005032,-0.024747,-0.002483042,3.3011e+23
429181,0.447688,0.557306,-0.018262,-0.015799,0.012606,0.0010851,4.8675e+24
593550,-0.184414,0.962072,0.000128,-0.017198,-0.003197,-1.122e-07,5.97237e+24
757919,-0.527416,1.520324,0.044935,-0.012705,-0.003343,0.0002416611,6.4171e+23
922288,1.050303,4.966542,-0.044099,-0.007469,0.001921,0.0001591398,1.8982e+27
1086657,9.455337,-1.769525,-0.345697,0.000717,0.005471,-0.0001233523,5.6834e+26
1251026,11.097898,16.089573,-0.084019,-0.003267,0.00205,4.99682e-05,8.681e+25
1415395,29.874197,-0.639099,-0.67532,4.7e-05,0.003157,-6.64877e-05,1.02413e+26
