# N-Body Experiment
The goal of this notebook is to demonstrate how a workflow can be simplified using equiadapt.

In [31]:
import torch
from tqdm import tqdm

from equiadapt.nbody.canonicalization.euclidean_group import EuclideanGroupNBody
from equiadapt.nbody.canonicalization_networks.custom_equivariant_networks import VNDeepSets
from equiadapt.common.utils import gram_schmidt

from examples.nbody.networks.euclideangraph_base_models import GNN
from examples.nbody.prepare.nbody_data import NBodyDataModule
from examples.nbody.model_utils import get_edges


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Hyperparameters

In [3]:
class Hyperparameters:
    def __init__(self):
        self.model = "NBodyPipeline"
        self.canon_model_type = "vndeepsets"
        self.pred_model_type = "Transformer"
        self.batch_size = 100
        self.dryrun = False
        self.use_wandb = False
        self.checkpoint = False
        self.num_epochs = 1000
        self.num_workers = 0
        self.auto_tune = False
        self.seed = 0
        self.learning_rate = 1e-3
        self.weight_decay = 1e-12
        self.patience = 1000

class CanonicalizationHyperparameters:
    def __init__(self):
        self.architecture = "vndeepsets"
        self.num_layers = 4
        self.hidden_dim = 16
        self.layer_pooling = "mean"
        self.final_pooling = "mean"
        self.out_dim = 4
        self.batch_size = 100
        self.nonlinearity = "relu"
        self.canon_feature = "p"
        self.canon_translation = False
        self.angular_feature = "pv"
        self.dropout = 0.5

class PredictionHyperparameters:
    def __init__(self):
        self.architecture = "GNN"
        self.num_layers = 4
        self.hidden_dim = 32
        self.input_dim = 6
        self.in_node_nf = 1
        self.in_edge_nf = 2


In [4]:
hyperparams = Hyperparameters()
canon_hyperparams = CanonicalizationHyperparameters()
pred_hyperparams = PredictionHyperparameters()
hyperparams.canon_hyperparams = canon_hyperparams
hyperparams.pred_hyperparams = pred_hyperparams

## Data

### Preparing Data

In [5]:
nbody_data = NBodyDataModule(hyperparams)
nbody_data.setup()
train_loader = nbody_data.train_dataloader()

nbody_data.setup(stage="test")
test_loader = nbody_data.val_dataloader()


  edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2)


In [6]:
# Splits the batch into location features, velocity features, 
# node features, edges, edge features, charges, and end locations (ie. targets)
def get_data(batch):
    batch_size, n_nodes, _ = batch[0].size()
    batch = [d.view(-1, d.size(2)) for d in batch]  # converts to 2D matrices
    loc, vel, edge_attr, charges, loc_end = batch
    edges = get_edges(
        batch_size, n_nodes
    )  # returns a list of two tensors, each of size num_edges * batch_size (where num_edges is always 20, since G = K5)

    nodes = (
        torch.sqrt(torch.sum(vel**2, dim=1)).unsqueeze(1).detach()
    )  # norm of velocity vectors
    rows, cols = edges
    loc_dist = torch.sum((loc[rows] - loc[cols]) ** 2, 1).unsqueeze(
        1
    )  # relative distances among locations
    edge_attr = torch.cat(
        [edge_attr, loc_dist], 1
    ).detach()  # concatenate all edge properties

    return loc, vel, nodes, edges, edge_attr, charges, loc_end

## Training

### Training Without `equiadapt`

In [26]:
canonicalization_network = VNDeepSets(canon_hyperparams).to(device)
prediction_network = GNN(pred_hyperparams).to(device)

In [27]:
optimizer = torch.optim.Adam(
            [
                {
                    "params": prediction_network.parameters(),
                    "lr": hyperparams.learning_rate,
                },
                {"params": canonicalization_network.parameters(), "lr": hyperparams.learning_rate},
            ]
        )
loss_fn = torch.nn.MSELoss()

In [28]:
epochs = 20

