In [1]:
import matplotlib
%matplotlib inline
import utils as ut
import importlib
from torch_scatter import scatter
from e3nn import o3, nn
from e3nn.math import soft_one_hot_linspace
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

importlib.reload(ut)

seed_value = 42

N = 5  # Number of particles
tEnd = 10.0  # time at which simulation ends
dt = 0.01  # timestep
softening = 0.15  # softening length
G = 1.0  # Newton's Gravitational Constant
boxSize = 1.0
mass_coef = 10.0
dims = 3
LOG_WANDB = False

hparams = {
    'N': N,  # Number of particles
    'tEnd': tEnd,  # Time at which simulation ends
    'dt': dt,  # Timestep
    'G': G,  # Newton's Gravitational Constant
    'boxSize': boxSize,  # Size of the simulation box
    'mass_coef': mass_coef  # Mass coefficient
}

combined_data = ut.simulate_gravitational_system(seed_value, N, tEnd, dt, softening, G, boxSize, mass_coef, dims=dims,
                                                 init_boxsize=boxSize)


OSError: [WinError 127] The specified procedure could not be found

In [None]:
inputs_np, targets_np = ut.process_data(combined_data, dims=dims)

# Model


In [None]:
class Convolution(torch.nn.Module):
    def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors, hidden_layer=256, embedding_dim=10) -> None:
        super().__init__()

        self.num_neighbors = num_neighbors

        tp = o3.FullyConnectedTensorProduct(
            irreps_in1=irreps_in,
            irreps_in2=irreps_sh,
            irreps_out=irreps_out,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = o3.FullyConnectedNet([embedding_dim, hidden_layer, tp.weight_numel], torch.relu)
        self.tp = tp
        self.irreps_out = self.tp.irreps_out

    def forward(self, node_features, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor:
        weight = self.fc(edge_scalars)
        edge_features = self.tp(node_features[edge_src], edge_attr, weight)
        node_features = scatter(edge_features, edge_dst, dim=0).div(self.num_neighbors ** 0.5)
        return node_features


class NbodyConv(torch.nn.Module):
    def __init__(self, l=2, hidden_layers=256, max_radius=None, num_basis=10) -> None:
        super().__init__()
        self.irreps_sh: o3.Irreps = o3.Irreps.spherical_harmonics(l)
        self.irreps_input = o3.Irreps("2x1o")
        self.irreps_output = o3.Irreps("2x1o")
        self.max_radius = max_radius
        self.num_basis = num_basis
        self.hidden_layers = hidden_layers

        self.tensor_product = o3.FullyConnectedTensorProduct(
            self.irreps_input,
            self.irreps_sh,
            self.irreps_output,
            shared_weights=False)
        
        self.conv = Convolution(self.irreps_input, self.irreps_sh, self.irreps_output, N-1, hidden_layer=self.hidden_layers, embedding_dim=num_basis)
        self.gate = nn.Gate(
            "16x0e + 16x0o",
            [torch.relu, torch.abs],  # scalar
            "8x0e + 8x0o + 8x0e + 8x0o",
            [torch.relu, torch.tanh, torch.relu, torch.tanh],  # gates (scalars)
            "16x1o + 16x1e",  # gated tensors, num_irreps has to match with gates
        )

    def forward(self, data) -> torch.Tensor:
        num_nodes = N
        num_neighbors = num_nodes - 1
        edge_src, edge_dst = data.edge_index

        spherical_harmonics = o3.spherical_harmonics(self.irreps_sh, data.edge_vec, normalize=True,
                                                     normalization='component')

        x = self.conv(data.x, edge_src, edge_dst, spherical_harmonics, data.edge_attr)
        
        #x = self.gate(x)        
        x = self.conv(data.x, edge_src, edge_dst, spherical_harmonics, data.edge_attr)
        return scatter(x, data.batch, dim=0).div(num_nodes**0.5)


# Dataloader

In [None]:
def create_fully_connected_data_with_edge_features(node_features, positions, targets, max_radius, num_basis):
    num_nodes = node_features.size(0)
    device = node_features.device

    # Generate fully connected edge_index for num_nodes
    row = torch.arange(num_nodes, device=device).repeat_interleave(num_nodes)
    col = torch.arange(num_nodes, device=device).repeat(num_nodes)
    edge_index = torch.stack([row, col], dim=0)

    # Avoid self-loops
    edge_index = edge_index[:, row != col]

    # Calculate edge features
    edge_vec = positions[edge_index[0]] - positions[edge_index[1]]
    edge_features = soft_one_hot_linspace(
        edge_vec.norm(dim=1),
        0.0,
        max_radius,
        num_basis,
        basis='smooth_finite',
        cutoff=True
    ).mul(num_basis ** 0.5)

    y = targets.detach().to(device).to(torch.float32)

    return Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, y=y, edge_vec=edge_vec)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
d = (boxSize**2 + boxSize**2 + boxSize**2)**0.5
max_radius = d
num_basis = 10

inputs_tensor = torch.tensor(inputs_np, dtype=torch.float32)
targets_tensor = torch.tensor(targets_np, dtype=torch.float32)

data_list = []

for index, simulation_step_graph in enumerate(inputs_tensor):
    node_features = simulation_step_graph
    positions = node_features[..., :dims]
    targets = targets_tensor[index, ...]

    data = create_fully_connected_data_with_edge_features(node_features, positions, targets,
                                                          max_radius=max_radius, num_basis=num_basis)
    data_list.append(data)

data_loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

In [None]:
data_loader.dataset

In [None]:
data_loader.dataset[0]

# Training

In [None]:
model = NbodyConv()
print("wtf")
for batch in data_loader:
    # batch.x: current state node features
    # batch.edge_index: edges for fully connected graph
    # batch.edge_attr: edge features
    # batch.y: labels for the next state of the simulation
    outputs = model(batch)


In [None]:
model = NbodyConv()
num_epochs = 10
lr = 1e-2
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.train()
for epoch in range(num_epochs):
    total_metrics = {
        "loss": 0, "loss_pos": 0, "loss_vel": 0,
        "perc_error_pos": 0, "perc_error_vel": 0, "perc_error_pos_vs_vel_l1": 0, "perc_error_pos_vs_vel_l2": 0
    }
    num_batches = 0
    for batch in data_loader:
        inputs, targets = batch  # Adjust based on your data loading method

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            predicted_pos = outputs[..., :dims]
            target_pos = targets[..., :dims]

            predicted_vel = outputs[..., dims:]
            target_vel = targets[..., dims:]

            loss_pos = criterion(predicted_pos, target_pos)
            loss_vel = criterion(predicted_vel, target_vel)

            # Calculate percentage errors
            perc_error_pos = (torch.norm(predicted_pos - target_pos, dim=1) /
                              torch.norm(target_pos, dim=1)).mean() * 100

            perc_error_vel = (torch.norm(predicted_vel - target_vel, dim=1) /
                              torch.norm(target_vel, dim=1)).mean() * 100

            perc_error_pos_vs_vel_l1 = (torch.abs(predicted_pos - target_pos).mean() /
                                        torch.norm(target_vel, dim=1)).mean() * 100

            perc_error_pos_vs_vel_l2 = (torch.norm(predicted_pos - target_pos, dim=1) /
                                        torch.norm(target_vel, dim=1)).mean() * 100

            total_metrics["loss"] += loss.item()
            total_metrics["loss_pos"] += loss_pos.item()
            total_metrics["loss_vel"] += loss_vel.item()
            total_metrics["perc_error_pos"] += perc_error_pos.item()
            total_metrics["perc_error_vel"] += perc_error_vel.item()
            total_metrics["perc_error_pos_vs_vel_l1"] += perc_error_pos_vs_vel_l1.item()
            total_metrics["perc_error_pos_vs_vel_l2"] += perc_error_pos_vs_vel_l2.item()

        num_batches += 1

    if epoch % 1 == 0:
        avg_metrics = {k: v / num_batches for k, v in total_metrics.items()}
        print(
            f"Epoch [{epoch + 1}/{num_epochs}], avg_both: {avg_metrics['loss']:.5f}, avg_pos: {avg_metrics['loss_pos']: .5f}, avg_vel: {avg_metrics['loss_vel']: .5f}, perc_pos: {avg_metrics['perc_error_pos']: .5f}%, perc_vel: {avg_metrics['perc_error_vel']: .5f}%")