In [1]:
import argparse


def create_argparser():
    parser = argparse.ArgumentParser()

    # Run parameters
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size. Does not scale with number of gpus.')
    parser.add_argument('--lr', type=float, default=5e-4,
                        help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-8,
                        help='weight decay')
    parser.add_argument('--print', type=int, default=100,
                        help='print interval')
    parser.add_argument('--log', type=bool, default=False,
                        help='logging flag')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='Num workers in dataloader')
    parser.add_argument('--save_dir', type=str, default="saved models",
                        help='Directory in which to save models')

    # Data parameters
    parser.add_argument('--dataset', type=str, default="qm9",
                        help='Data set')
    parser.add_argument('--root', type=str, default="datasets",
                        help='Data set location')
    parser.add_argument('--download', type=bool, default=False,
                        help='Download flag')

    # QM9 parameters
    parser.add_argument('--target', type=str, default="alpha",
                        help='Target value, also used for gravity dataset [pos, force]')
    parser.add_argument('--radius', type=float, default=2,
                        help='Radius (Angstrom) between which atoms to add links.')
    parser.add_argument('--feature_type', type=str, default="one_hot",
                        help='Type of input feature: one-hot, or Cormorants charge thingy')

    # Nbody parameters:
    parser.add_argument('--nbody_name', type=str, default="nbody_small",
                        help='Name of nbody data [nbody, nbody_small]')
    parser.add_argument('--max_samples', type=int, default=3000,
                        help='Maximum number of samples in nbody dataset')
    parser.add_argument('--time_exp', type=bool, default=False,
                        help='Flag for timing experiment')
    parser.add_argument('--test_interval', type=int, default=5,
                        help='Test every test_interval epochs')
    parser.add_argument('--n_nodes', type=int, default=5,
                        help='How many nodes are in the graph.')

    # Gravity parameters:
    parser.add_argument('--neighbours', type=int, default=6,
                        help='Number of connected nearest neighbours')

    # Model parameters
    parser.add_argument('--model', type=str, default="segnn",
                        help='Model name')
    parser.add_argument('--hidden_features', type=int, default=128,
                        help='max degree of hidden rep')
    parser.add_argument('--lmax_h', type=int, default=2,
                        help='max degree of hidden rep')
    parser.add_argument('--lmax_attr', type=int, default=3,
                        help='max degree of geometric attribute embedding')
    parser.add_argument('--subspace_type', type=str, default="weightbalanced",
                        help='How to divide spherical harmonic subspaces')
    parser.add_argument('--layers', type=int, default=7,
                        help='Number of message passing layers')
    parser.add_argument('--norm', type=str, default="instance",
                        help='Normalisation type [instance, batch]')
    parser.add_argument('--pool', type=str, default="avg",
                        help='Pooling type type [avg, sum]')
    parser.add_argument('--conv_type', type=str, default="linear",
                        help='Linear or non-linear aggregation of local information in SEConv')

    # Parallel computing stuff
    parser.add_argument('-g', '--gpus', default=0, type=int,
                        help='number of gpus to use (assumes all are on one node)')

    return parser

In [2]:
import glob
import os
import numpy as np
import torch
import sys
from datasets.nbody.dataset_gravity import GravityDataset

torch.manual_seed(42)

print(sys.executable)

import utils.nbody_utils as ut

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

run = os.path.join("segnn_runs", "2024-03-11 16-42_gravity_segnn")

models = glob.glob(run + "/" + '*.pth')
if len(models) > 1:
    print("MORE MODELS FOUND IN THE DIR, LOADING THE FIRST:", models[0])

model = torch.load(models[0], map_location=device)

sys.argv = [
    'main.py', '--dataset=gravity', '--epochs=5', '--max_samples=3000',
    '--model=segnn', '--lmax_h=1', '--lmax_attr=1', '--layers=4',
    '--hidden_features=64', '--subspace_type=weightbalanced', '--norm=none',
    '--batch_size=100', '--gpu=1', '--weight_decay=1e-12', '--target=pos'
]
parser = create_argparser()

args = parser.parse_args()

dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

C:\Users\MartinKaras(AI)\.conda\envs\n_body_approx_3_10\python.exe


# Batch simulation

In [3]:
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity import O3Transform

