# Test d'inférence de la première version entraînée du GNN (sans critères "physics informed")

## 1. Définition des chemins des fichiers (à ajuster selon votre environnement local)

In [1]:
model_file_path = "C:\\Repos\\GIF-7005-Project\\model_25-11-25\\model.pth"
scaler_file_path = "C:\\Repos\\GIF-7005-Project\\model_25-11-25\\scaler.joblib"
dataset_path = "data/body_coordinates_and_velocities_from_1749-12-31_to_2200-01-09.json"

## 2. Chargement des librairies

In [2]:
from model.GNN import GNN_NBody, InteractionNetwork
from data import solarSystemDataSet
import torch
from torch_geometric.data import Data
import torch_geometric.nn
import torch_geometric.inspector
import inspect
import _operator
import typing
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from joblib import load

  from .autonotebook import tqdm as notebook_tqdm


## 3. Chargement du modèle

In [3]:
torch.serialization.add_safe_globals(
    [
        GNN_NBody, 
        torch.nn.modules.linear.Linear, 
        torch.nn.modules.container.ModuleList, 
        InteractionNetwork, 
        torch_geometric.nn.aggr.basic.SumAggregation, 
        torch.nn.modules.container.Sequential,
        torch.nn.modules.activation.ReLU,
        torch_geometric.inspector.Inspector,
        torch_geometric.inspector.Signature,
        torch_geometric.inspector.Parameter,
        inspect._empty,
        _operator.getitem,
        typing.OrderedDict,
        typing.Union,
        type,
        int
    ])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model: GNN_NBody = torch.load(model_file_path, map_location=torch.device('cpu'))

model.eval()
model.to(device)

Using device: cpu


GNN_NBody(
  (node_encoder): Linear(in_features=7, out_features=128, bias=True)
  (interaction_layers): ModuleList(
    (0-2): 3 x InteractionNetwork()
  )
  (output_decoder): Linear(in_features=128, out_features=6, bias=True)
)

## 4. Chargement du scaler et du jeu de données

In [4]:
# Chargement du scaler sauvegardé.
scaler: StandardScaler = load(scaler_file_path)

# Chargement des données de référence dans un DataFrame.
df_targets: pd.DataFrame = pd.read_json(dataset_path, lines=True)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


## 5. Tests d'inférence avec le modèle

In [None]:
# Pour les fins du test, on ne conserve que les données pour l'année 2025.
df_targets_2025 = df_targets[(df_targets["datetime_str"].str.startswith("A.D. 2025-"))]
df_targets_2025['body_mass'] = np.log10(df_targets_2025['body_mass'])

X_dataset: solarSystemDataSet.SolarSystemDataset = solarSystemDataSet.SolarSystemDataset(df_targets_2025, scaler)

df_predictions_2025: pd.DataFrame = pd.DataFrame()

with torch.no_grad():
    state = X_dataset.states[0]
    state_numpy = state.numpy()

    # On ne garde que la dernière colonne, qui correspond à la masse des planètes normalisée par le scaler.
    normalized_body_masses = state_numpy[:, -1:]

    # Prédiction des features pour l'ensemble des corps.
    predicted_normalized = model(Data(x=X_dataset.states[0], edge_index=X_dataset.edge_index))

    # On dénormalise les valeurs et on les ajoute au DataFrame des prédictions.
    predicted_normalized_numpy = predicted_normalized.numpy()
    predicted_normalized_numpy = np.append(predicted_normalized_numpy, normalized_body_masses, axis=1)
    predicted = scaler.inverse_transform(predicted_normalized_numpy)
    
    display(pd.DataFrame(predicted))
    display(df_targets_2025[(df_targets_2025["datetime_str"] == "A.D. 2025-Jan-01 00:00:00.0000")][['x', 'y', 'z', 'vx', 'vy', 'vz', 'body_mass']])


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_targets_2025['body_mass'] = np.log10(df_targets_2025['body_mass'])
Processing data groups: 100%|██████████| 365/365 [00:00<00:00, 2558.12it/s]


Unnamed: 0,0,1,2,3,4,5,6
0,-0.049722,-0.050638,0.003118,-2.3e-05,-6.782042e-07,-2.675547e-07,30.298634
1,5.695029,-24.183161,-2.532762,1.507695,0.6810218,-0.08216775,23.518658
2,-16.181604,12.640042,1.098791,-0.349358,-0.4445933,0.0139732,24.687305
3,-17.223604,-3.325155,-0.002227,0.057881,-0.3004286,6.784775e-05,24.776146
4,-12.693141,-3.334376,0.248217,0.037296,-0.1070991,-0.003159416,23.807339
5,-7.617223,1.761917,0.16486,-0.002586,-0.01102562,0.0001398865,27.278341
6,0.51316,5.318707,-0.098745,-0.00296,0.0006159866,0.0001725495,26.754608
7,-3.379974,1.958013,0.052442,-0.000424,-0.0005957657,5.651053e-06,25.93857
8,-0.847815,3.903936,0.035774,0.001191,-0.0003321152,2.305446e-05,26.010355


Unnamed: 0,x,y,z,vx,vy,vz,body_mass
100443,-0.005731,-0.004911,0.00018,7e-06,-4e-06,-1.177e-07,30.298635
264812,-0.393034,-0.166635,0.022487,0.005032,-0.024747,-0.002483042,23.518659
429181,0.447688,0.557306,-0.018262,-0.015799,0.012606,0.0010851,24.687306
593550,-0.184414,0.962072,0.000128,-0.017198,-0.003197,-1.122e-07,24.776147
757919,-0.527416,1.520324,0.044935,-0.012705,-0.003343,0.0002416611,23.807339
922288,1.050303,4.966542,-0.044099,-0.007469,0.001921,0.0001591398,27.278342
1086657,9.455337,-1.769525,-0.345697,0.000717,0.005471,-0.0001233523,26.754608
1251026,11.097898,16.089573,-0.084019,-0.003267,0.00205,4.99682e-05,25.93857
1415395,29.874197,-0.639099,-0.67532,4.7e-05,0.003157,-6.64877e-05,26.010355