for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
                 
    total_loss, total_task_loss, = 0.0, 0.0,
    for batch_idx, batch in tqdm_bar:

        optimizer.zero_grad()

        training_metrics = {}
        loss = 0.0

        batch = [b.to(device) for b in batch]

        # Split batch into inputs and targets
        loc, vel, nodes, edges, edge_attr, charges, loc_end = get_data(batch)

        # ------------------- code starting here is replaced by equiadapt -------------------

        # Obtain rotation and translation vectors for canonicalization
        rotation_vectors, translation_vectors = canonicalization_network(nodes, loc, edges, vel, edge_attr, charges)
        rotation_matrix = gram_schmidt(rotation_vectors)
        rotation_matrix_inverse = rotation_matrix.transpose(1, 2)

        # Canonicalize node locations
        canonical_loc = (torch.bmm(loc[:, None, :], 
                                   rotation_matrix_inverse).squeeze()- torch.bmm(translation_vectors[:, None, :], 
                                   rotation_matrix_inverse).squeeze()
        )
        # Canonicalize node velocities
        canonical_vel = torch.bmm(vel[:, None, :], rotation_matrix_inverse).squeeze() 
        # Make prediction using canonical inputs 
        canonical_pred_loc = prediction_network(nodes, canonical_loc, edges, canonical_vel, edge_attr, charges)
        # Un-canonicalize the predicted locations     
        pred_loc = (torch.bmm(canonical_pred_loc[:, None, :], rotation_matrix).squeeze()+ translation_vectors)

        # -----------------------------------------------------------------------------------

        task_loss = loss_fn(pred_loc, loc_end)

        loss += task_loss

        # Logging the training metrics
        total_loss += loss.item()
        total_task_loss += task_loss.item()  
        training_metrics.update({
                "task_loss": total_task_loss / (batch_idx + 1),
                "loss": total_loss / (batch_idx + 1),
            })  
        
        # Usual training steps
        loss.backward()
        
        optimizer.step()
        
        # Log the training metrics
        tqdm_bar.set_postfix(training_metrics)



        

Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.67it/s, task_loss=1.89, loss=1.89]
Epoch 1: 100%|██████████| 30/30 [00:01<00:00, 17.58it/s, task_loss=0.135, loss=0.135]
Epoch 2: 100%|██████████| 30/30 [00:01<00:00, 16.49it/s, task_loss=0.0769, loss=0.0769]
Epoch 3: 100%|██████████| 30/30 [00:01<00:00, 16.13it/s, task_loss=0.0711, loss=0.0711]
Epoch 4: 100%|██████████| 30/30 [00:01<00:00, 15.57it/s, task_loss=0.0677, loss=0.0677]
Epoch 5: 100%|██████████| 30/30 [00:01<00:00, 16.25it/s, task_loss=0.065, loss=0.065]  
Epoch 6: 100%|██████████| 30/30 [00:02<00:00, 10.09it/s, task_loss=0.0626, loss=0.0626]
Epoch 7: 100%|██████████| 30/30 [00:02<00:00, 14.06it/s, task_loss=0.0619, loss=0.0619]
Epoch 8: 100%|██████████| 30/30 [00:03<00:00,  9.61it/s, task_loss=0.0571, loss=0.0571]
Epoch 9: 100%|██████████| 30/30 [00:01<00:00, 16.93it/s, task_loss=0.0527, loss=0.0527]
Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 11.15it/s, task_loss=0.0502, loss=0.0502]
Epoch 11: 100%|██████████| 30/30 [00:

### Training with `equiadapt`
Using `equiadapt`, we use an instance of `EuclideanGroupNBody`, which handles canonicalization and inverting canonicalization, using the `.canonicalize` and `invert_canonicalization` methods, respectively.

In [29]:
canonicalization_network = VNDeepSets(canon_hyperparams)
prediction_network = GNN(pred_hyperparams)
canonicalizer = EuclideanGroupNBody(canonicalization_network, canon_hyperparams)
optimizer = torch.optim.Adam(
            [
                {
                    "params": prediction_network.parameters(),
                    "lr": hyperparams.learning_rate,
                },
                {"params": canonicalization_network.parameters(), "lr": hyperparams.learning_rate},
            ]
        )
loss_fn = torch.nn.MSELoss()

In [30]:
epochs = 20

for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
                 
    total_loss, total_task_loss, = 0.0, 0.0,
    for batch_idx, batch in tqdm_bar:

        optimizer.zero_grad()

        training_metrics = {}
        loss = 0.0

        batch = [b.to(device) for b in batch]

        loc, vel, nodes, edges, edge_attr, charges, loc_end = get_data(batch)

        ## ------------------- equiadapt code -------------------

        # canonicalize the input data
        canonical_loc, canonical_vel = canonicalizer(x=nodes, targets=None, loc=loc, edges=edges, vel=vel, edge_attr=edge_attr, charges=charges,)  
        canonical_pred_loc = prediction_network(nodes, canonical_loc, edges, canonical_vel, edge_attr, charges)  # predict the output
        pred_loc = canonicalizer.invert_canonicalization(canonical_pred_loc)  # invert the canonicalization

        ## -----------------------------------------------------


        task_loss = loss_fn(pred_loc, loc_end)

        loss += task_loss

        # Logging the training metrics
        total_loss += loss.item()
        total_task_loss += task_loss.item()  
        training_metrics.update({
                "task_loss": total_task_loss / (batch_idx + 1),
                "loss": total_loss / (batch_idx + 1),
            })  
        
        # Usual training steps
        loss.backward()
        
        optimizer.step()
        
        # Log the training metrics
        tqdm_bar.set_postfix(training_metrics)


Epoch 0: 100%|██████████| 30/30 [00:02<00:00, 14.93it/s, task_loss=1.63, loss=1.63]
Epoch 1: 100%|██████████| 30/30 [00:01<00:00, 16.78it/s, task_loss=0.132, loss=0.132]
Epoch 2: 100%|██████████| 30/30 [00:02<00:00, 14.36it/s, task_loss=0.079, loss=0.079]  
Epoch 3: 100%|██████████| 30/30 [00:01<00:00, 15.51it/s, task_loss=0.0703, loss=0.0703]
Epoch 4: 100%|██████████| 30/30 [00:02<00:00, 10.34it/s, task_loss=0.0675, loss=0.0675]
Epoch 5: 100%|██████████| 30/30 [00:01<00:00, 17.05it/s, task_loss=0.065, loss=0.065]  
Epoch 6: 100%|██████████| 30/30 [00:03<00:00,  9.92it/s, task_loss=0.0624, loss=0.0624]
Epoch 7: 100%|██████████| 30/30 [00:01<00:00, 16.50it/s, task_loss=0.0604, loss=0.0604]
Epoch 8: 100%|██████████| 30/30 [00:01<00:00, 16.10it/s, task_loss=0.0591, loss=0.0591]
Epoch 9: 100%|██████████| 30/30 [00:02<00:00, 14.37it/s, task_loss=0.0542, loss=0.0542]
Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 12.92it/s, task_loss=0.0467, loss=0.0467]
Epoch 11: 100%|██████████| 30/30 [00: