In [1]:
import glob
from utils import segnn_utils
import os
import numpy as np
import torch
import sys
from datasets.nbody.dataset_gravity import GravityDataset

sys.argv = [
    'main.py', '--dataset=gravity', '--epochs=5', '--max_samples=3000',
    '--model=segnn', '--lmax_h=1', '--lmax_attr=1', '--layers=4',
    '--hidden_features=64', '--subspace_type=weightbalanced', '--norm=none',
    '--batch_size=100', '--gpu=1', '--weight_decay=1e-12', '--target=pos'
]
parser = segnn_utils.create_argparser()
args = parser.parse_args()

torch.manual_seed(42)

print(sys.executable)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

run = os.path.join("segnn_runs", "2024-03-11 16-42_gravity_segnn")

models = glob.glob(run + "/" + '*.pth')
if len(models) > 1:
    print("MORE MODELS FOUND IN THE DIR, LOADING THE FIRST:", models[0])

model = torch.load(models[0], map_location=device)

dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

C:\Users\MartinKaras(AI)\.conda\envs\n_body_approx_3_10\python.exe


# Getting prediction from batches, where a batch is all steps of a single simulation



In [2]:
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform

transform = O3Transform(args.lmax_attr)

simulation_index = 0

loc, vel, force, mass = dataset_train.data

output_dims = loc.shape[-1]
batch_size = loc.shape[-3]
n_nodes = loc.shape[-2]
t_delta = 2

loc = torch.from_numpy(loc[simulation_index]).view(-1, output_dims)
vel = torch.from_numpy(vel[simulation_index]).view(-1, output_dims)
force = torch.from_numpy(force[simulation_index]).view(-1, output_dims)
mass = torch.from_numpy(mass[simulation_index]).repeat(batch_size, 1)
data = [loc, vel, force, mass]

if args.target == 'pos':
    y = loc
else:
    y = force

data = [d.to(device) for d in data]
loc, vel, force, mass = data

graph = Data(pos=loc, vel=vel, force=force, mass=mass, y=y)
batch = torch.arange(0, batch_size)
graph.batch = batch.repeat_interleave(n_nodes).long()
graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

graph = transform(graph)  # Add O3 attributes
graph = graph.to(device)
batch_prediction = model(graph).cpu().detach().numpy()



# Batch simulation where batch is [simulations x steps x nodes], be careful this requires lots of memory 



In [None]:
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform
from torch_geometric.data import Batch

transform = O3Transform(args.lmax_attr)

loc, vel, force, mass = dataset_train.data

max_simulations = 500

output_dims = loc.shape[-1]
batch_size = loc.shape[-3]
n_nodes = loc.shape[-2]
t_delta = 2

loc = torch.from_numpy(loc)
vel = torch.from_numpy(vel)
force = torch.from_numpy(force)
mass = torch.from_numpy(mass)

if args.target == 'pos':
    y = loc
else:
    y = force

graphs = []
for i in range(max_simulations):
    mss = mass[i].repeat(batch_size, 1)
    data_i = Data(pos=loc[i].view(-1, output_dims), vel=vel[i].view(-1, output_dims),
                  force=force[i].view(-1, output_dims), mass=mss, y=y[i].view(-1, output_dims))
    data_i.batch = torch.arange(batch_size).repeat_interleave(n_nodes).long()
    data_i.edge_index = knn_graph(data_i.pos, args.neighbours, data_i.batch)
    data_i = transform(data_i)
    graphs.append(data_i)

batched_graph = Batch.from_data_list(graphs)
batched_graph = batched_graph.to(device)

from torch_geometric.data import Batch

batched_graph = Batch.from_data_list(graphs)
batched_graph = batched_graph.to(device)
pred = model(batched_graph)

# Stepwise simulation


In [9]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform

transform = O3Transform(args.lmax_attr)

simulation_index = 0

loc, vel, force, mass = dataset_train.data

output_dims = loc.shape[-1]
n_nodes = loc.shape[-2]
t_delta = 2

# Assuming simulation_steps is defined or calculated from your dataset
simulation_steps = len(loc[simulation_index])  # Or however you determine the number of steps
stepwise_prediction = []
for step in range(simulation_steps):
    loc_step = torch.from_numpy(loc[simulation_index][step]).view(-1, output_dims).to(device)
    vel_step = torch.from_numpy(vel[simulation_index][step]).view(-1, output_dims).to(device)
    force_step = torch.from_numpy(force[simulation_index][step]).view(-1, output_dims).to(device)
    mass_step = torch.tensor(mass[simulation_index], dtype=torch.float).repeat(n_nodes, 1).to(device)

    if args.target == 'pos':
        y_step = loc_step
    else:
        y_step = force_step

    graph = Data(pos=loc_step, vel=vel_step, force=force_step, mass=mass_step, y=y_step)
    # Since we're dealing with single steps, no need to batch
    graph.edge_index = knn_graph(loc_step, args.neighbours)

    graph = transform(graph)  # Add O3 attributes
    graph = graph.to(device)
    stepwise_prediction.append(model(graph).detach().numpy())

