In [51]:
import glob
from utils import segnn_utils
import os
import numpy as np
import torch
import sys
from datasets.nbody.dataset_gravity import GravityDataset
import importlib
import json
from types import SimpleNamespace
%matplotlib inline

torch.manual_seed(42)

print(sys.executable)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#run = os.path.join("segnn_runs", "2024-03-18 18-38_gravityV2_segnn")
run = os.path.join("segnn_runs", "2024-03-20 11-01_gravityV2_segnn")

# training args
metadata_path = os.path.join(run, 'training_args.json')
if os.path.exists(metadata_path):
    with open(metadata_path, 'r') as json_file:
        args_dict = json.load(json_file)
    args = SimpleNamespace(**args_dict)
    args = SimpleNamespace(**args.args)

models = glob.glob(run + "/" + '*.pth')
if len(models) > 1:
    print("MORE MODELS FOUND IN THE DIR, LOADING THE FIRST:", models[0])

model = torch.load(models[0], map_location=device)
model.eval()

dataset = GravityDataset(partition='test', dataset_name=args.nbody_name,
                         max_samples=args.max_samples, neighbours=args.neighbours, target=args.target,
                         path=os.path.join(run, "gravity"))

t_delta = 2
loc, vel, force, mass = dataset.data
dims = loc.shape[-1]
boxSize = 5
particle_index = None

C:\Users\MartinKaras(AI)\.conda\envs\n_body_approx_3_10\python.exe


In [52]:
import glob
import os
import argparse
import copy
import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from datasets.nbody.train_gravity_V2 import O3Transform
import multiprocessing

import torch

def rotate_positions(positions, rotation_matrix):
    """Rotate positions by a given rotation matrix."""
    return torch.einsum(
        "ij,nj->ni",
        rotation_matrix,
        positions,
    )


transform = O3Transform(args.lmax_attr)

simulation_index = 0
all_predictions = []
loc, vel, force, mass = copy.deepcopy(dataset.data)

data_dims = loc.shape[-1]
batch_size = loc.shape[-3]
#batch_size = 10000
n_nodes = loc.shape[-2]



loc = torch.from_numpy(loc[simulation_index]).view(-1, data_dims)
vel = torch.from_numpy(vel[simulation_index]).view(-1, data_dims)
force = torch.from_numpy(force[simulation_index]).view(-1, data_dims)

# new_loc = torch.rand(batch_size, data_dims)
# new_vel = torch.rand(batch_size, data_dims)
# new_force = torch.rand(batch_size, data_dims)
mass = torch.from_numpy(mass[simulation_index]).repeat(batch_size, 1)

loc, vel, force, mass = [d.to(device) for d in [loc, vel, force, mass]]

graph = Data(pos=loc, vel=vel, force=force, mass=mass)
batch = torch.arange(0, batch_size)
graph.batch = batch.repeat_interleave(n_nodes).long()
graph.edge_index = knn_graph(loc, args.neighbours, graph.batch)

graph = transform(graph)  # Add O3 attributes
graph = graph.to(device)
batch_prediction = model(graph).cpu().detach().numpy()
output_dims = batch_prediction.shape[-1]
pred_before_rotation = batch_prediction.copy()
pred_before_rotation_reshaped = batch_prediction.reshape(batch_size, n_nodes, output_dims)



In [53]:
# Generate a random rotation matrix
rotation_matrix = torch.randn(3, 3).to(device)
rotation_matrix, _ = torch.linalg.qr(rotation_matrix)
rotation_matrix = rotation_matrix.to(dtype=loc.dtype)
# Rotate the input
loc_rotated = rotate_positions(loc.squeeze(), rotation_matrix).view(-1, dims)
vel_rotated = rotate_positions(vel.squeeze(), rotation_matrix).view(-1, dims)
force_rotated = rotate_positions(force.squeeze(), rotation_matrix).view(-1, dims)

In [54]:
graph = Data(pos=loc_rotated, vel=vel_rotated, force=force_rotated, mass=mass)
batch = torch.arange(0, batch_size)
graph.batch = batch.repeat_interleave(n_nodes).long()
graph.edge_index = knn_graph(loc_rotated, args.neighbours, graph.batch)

graph = transform(graph)  # Add O3 attributes
graph = graph.to(device)
batch_prediction = model(graph).cpu().detach().numpy()
output_dims = batch_prediction.shape[-1]
pred_after_rotation = batch_prediction.copy()
pred_after_rotation_reshaped = batch_prediction.reshape(batch_size, n_nodes, output_dims)


In [61]:

# Rotate the prediction back
pred_after_rotation_back_pos = rotate_positions(torch.from_numpy(pred_after_rotation[...,:dims]), rotation_matrix.T)
pred_after_rotation_back_vel = rotate_positions(torch.from_numpy(pred_after_rotation[...,dims:]), rotation_matrix.T)

# pred_after_rotation_back_pos = torch.from_numpy(pred_after_rotation[...,:dims])
# pred_after_rotation_back_vel = torch.from_numpy(pred_after_rotation[...,dims:])
# 
# # Check if the prediction before rotation is close to the rotated-back prediction after rotation
# assert torch.allclose(
#     pred_before_rotation,
#     pred_after_rotation_back,
#     atol=1e-3,
# ), "The model is not equivariant under rotation"
# 
# print("Equivariance test passed!")


In [57]:
pred_after_rotation_back_pos.mean()

tensor(5.3161e-05, dtype=torch.float64)

In [62]:
(pred_after_rotation_back_pos - pred_before_rotation[...,:dims]).mean()

tensor(-2.2696e-05, dtype=torch.float64)

In [63]:
torch.allclose(
    pred_after_rotation_back_pos,
    torch.from_numpy(pred_before_rotation[...,:dims]),
)

False

In [65]:
loc.diff().mean()

tensor(-0.3825, dtype=torch.float64)