In [1]:
import glob
from utils import segnn_utils
import os
import numpy as np
import torch
import sys
from datasets.nbody.dataset_gravity import GravityDataset

sys.argv = [
    'main.py', '--dataset=gravity', '--epochs=5', '--max_samples=3000',
    '--model=segnn', '--lmax_h=1', '--lmax_attr=1', '--layers=4',
    '--hidden_features=64', '--subspace_type=weightbalanced', '--norm=none',
    '--batch_size=100', '--gpu=1', '--weight_decay=1e-12', '--target=pos'
]
parser = segnn_utils.create_argparser()
args = parser.parse_args()

torch.manual_seed(42)

print(sys.executable)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

run = os.path.join("segnn_runs", "2024-03-18 18-38_gravityV2_segnn")

models = glob.glob(run + "/" + '*.pth')
if len(models) > 1:
    print("MORE MODELS FOUND IN THE DIR, LOADING THE FIRST:", models[0])

model = torch.load(models[0], map_location=device)

dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

C:\Users\MartinKaras(AI)\.conda\envs\n_body_approx_3_10\python.exe


# Getting prediction from batches, where a batch is all steps of a single simulation



In [2]:
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform

transform = O3Transform(args.lmax_attr)

simulation_index = 0

loc, vel, force, mass = dataset_train.data

output_dims = loc.shape[-1]
batch_size = loc.shape[-3]
n_nodes = loc.shape[-2]

t_delta = 2

loc = torch.from_numpy(loc[simulation_index]).view(-1, output_dims)
vel = torch.from_numpy(vel[simulation_index]).view(-1, output_dims)
force = torch.from_numpy(force[simulation_index]).view(-1, output_dims)
mass = torch.from_numpy(mass[simulation_index]).repeat(batch_size, 1)
data = [loc, vel, force, mass]

if args.target == 'pos':
    y = loc
else:
    y = force

data = [d.to(device) for d in data]
loc, vel, force, mass = data

graph = Data(pos=loc, vel=vel, force=force, mass=mass, y=y)
batch = torch.arange(0, batch_size)
graph.batch = batch.repeat_interleave(n_nodes).long()
graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

graph = transform(graph)  # Add O3 attributes
graph = graph.to(device)
batch_prediction = model(graph).cpu().detach().numpy()



# Plot the whole simulation

In [3]:
article_index = None
boxSize = 5
model_output_dims = batch_prediction.shape[-1]

loc_orig = loc.view(batch_size, n_nodes, output_dims)
predicted_position_changes = batch_prediction.reshape(batch_size, n_nodes, model_output_dims)[..., :output_dims]
predicted_positions = loc_orig + predicted_position_changes

targets = []
for i in range(0, batch_size - t_delta):
    targets.append(loc_orig[i + t_delta, :, :])

targets_np = np.array(targets)

dataset_train.simulation.interactive_trajectory_plot_all_particles_3d(targets_np, predicted_positions,
                                                                      None,
                                                                      boxSize=boxSize, dims=output_dims,
                                                                      offline_plot=False)


# !!! SELF FEED PREDICTION !!!

In [56]:
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform

transform = O3Transform(args.lmax_attr)

loc, vel, force, mass = dataset_train.data

output_dims = loc.shape[-1]
n_nodes = loc.shape[-2]

n_sims = 10
num_steps = 6

t_delta = 2

loc = torch.from_numpy(loc)
vel = torch.from_numpy(vel)
force = torch.from_numpy(force)
mass = torch.from_numpy(mass)

# get just initial states
loc, vel, force, mass = [d[:n_sims, 0, ...].to(device) for d in [loc, vel, force, mass]]
loc, vel, force = [d.reshape(-1, 3) for d in [loc, vel, force]]
mass = mass.repeat(n_nodes, 1)
loc_orig = loc.clone()

simulation_instance = dataset_train.simulation

states = []
self_feed_predictions = []

for step in range(num_steps):
    
    batch = torch.arange(0, n_sims)
    graph = Data(pos=loc, vel=vel, force=force, mass=mass)
    
    graph.batch = batch.repeat_interleave(n_nodes).long()
    graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)
    graph = transform(graph)  # Add O3 attributes
    graph = graph.to(device)

    # Model prediction
    prediction = model(graph).cpu().detach().numpy()

    # Update states based on prediction
    delta_loc, delta_vel = prediction[:, :output_dims], prediction[:, output_dims:]

    # Update position and velocity
    loc = loc + torch.from_numpy(delta_loc).to(device)
    vel = vel + torch.from_numpy(delta_vel).to(device)

    force = simulation_instance.compute_force_batched(loc.cpu().detach().numpy(), mass.cpu().detach().numpy(),
                                                      simulation_instance.interaction_strength,
                                                      simulation_instance.softening, batch_size)

    force = torch.from_numpy(force)

    states.append((loc.clone(), vel.clone(), force.clone()))
    self_feed_predictions.append(loc.clone().view(n_nodes, n_sims, output_dims))

