In [6]:
import glob
from utils import segnn_utils
import os
import numpy as np
import torch
import sys
from datasets.nbody.dataset_gravity import GravityDataset
import importlib

sys.argv = [
    'main.py', '--dataset=gravity', '--epochs=5', '--max_samples=3000',
    '--model=segnn', '--lmax_h=1', '--lmax_attr=1', '--layers=4',
    '--hidden_features=64', '--subspace_type=weightbalanced', '--norm=none',
    '--batch_size=100', '--gpu=1', '--weight_decay=1e-12', '--target=pos'
]
parser = segnn_utils.create_argparser()
args = parser.parse_args()

torch.manual_seed(42)

print(sys.executable)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

run = os.path.join("segnn_runs", "2024-03-18 18-38_gravityV2_segnn")

models = glob.glob(run + "/" + '*.pth')
if len(models) > 1:
    print("MORE MODELS FOUND IN THE DIR, LOADING THE FIRST:", models[0])

model = torch.load(models[0], map_location=device)

dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

t_delta = 2
loc, vel, force, mass = dataset_train.data
dims = loc.shape[-1]
boxSize = 5
particle_index = None

C:\Users\MartinKaras(AI)\.conda\envs\n_body_approx_3_10\python.exe


# Getting prediction from batches, where a batch is all steps of a single simulation



In [7]:
from utils import segnn_utils

sim_indices = (i for i in range(3))
batch_simulations_predictions = segnn_utils.batch_prediction(model, dataset_train.data, args, device,
                                                             simulation_indices=sim_indices)



# Plot the whole simulation

In [8]:
simulation_index = 0

batch_simulations_predicted_position_changes = batch_simulations_predictions[simulation_index, ..., :dims]
batch_simulations_predicted_positions = loc[simulation_index] + batch_simulations_predicted_position_changes

batch_targets = segnn_utils.get_targets(dataset_train.data, simulation_index=simulation_index, t_delta=t_delta)

dataset_train.simulation.interactive_trajectory_plot_all_particles_3d(batch_targets,
                                                                      batch_simulations_predicted_positions,
                                                                      particle_index,
                                                                      boxSize=boxSize, dims=dims,
                                                                      offline_plot=False)

# SELFFEED STEPWISE (len jedna simulacia, nie batch simulacii)


In [27]:
sim_indices = (i for i in range(3))
steps = 100

self_feed_simulations_pos, self_feed_simulations_model_output = segnn_utils.self_feed_stepwise_prediction(model,
                                                                                                          dataset_train.data,
                                                                                                          dataset_train.simulation,
                                                                                                          args, device,
                                                                                                          simulation_indices=sim_indices,
                                                                                                          steps=steps)

Simulating 0
Simulating 1
Simulating 2


### selffeed stepwise results

In [28]:
particle_index = None
simulation_index = 0

self_feed_targets = segnn_utils.get_targets(dataset.data, simulation_index=simulation_index, t_delta=t_delta)
dataset_train.simulation.interactive_trajectory_plot_all_particles_3d(self_feed_targets,
                                                                      self_feed_simulations_pos[simulation_index],
                                                                      particle_index,
                                                                      boxSize=boxSize, dims=dims,
                                                                      offline_plot=False, alpha=0.2)

# SELFFEED BATCH

In [29]:
n_sims = 3
steps = 100

self_feed_batch_pos, self_feed_batch_model_output = segnn_utils.self_feed_batch_prediction(model,
                                                                                           dataset_train.data,
                                                                                           dataset_train.simulation,
                                                                                           args, device,
                                                                                           n_sims=n_sims,
                                                                                           steps=steps)

# selffeed batch results

In [26]:
particle_index = None
simulation_index = 0

self_feed_batch_targets = segnn_utils.get_targets(dataset.data, simulation_index=simulation_index, t_delta=t_delta)
dataset_train.simulation.interactive_trajectory_plot_all_particles_3d(self_feed_batch_targets,
                                                                      self_feed_batch_pos[simulation_index],
                                                                      particle_index,
                                                                      boxSize=boxSize, dims=dims,
                                                                      offline_plot=False, alpha=0.2)

# Compare batch and stepwise results


In [33]:
importlib.reload(segnn_utils)
comparison_results = segnn_utils.compare_predictions(self_feed_batch_pos, self_feed_simulations_pos)

{'MSE_Batch': 1.0315218313824176e-18, 'MSE_Stepwise': 1.0315218313824176e-18, 'MSE_Difference': 0.0}


In [None]:
predictions_np = np.stack(predictions)

pred_loc = predictions_np[..., :3]
pred_vel = predictions_np[..., 3:]
#preds_mass = mass.repeat(len(pred_vel), 1)
%matplotlib inline
sim.plot_trajectory_static(pred_loc)
sim.plot_energies(pred_loc, pred_vel, np.array(mass))
sim.plot_histograms(pred_loc, pred_vel)
sim.plot_energy_distribution(pred_loc, pred_vel, np.array(mass), bins=50)

In [None]:


import matplotlib.pyplot as plt

num_dims = 3
opos = batch_prediction[..., :num_dims].reshape(50, 5, num_dims)
ovel = batch_prediction[..., num_dims:].reshape(50, 5, num_dims)


In [None]:

plt.figure(figsize=(10, 5))

dim_labels = ['x', 'y', 'z'][:num_dims]  # Labels for dimensions
colors = ['red', 'green', 'blue'][:num_dims]  # Color for each dimension

# Positions
# plt.subplot(1, 2, 1)
for i, (color, label) in enumerate(zip(colors, dim_labels)):
    plt.hist(opos[:, :, i].flatten(), bins=50, alpha=0.5, color=color, label=f'{label}')
plt.title('Positions')
plt.legend()

# # Velocities
# plt.subplot(1, 2, 2)
# for i, (color, label) in enumerate(zip(colors, dim_labels)):
#     plt.hist(ovel[:, :, i].flatten(), bins=50, alpha=0.5, color=color, label=f'{label}')
# plt.title('Velocities')
# plt.legend()

#plt.tight_layout()
plt.show()