transform = O3Transform(args.lmax_attr)

simulation_index = 0

loc, vel, force, mass = dataset_train.data

output_dims = 3
batch_size = 50
n_nodes = 5
t_delta = 2

loc = torch.from_numpy(loc[simulation_index]).view(-1, output_dims)
vel = torch.from_numpy(vel[simulation_index]).view(-1, output_dims)
force = torch.from_numpy(force[simulation_index]).view(-1, output_dims)
mass = mass[simulation_index].repeat(batch_size, 1)
data = [loc, vel, force, mass]

if args.target == 'pos':
    y = loc
else:
    y = force

data = [d.to(device) for d in data]
loc, vel, force, mass = data

graph = Data(pos=loc, vel=vel, force=force, mass=mass, y=y)
batch = torch.arange(0, batch_size)
graph.batch = batch.repeat_interleave(n_nodes).long()
graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

graph = transform(graph)  # Add O3 attributes
graph = graph.to(device)
pred = model(graph)



# Stepwise simulation

In [4]:
# import torch
# from torch_geometric.data import Data
# from torch_geometric.nn import knn_graph
# from datasets.nbody.train_gravity import O3Transform
# 
# transform = O3Transform(args.lmax_attr)
# 
# simulation_index = 0
# 
# loc, vel, force, mass = dataset_train.data
# 
# output_dims = 3
# n_nodes = 5
# t_delta = 2
# 
# # Assuming simulation_steps is defined or calculated from your dataset
# simulation_steps = len(loc[simulation_index])  # Or however you determine the number of steps
# 
# for step in range(simulation_steps):
#     loc_step = torch.from_numpy(loc[simulation_index][step]).view(-1, output_dims).to(device)
#     vel_step = torch.from_numpy(vel[simulation_index][step]).view(-1, output_dims).to(device)
#     force_step = torch.from_numpy(force[simulation_index][step]).view(-1, output_dims).to(device)
#     mass_step = torch.tensor(mass[simulation_index], dtype=torch.float).repeat(n_nodes, 1).to(device)
# 
#     if args.target == 'pos':
#         y_step = loc_step
#     else:
#         y_step = force_step
# 
#     graph = Data(pos=loc_step, vel=vel_step, force=force_step, mass=mass_step, y=y_step)
#     # Since we're dealing with single steps, no need to batch
#     graph.edge_index = knn_graph(loc_step, args.neighbours)
# 
#     graph = transform(graph)  # Add O3 attributes
#     graph = graph.to(device)
#     pred = model(graph)
# 
#     # Process prediction here or store for later analysis


# Plot the whole simulation

In [5]:
import utils.nbody_utils as ut
%matplotlib qt

particle_index = 4
boxSize = 5

predicted_data = pred.view(batch_size, n_nodes, output_dims).detach().cpu().numpy()
loc_orig = loc.view(batch_size, n_nodes, output_dims)

targets = []
for i in range(0, batch_size - t_delta):
    targets.append(loc_orig[i + t_delta, :, :])

targets_np = np.array(targets)
hmm = loc_orig

ut.plot_trajectory(targets_np[: (batch_size - t_delta), ...], predicted_data, particle_index=particle_index, loggers=[],
                   epoch=1,
                   dims=output_dims)

# Interactive plot of the simulation

# Pre istotu overenie, ci pri inferenci dosiahnem rovnaky result ako pocas treningu
...spoiler alert, ano dostanem

In [6]:
dataset_train = GravityDataset(partition='train', dataset_name=args.nbody_name,
                               max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True)

dataset_val = GravityDataset(partition='val', dataset_name=args.nbody_name,
                             neighbours=args.neighbours, target=args.target)
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False)

dataset_test = GravityDataset(partition='test', dataset_name=args.nbody_name,
                              neighbours=args.neighbours, target=args.target)
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False)

from torch import nn

model.eval()
criterion = nn.MSELoss()
#criterion = nn.L1Loss()


loaders = {"train": loader_train,
           # "valid": loader_val,
           # "test": loader_test,
           }

batch_size = args.batch_size

