In [1]:
import glob
from utils import segnn_utils
import os
import numpy as np
import torch
import sys
from datasets.nbody.dataset_gravity import GravityDataset
import importlib
import json
from types import SimpleNamespace
%matplotlib inline

torch.manual_seed(42)

print(sys.executable)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#run = os.path.join("segnn_runs", "2024-03-18 18-38_gravityV2_segnn")
run = os.path.join("segnn_runs", "2024-03-20 11-01_gravityV2_segnn")

# training args
metadata_path = os.path.join(run, 'training_args.json')
if os.path.exists(metadata_path):
    with open(metadata_path, 'r') as json_file:
        args_dict = json.load(json_file)
    args = SimpleNamespace(**args_dict)
    args = SimpleNamespace(**args.args)

models = glob.glob(run + "/" + '*.pth')
if len(models) > 1:
    print("MORE MODELS FOUND IN THE DIR, LOADING THE FIRST:", models[0])

model = torch.load(models[0], map_location=device)
model.eval()

dataset = GravityDataset(partition='test', dataset_name=args.nbody_name,
                         max_samples=args.max_samples, neighbours=args.neighbours, target=args.target,
                         path=os.path.join(run, "gravity"))

t_delta = 2
loc, vel, force, mass = dataset.data
dims = loc.shape[-1]
boxSize = 5
particle_index = None

C:\Users\MartinKaras(AI)\.conda\envs\n_body_approx_3_10\python.exe


---
# SELFFEED BATCH
---

In [19]:
importlib.reload(segnn_utils)
n_sims = 30
steps = 300

self_feed_batch_pos, self_feed_batch_vel, self_feed_batch_force = segnn_utils.self_feed_batch_prediction(model,
                                                                                                         dataset.data,
                                                                                                         dataset.simulation,
                                                                                                         args, device,
                                                                                                         n_sims=n_sims,
                                                                                                         steps=steps)

In [17]:
particle_index = None
simulation_index = 2
importlib.reload(segnn_utils)


import datasets.nbody.dataset.synthetic_sim as synthetic_sim
importlib.reload(synthetic_sim)

sim = synthetic_sim.GravitySim(n_balls=5, loc_std=1)


self_feed_batch_targets = segnn_utils.get_targets(dataset.data, simulation_index=simulation_index,
                                                  t_delta=t_delta)
sim.interactive_trajectory_plot_all_particles_3d(self_feed_batch_targets,
                                                                self_feed_batch_pos[simulation_index],
                                                                particle_index,
                                                                boxSize=10, dims=dims,
                                                                offline_plot=False, alpha=0.2)

Showing plot, you might need to bring the plot window in focus


In [23]:
for i in range(len(self_feed_batch_pos)):
    self_feed_batch_targets = segnn_utils.get_targets(dataset.data, simulation_index=i, t_delta=t_delta)
    sim.interactive_plotly_offline_plot(self_feed_batch_targets, self_feed_batch_pos[i], duration=12, output_file=os.path.join("offline_plots",f"sim{i}.html"))

In [21]:
sim.interactive_trajectory_plot_all_particles_3d(self_feed_batch_targets,
                                                 self_feed_batch_pos[15],
                                                 particle_index,
                                                 boxSize=10, dims=dims,
                                                 offline_plot=False, alpha=0.2)

Showing plot, you might need to bring the plot window in focus


In [39]:
import datasets.nbody.dataset.synthetic_sim as synthetic_sim
importlib.reload(synthetic_sim)

sim = synthetic_sim.GravitySim(n_balls=5, loc_std=1)

importlib.reload(segnn_utils)

for i in range(len(self_feed_batch_pos)):
    self_feed_batch_targets = segnn_utils.get_targets(dataset.data, simulation_index=i, t_delta=t_delta)
    sim.interactive_plotly_offline_plot(self_feed_batch_targets, self_feed_batch_pos[i], duration=12, output_file=os.path.join("offline_plots",f"sim{i}.html"))
    


In [31]:
import datasets.nbody.dataset.synthetic_sim as synthetic_sim
importlib.reload(synthetic_sim)

sim = synthetic_sim.GravitySim(n_balls=5, loc_std=1)

sim.interactive_trajectory_plot_all_particles_3d(self_feed_batch_targets,
                                                 self_feed_batch_pos[0],
                                                 particle_index,
                                                 boxSize=20, dims=dims,
                                                 offline_plot=False, alpha=0.2)

Showing plot, you might need to bring the plot window in focus


array([[[ 0.25599378,  0.50036839, -0.68117733],
        [-0.58148081, -1.1037777 , -0.33030632],
        [ 0.32889949, -0.85687084, -1.03891683],
        [-0.96231441, -0.07366148, -2.29508088],
        [ 1.27814505, -1.51114328,  0.8399791 ]]])