self_feed_predictions = np.array(self_feed_predictions)
#self_feed_predictions = self_feed_predictions.transpose(1, 0 ,2)

In [57]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform

transform = O3Transform(args.lmax_attr)

simulation_index = 0

loc, vel, force, mass = dataset_train.data

output_dims = loc.shape[-1]
n_nodes = loc.shape[-2]
t_delta = 2

# Assuming simulation_steps is defined or calculated from your dataset
simulation_steps = len(loc[simulation_index])  # Or however you determine the number of steps
stepwise_prediction = []
for step in range(simulation_steps):
    loc_step = torch.from_numpy(loc[simulation_index][step]).view(-1, output_dims).to(device)
    vel_step = torch.from_numpy(vel[simulation_index][step]).view(-1, output_dims).to(device)
    force_step = torch.from_numpy(force[simulation_index][step]).view(-1, output_dims).to(device)
    mass_step = torch.tensor(mass[simulation_index], dtype=torch.float).repeat(n_nodes, 1).to(device)

    if args.target == 'pos':
        y_step = loc_step
    else:
        y_step = force_step

    graph = Data(pos=loc_step, vel=vel_step, force=force_step, mass=mass_step, y=y_step)
    # Since we're dealing with single steps, no need to batch
    graph.edge_index = knn_graph(loc_step, args.neighbours)

    graph = transform(graph)  # Add O3 attributes
    graph = graph.to(device)
    stepwise_prediction.append(model(graph).detach().numpy())

stepwise_prediction = np.stack(stepwise_prediction)

In [ ]:
self_feed_predictions.transpose(1, )

In [15]:
particle_index = None
boxSize = 5
model_output_dims = self_feed_predictions.shape[-1]
num_sims = self_feed_predictions.shape[0]

loc_orig = loc_orig.view(batch_size, n_nodes, output_dims)
predicted_positions_self_feed = self_feed_predictions.reshape(batch_size, n_nodes, model_output_dims)

targets = []
for i in range(0, batch_size - t_delta):
    targets.append(loc_orig[i + t_delta, :, :])

targets_np = np.array(targets)

dataset_train.simulation.interactive_trajectory_plot_all_particles_3d(targets_np, predicted_positions_self_feed,
                                                                      None,
                                                                      boxSize=boxSize, dims=output_dims,
                                                                      offline_plot=False)

ValueError: cannot reshape array of size 7500 into shape (50,5,3)

In [7]:
batch_prediction.reshape(batch_size, n_nodes, output_dims)

(250, 6)

# Interactive plot of the simulation

In [10]:
import datasets.nbody.dataset_gravity as dataset_gravity
%matplotlib inline
import importlib

importlib.reload(dataset_gravity)

dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

dataset_train.plot_energy_statistics()


NameError: name 'importlib' is not defined

# Check whether the inference yields the same result as during the training


In [15]:
dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True)

dataset_val = GravityDataset(partition='val', dataset_name=args.nbody_name,
                             neighbours=args.neighbours, target=args.target)
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False)

dataset_test = GravityDataset(partition='test', dataset_name=args.nbody_name,
                              neighbours=args.neighbours, target=args.target)
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False)

from torch import nn

model.eval()
criterion = nn.MSELoss()
#criterion = nn.L1Loss()


loaders = {"train": loader_train,
           # "valid": loader_val,
           # "test": loader_test,
           }

batch_size = args.batch_size

tartets_across_sims = []
predicted_data_across_sims = []
for name, loader in loaders.items():
    res = {'dataset': "test", 'loss': 0, 'counter': 0}
    for batch_idx, data in enumerate(loader):
        batch_size, n_nodes, _ = data[0].size()
        data = [d.to(device) for d in data]
        data = [d.view(-1, d.size(2)) for d in data]
        loc, vel, force, mass, y = data

        graph = Data(pos=loc, vel=vel, force=force, mass=mass, y=y)
        batch = torch.arange(0, batch_size)
        graph.batch = batch.repeat_interleave(n_nodes).long()
        graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

        graph = transform(graph)  # Add O3 attributes
        graph = graph.to(device)

        tartets_across_sims.append(graph.y)
        pred = model(graph)
        predicted_data_across_sims.append(pred)

        loss = criterion(pred, graph.y)

        #print("loss:", loss.item() * batch_size)
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size

        #break

    print('%s epoch avg loss: %.5f' % (loader.dataset.partition, res['loss'] / res['counter']))

    print(res['loss'] / res['counter'])




train epoch avg loss: 0.00782
0.007820523008358886