tartets_across_sims = []
predicted_data_across_sims = []
for name, loader in loaders.items():
    res = {'dataset': "test", 'loss': 0, 'counter': 0}
    for batch_idx, data in enumerate(loader):
        batch_size, n_nodes, _ = data[0].size()
        data = [d.to(device) for d in data]
        data = [d.view(-1, d.size(2)) for d in data]
        loc, vel, force, mass, y = data

        graph = Data(pos=loc, vel=vel, force=force, mass=mass, y=y)
        batch = torch.arange(0, batch_size)
        graph.batch = batch.repeat_interleave(n_nodes).long()
        graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

        graph = transform(graph)  # Add O3 attributes
        graph = graph.to(device)

        tartets_across_sims.append(graph.y)
        pred = model(graph)
        predicted_data_across_sims.append(pred)

        loss = criterion(pred, graph.y)

        #print("loss:", loss.item() * batch_size)
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size

        #break

    print('%s epoch avg loss: %.5f' % (loader.dataset.partition, res['loss'] / res['counter']))

    print(res['loss'] / res['counter'])




train epoch avg loss: 0.00782
0.007820523008358886


# Interaktivny plot 

In [7]:
import importlib

importlib.reload(ut)

batch_index = 0
bodies = 5
dims = 3

target_batch = tartets_across_sims[batch_index].view(args.batch_size, bodies, dims).detach().cpu().numpy()
predicted_batch = predicted_data_across_sims[batch_index].view(args.batch_size, bodies, dims).detach().cpu().numpy()

sim_steps = batch_size - t_delta
ut.interactive_trajectory_plot_all_particles_3d(target_batch[0:sim_steps, ...], predicted_batch[0:sim_steps, ...],
                                                particle_index,
                                                boxSize=boxSize, dims=output_dims, offline_plot=False, loggers=[],
                                                video_tag=f"One step prediction of a particle {particle_index}", trace_length=1)

In [8]:
import datasets.nbody.dataset.synthetic_sim as synthetic_sim
import importlib
import numpy as np

importlib.reload(synthetic_sim)
importlib.reload(synthetic_sim)

np.random.seed(43)
sim = synthetic_sim.GravitySim(n_balls=5, loc_std=1)

loc, vel, force, mass = sim.sample_trajectory(T=50000, sample_freq=1)
sim.plot_energies(loc, vel, mass)
sim.plot_trajectory_static(loc)

(50000, 5, 3)
idem logov at 0
logging 0
[[ 0.25739993 -0.90848143 -0.37850311]
 [-0.5349156   0.85807335 -0.41300998]
 [ 0.49818858  2.01019925  1.26286154]
 [-0.43921486 -0.34643789  0.45531966]
 [-1.66866271 -0.8620855   0.49291085]]
idem logov at 1
logging 1
[[ 0.25741966 -0.90745698 -0.3790664 ]
 [-0.53581799  0.85627241 -0.41294084]
 [ 0.49817205  2.01151844  1.26251717]
 [-0.43901608 -0.34646472  0.45526621]
 [-1.66796228 -0.86260137  0.49380283]]
idem logov at 2
logging 2
[[ 0.25743871 -0.90643189 -0.37962911]
 [-0.53672025  0.8544708  -0.41287121]
 [ 0.49815533  2.01283727  1.26217255]
 [-0.43881743 -0.34649158  0.45521211]
 [-1.667261   -0.86311683  0.49469463]]
idem logov at 3
logging 3
[[ 0.25745707 -0.90540618 -0.38019124]
 [-0.53762239  0.85266854 -0.41280109]
 [ 0.49813842  2.01415574  1.26182768]
 [-0.43861889 -0.34651847  0.45515735]
 [-1.66655886 -0.86363187  0.49558626]]
