# N-Body Experiment
The goal of this notebook is to demonstrate how a workflow can be simplified using equiadapt.

In [32]:
import torch
from tqdm import tqdm

from equiadapt.nbody.canonicalization.euclidean_group import EuclideanGroupNBody
from equiadapt.nbody.canonicalization_networks.custom_equivariant_networks import VNDeepSets
from equiadapt.common.utils import gram_schmidt

from examples.nbody.networks.euclideangraph_base_models import GNN
from examples.nbody.prepare.nbody_data import NBodyDataModule
from examples.nbody.model_utils import get_edges


In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Hyperparameters

In [34]:
class Hyperparameters:
    def __init__(self):
        self.model = "NBodyPipeline"
        self.canon_model_type = "vndeepsets"
        self.pred_model_type = "Transformer"
        self.batch_size = 100
        self.dryrun = False
        self.use_wandb = False
        self.checkpoint = False
        self.num_epochs = 1000
        self.num_workers = 0
        self.auto_tune = False
        self.seed = 0
        self.learning_rate = 1e-3
        self.weight_decay = 1e-12
        self.patience = 1000

class CanonicalizationHyperparameters:
    def __init__(self):
        self.architecture = "vndeepsets"
        self.num_layers = 4
        self.hidden_dim = 16
        self.layer_pooling = "mean"
        self.final_pooling = "mean"
        self.out_dim = 4
        self.batch_size = 100
        self.nonlinearity = "relu"
        self.canon_feature = "p"
        self.canon_translation = False
        self.angular_feature = "pv"
        self.dropout = 0.5

class PredictionHyperparameters:
    def __init__(self):
        self.architecture = "GNN"
        self.num_layers = 4
        self.hidden_dim = 32
        self.input_dim = 6
        self.in_node_nf = 1
        self.in_edge_nf = 2


In [35]:
hyperparams = Hyperparameters()
canon_hyperparams = CanonicalizationHyperparameters()
pred_hyperparams = PredictionHyperparameters()
hyperparams.canon_hyperparams = canon_hyperparams
hyperparams.pred_hyperparams = pred_hyperparams

## Data

### Preparing Data

In [36]:
nbody_data = NBodyDataModule(hyperparams)
nbody_data.setup()
train_loader = nbody_data.train_dataloader()

nbody_data.setup(stage="test")
test_loader = nbody_data.val_dataloader()


In [37]:
# Splits the batch into location features, velocity features, 
# node features, edges, edge features, charges, and end locations (ie. targets)
def get_data(batch):
    batch_size, n_nodes, _ = batch[0].size()
    batch = [d.view(-1, d.size(2)) for d in batch]  # converts to 2D matrices
    loc, vel, edge_attr, charges, loc_end = batch
    edges = get_edges(
        batch_size, n_nodes
    )  # returns a list of two tensors, each of size num_edges * batch_size (where num_edges is always 20, since G = K5)

    nodes = (
        torch.sqrt(torch.sum(vel**2, dim=1)).unsqueeze(1).detach()
    )  # norm of velocity vectors
    rows, cols = edges
    loc_dist = torch.sum((loc[rows] - loc[cols]) ** 2, 1).unsqueeze(
        1
    )  # relative distances among locations
    edge_attr = torch.cat(
        [edge_attr, loc_dist], 1
    ).detach()  # concatenate all edge properties

    return loc, vel, nodes, edges, edge_attr, charges, loc_end

## Training

### Training Without `equiadapt`

In [38]:
canonicalization_network = VNDeepSets(canon_hyperparams).to(device)
prediction_network = GNN(pred_hyperparams).to(device)

In [39]:
optimizer = torch.optim.Adam(
            [
                {
                    "params": prediction_network.parameters(),
                    "lr": hyperparams.learning_rate,
                },
                {"params": canonicalization_network.parameters(), "lr": hyperparams.learning_rate},
            ]
        )
loss_fn = torch.nn.MSELoss()

In [40]:
epochs = 20

for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
                 
    total_loss, total_task_loss, = 0.0, 0.0,
    for batch_idx, batch in tqdm_bar:

        optimizer.zero_grad()

        training_metrics = {}
        loss = 0.0

        batch = [b.to(device) for b in batch]

        # Split batch into inputs and targets
        loc, vel, nodes, edges, edge_attr, charges, loc_end = get_data(batch)

        # ------------------- code starting here is replaced by equiadapt -------------------

        # Obtain rotation and translation vectors for canonicalization
        rotation_vectors, translation_vectors = canonicalization_network(nodes, loc, edges, vel, edge_attr, charges)
        rotation_matrix = gram_schmidt(rotation_vectors)
        rotation_matrix_inverse = rotation_matrix.transpose(1, 2)

        # Canonicalize node locations
        canonical_loc = (torch.bmm(loc[:, None, :], 
                                   rotation_matrix_inverse).squeeze()- torch.bmm(translation_vectors[:, None, :], 
                                   rotation_matrix_inverse).squeeze()
        )
        # Canonicalize node velocities
        canonical_vel = torch.bmm(vel[:, None, :], rotation_matrix_inverse).squeeze() 
        # Make prediction using canonical inputs 
        canonical_pred_loc = prediction_network(nodes, canonical_loc, edges, canonical_vel, edge_attr, charges)
        # Un-canonicalize the predicted locations     
        pred_loc = (torch.bmm(canonical_pred_loc[:, None, :], rotation_matrix).squeeze()+ translation_vectors)

        # -----------------------------------------------------------------------------------

        task_loss = loss_fn(pred_loc, loc_end)

        loss += task_loss

        # Logging the training metrics
        total_loss += loss.item()
        total_task_loss += task_loss.item()  
        training_metrics.update({
                "task_loss": total_task_loss / (batch_idx + 1),
                "loss": total_loss / (batch_idx + 1),
            })  
        
        # Usual training steps
        loss.backward()
        
        optimizer.step()
        
        # Log the training metrics
        tqdm_bar.set_postfix(training_metrics)



        