stepwise_prediction = np.stack(stepwise_prediction)

# Batch vs stepwise comparison
they are identical

In [15]:
import torch
from sklearn.metrics import mean_squared_error


def compare_predictions(batch_preds_np, stepwise_preds_np):
    # Calculate MSE
    mse_batch = mean_squared_error(batch_preds_np.reshape(-1, 3), stepwise_preds_np.reshape(-1, 3))
    mse_stepwise = mean_squared_error(stepwise_preds_np.reshape(-1, 3), batch_preds_np.reshape(-1, 3))

    return {
        "MSE_Batch": mse_batch,
        "MSE_Stepwise": mse_stepwise,
        "MSE_Difference": abs(mse_batch - mse_stepwise)
    }


comparison_results = compare_predictions(batch_prediction, stepwise_prediction)

print(comparison_results)


{'MSE_Batch': 4.053788708247991e-31, 'MSE_Stepwise': 4.053788708247991e-31, 'MSE_Difference': 0.0}


# Plot the whole simulation

In [7]:
import utils.nbody_utils as ut
%matplotlib qt

particle_index = 4
boxSize = 5

loc_orig = loc.view(batch_size, n_nodes, output_dims)

targets = []
for i in range(0, batch_size - t_delta):
    targets.append(loc_orig[i + t_delta, :, :])

targets_np = np.array(targets)

predicted_data = batch_prediction.reshape(batch_size, n_nodes, output_dims)

ut.plot_trajectory(targets_np[: (batch_size - t_delta), ...], predicted_data, particle_index=particle_index,
                   loggers=[],
                   epoch=1,
                   dims=output_dims)

# Interactive plot of the simulation

In [8]:
sim_steps = batch_size - t_delta
ut.interactive_trajectory_plot_all_particles_3d(targets_np[0:sim_steps, ...], predicted_data[0:sim_steps, ...],
                                                particle_index,
                                                boxSize=boxSize, dims=output_dims, offline_plot=False, loggers=[],
                                                video_tag=f"One step prediction of a particle {particle_index}",
                                                trace_length=10)

In [10]:
import datasets.nbody.dataset_gravity as dataset_gravity
%matplotlib inline
import importlib
importlib.reload(dataset_gravity)

dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

dataset_train.plot_energy_statistics()


NameError: name 'importlib' is not defined

# Check whether the inference yields the same result as during the training


In [15]:
dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True)

dataset_val = GravityDataset(partition='val', dataset_name=args.nbody_name,
                             neighbours=args.neighbours, target=args.target)
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False)

dataset_test = GravityDataset(partition='test', dataset_name=args.nbody_name,
                              neighbours=args.neighbours, target=args.target)
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False)

from torch import nn

model.eval()
criterion = nn.MSELoss()
#criterion = nn.L1Loss()


loaders = {"train": loader_train,
           # "valid": loader_val,
           # "test": loader_test,
           }

batch_size = args.batch_size

tartets_across_sims = []
predicted_data_across_sims = []
for name, loader in loaders.items():
    res = {'dataset': "test", 'loss': 0, 'counter': 0}
    for batch_idx, data in enumerate(loader):
        batch_size, n_nodes, _ = data[0].size()
        data = [d.to(device) for d in data]
        data = [d.view(-1, d.size(2)) for d in data]
        loc, vel, force, mass, y = data

        graph = Data(pos=loc, vel=vel, force=force, mass=mass, y=y)
        batch = torch.arange(0, batch_size)
        graph.batch = batch.repeat_interleave(n_nodes).long()
        graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

        graph = transform(graph)  # Add O3 attributes
        graph = graph.to(device)

        tartets_across_sims.append(graph.y)
        pred = model(graph)
        predicted_data_across_sims.append(pred)

        loss = criterion(pred, graph.y)

        #print("loss:", loss.item() * batch_size)
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size

        #break

    print('%s epoch avg loss: %.5f' % (loader.dataset.partition, res['loss'] / res['counter']))

    print(res['loss'] / res['counter'])




train epoch avg loss: 0.00782
0.007820523008358886