idem logov at 4
logging 4
[[ 0.25747475 -0.90437984 -0.38075278]
 [-0.5385244   0.85086562 -0.4127

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



logging 22254
[[ -3.20115732   1.85424214   0.19331924]
 [ 10.31555411 -15.11142359  14.88329577]
 [ -3.77871128   7.51168863  -7.2629269 ]
 [ -2.60973      1.67760307  -0.95907589]
 [ -2.61316016   4.81915752  -5.43503326]]
idem logov at 22255
logging 22255
[[ -3.20120733   1.85393555   0.19317742]
 [ 10.31603864 -15.11194092  14.88389892]
 [ -3.77879149   7.51137047  -7.26308033]
 [ -2.6095183    1.67833578  -0.95939765]
 [ -2.61372617   4.8195669   -5.4350194 ]]
idem logov at 22256
logging 22256
[[ -3.20125707   1.85362889   0.19303505]
 [ 10.31652317 -15.11245824  14.88450206]
 [ -3.77887167   7.51105223  -7.26323369]
 [ -2.60930687   1.67906859  -0.95971893]
 [ -2.61429221   4.8199763   -5.43500553]]
idem logov at 22257
logging 22257
[[ -3.20130655   1.85332218   0.19289215]
 [ 10.3170077  -15.11297557  14.8851052 ]
 [ -3.77895181   7.51073391  -7.26338699]
 [ -2.6090957    1.67980151  -0.96003974]
 [ -2.61485829   4.82038575  -5.43499166]]
idem logov at 22258
logging 22258
[[ -3.

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



logging 47258
[[ -6.34878006   4.89446364  -6.74955664]
 [ 22.00662315 -27.47695374  29.40477275]
 [ -7.73346094   8.88589363 -10.56069671]
 [ -5.98461366   6.33789871  -7.05844383]
 [ -3.82697314   8.10996553  -3.61649661]]
idem logov at 47259
logging 47259
[[ -6.34839477   4.89442331  -6.74984969]
 [ 22.00707977 -27.47743416  29.40533914]
 [ -7.73404256   8.88586522 -10.56101717]
 [ -5.98474041   6.33848369  -7.05803832]
 [ -3.82710668   8.10992971  -3.616855  ]]
idem logov at 47260
logging 47260
[[ -6.34800937   4.89438342  -6.75014281]
 [ 22.0075364  -27.47791457  29.40590553]
 [ -7.73462415   8.88583677 -10.56133756]
 [ -5.98486726   6.33906831  -7.05763272]
 [ -3.82724027   8.10989384  -3.61721346]]
idem logov at 47261
logging 47261
[[ -6.34762386   4.89434399  -6.75043603]
 [ 22.00799302 -27.47839498  29.40647191]
 [ -7.7352057    8.88580827 -10.56165788]
 [ -5.98499421   6.33965256  -7.05722703]
 [ -3.82737391   8.10985794  -3.61757201]]
idem logov at 47262
logging 47262
[[ -6.

In [9]:
sim.plot_energies(loc, vel, mass)

In [10]:
sim.plot_trajectory_static(loc)

In [11]:
sim.plot_histograms(loc, vel)
# todo zo vsetkych

In [43]:
import datasets.nbody.dataset.synthetic_sim as synthetic_sim
import importlib

np.random.seed(22)
importlib.reload(ut)
importlib.reload(synthetic_sim)

sim = synthetic_sim.GravitySim(n_balls=5, loc_std=1, interaction_strength=2)
loc, vel, force, mass = sim.sample_trajectory(T=5000, sample_freq=100)
sim_steps = batch_size - t_delta
sim.interactive_trajectory_plot_all_particles_3d(loc, None,
                                                None,
                                                boxSize=boxSize, dims=output_dims, offline_plot=False)


In [44]:
new_loc, new_vel, new_force, mass = sim.sample_trajectory(T=1000, sample_freq=100, og_pos_save=loc, og_vel_save=vel, og_force_save=force)


In [14]:
sim.interactive_trajectory_plot_all_particles_3d(targets_np[0:sim_steps, ...], predicted_data[0:sim_steps],
                                                 None,
                                                 boxSize=boxSize, dims=output_dims, offline_plot=False)

In [46]:
sim.interactive_trajectory_plot_all_particles_3d(new_loc, None,
                                                 0,
                                                 boxSize=1, dims=output_dims, offline_plot=False)

In [50]:
import datasets.nbody.dataset_gravity as dataset_gravity

import importlib

importlib.reload(dataset_gravity)



dataset_train = dataset_gravity.GravityDataset(partition='train', dataset_name=args.nbody_name,
                                   max_samples=args.max_samples, neighbours=args.neighbours, target=args.target)

loc, vel, force, mass = dataset_train.get_one_sim_data(5)



In [53]:
sim.plot_histograms(loc, vel)