Epoch 0:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.88it/s, task_loss=1.99, loss=1.99]
Epoch 1: 100%|██████████| 30/30 [00:02<00:00, 14.98it/s, task_loss=0.156, loss=0.156]
Epoch 2: 100%|██████████| 30/30 [00:02<00:00, 12.26it/s, task_loss=0.0819, loss=0.0819]
Epoch 3: 100%|██████████| 30/30 [00:02<00:00, 10.37it/s, task_loss=0.0717, loss=0.0717]
Epoch 4: 100%|██████████| 30/30 [00:02<00:00, 10.86it/s, task_loss=0.0692, loss=0.0692]
Epoch 5: 100%|██████████| 30/30 [00:02<00:00, 14.54it/s, task_loss=0.0663, loss=0.0663]
Epoch 6: 100%|██████████| 30/30 [00:03<00:00,  9.71it/s, task_loss=0.0616, loss=0.0616]
Epoch 7: 100%|██████████| 30/30 [00:01<00:00, 15.01it/s, task_loss=0.0614, loss=0.0614]
Epoch 8: 100%|██████████| 30/30 [00:02<00:00, 13.51it/s, task_loss=0.065, loss=0.065]  
Epoch 9: 100%|██████████| 30/30 [00:02<00:00, 12.67it/s, task_loss=0.0527, loss=0.0527]
Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 14.25it/s, task_loss=0.0454, loss=0.0454]
Epoch 11: 100%|██████████| 30/30 [00:

### Training with `equiadapt`
Using `equiadapt`, we use an instance of `EuclideanGroupNBody`, which handles canonicalization and inverting canonicalization, using the `.canonicalize` and `invert_canonicalization` methods, respectively.

In [41]:
canonicalization_network = VNDeepSets(canon_hyperparams)
prediction_network = GNN(pred_hyperparams)
canonicalizer = EuclideanGroupNBody(canonicalization_network, canon_hyperparams)
optimizer = torch.optim.Adam(
            [
                {
                    "params": prediction_network.parameters(),
                    "lr": hyperparams.learning_rate,
                },
                {"params": canonicalization_network.parameters(), "lr": hyperparams.learning_rate},
            ]
        )
loss_fn = torch.nn.MSELoss()

In [42]:
epochs = 20

for epoch in range(epochs):
    tqdm_bar = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader))
                 
    total_loss, total_task_loss, = 0.0, 0.0,
    for batch_idx, batch in tqdm_bar:

        optimizer.zero_grad()

        training_metrics = {}
        loss = 0.0

        batch = [b.to(device) for b in batch]

        loc, vel, nodes, edges, edge_attr, charges, loc_end = get_data(batch)

        ## ------------------- equiadapt code -------------------

        # canonicalize the input data
        canonical_loc, canonical_vel = canonicalizer(x=nodes, targets=None, loc=loc, edges=edges, vel=vel, edge_attr=edge_attr, charges=charges,)  
        canonical_pred_loc = prediction_network(nodes, canonical_loc, edges, canonical_vel, edge_attr, charges)  # predict the output
        pred_loc = canonicalizer.invert_canonicalization(canonical_pred_loc)  # invert the canonicalization

        ## -----------------------------------------------------


        task_loss = loss_fn(pred_loc, loc_end)

        loss += task_loss

        # Logging the training metrics
        total_loss += loss.item()
        total_task_loss += task_loss.item()  
        training_metrics.update({
                "task_loss": total_task_loss / (batch_idx + 1),
                "loss": total_loss / (batch_idx + 1),
            })  
        
        # Usual training steps
        loss.backward()
        
        optimizer.step()
        
        # Log the training metrics
        tqdm_bar.set_postfix(training_metrics)


Epoch 0: 100%|██████████| 30/30 [00:01<00:00, 15.27it/s, task_loss=1.82, loss=1.82]
Epoch 1: 100%|██████████| 30/30 [00:02<00:00, 14.59it/s, task_loss=0.128, loss=0.128]
Epoch 2: 100%|██████████| 30/30 [00:02<00:00, 13.11it/s, task_loss=0.0774, loss=0.0774]
Epoch 3: 100%|██████████| 30/30 [00:02<00:00, 12.53it/s, task_loss=0.0698, loss=0.0698]
Epoch 4: 100%|██████████| 30/30 [00:02<00:00, 14.04it/s, task_loss=0.0679, loss=0.0679]
Epoch 5: 100%|██████████| 30/30 [00:02<00:00, 12.82it/s, task_loss=0.0754, loss=0.0754]
Epoch 6: 100%|██████████| 30/30 [00:02<00:00, 13.56it/s, task_loss=0.0639, loss=0.0639]
Epoch 7: 100%|██████████| 30/30 [00:02<00:00, 12.07it/s, task_loss=0.0603, loss=0.0603]
Epoch 8: 100%|██████████| 30/30 [00:02<00:00, 13.41it/s, task_loss=0.0554, loss=0.0554]
Epoch 9: 100%|██████████| 30/30 [00:01<00:00, 15.00it/s, task_loss=0.0514, loss=0.0514]
Epoch 10: 100%|██████████| 30/30 [00:02<00:00, 12.85it/s, task_loss=0.0472, loss=0.0472]
Epoch 11: 100%|██████████| 30/30 [00: