# Lorentz-Equivariant Quantum Graph Neural Network (Lorentz-EQGNN)

In [1]:
# For Colab
!pip install torch_geometric
# !pip install torch_sparse
# !pip install torch_scatter

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
!pip install pennylane qiskit pennylane-qiskit pylatexenc

Collecting pennylane
  Downloading PennyLane-0.38.0-py3-none-any.whl.metadata (9.3 kB)
Collecting qiskit
  Downloading qiskit-1.2.4-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting pennylane-qiskit
  Downloading PennyLane_qiskit-0.38.1-py3-none-any.whl.metadata (6.4 kB)
Collecting pylatexenc
  Downloading pylatexenc-2.10.tar.gz (162 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.15.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting autograd (from pennylane)
  Downloading autograd-1.7.0-py3-none-any.whl.metadata (7.5 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.38 (from pennylane)
  Downloading PennyLane_Lightning-0.3

In [3]:
import pennylane as qml
import qiskit
print(qml.__version__)
print(qiskit.__version__)
import pennylane_qiskit
print(pennylane_qiskit.__version__)
import pennylane as qml
from pennylane import numpy as np
# from pennylane_qiskit import AerDevice

0.38.0
1.2.4
0.38.1


In [4]:
!pip install energyflow

Collecting energyflow
  Downloading EnergyFlow-1.3.2-py2.py3-none-any.whl.metadata (4.3 kB)
Collecting wasserstein>=0.3.1 (from energyflow)
  Downloading wasserstein-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Downloading EnergyFlow-1.3.2-py2.py3-none-any.whl (700 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m700.5/700.5 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading wasserstein-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m502.2/502.2 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: wasserstein, energyflow
Successfully installed energyflow-1.3.2 wasserstein-1.1.0


In [5]:
import torch
import numpy as np
import energyflow
from scipy.sparse import coo_matrix
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data.distributed import DistributedSampler


# we define a function to return an adjacencyy matrix
# for our graph data representing the jets.
def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    # print(edge_mask[0])
    # raise
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask *= diag_mask
    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edge_mask, edges]

def retrieve_dataloaders(batch_size, num_data = -1, use_one_hot = False, cache_dir = './data', num_workers=4):
    raw = energyflow.qg_jets.load(num_data=num_data, pad=True, ncol=4, generator='pythia',
                            with_bc=False, cache_dir=cache_dir)
    splits = ['train', 'val', 'test']
    data = {type:{'raw':None,'label':None} for type in splits}
    (data['train']['raw'],  data['val']['raw'],   data['test']['raw'],
    data['train']['label'], data['val']['label'], data['test']['label']) = \
        energyflow.utils.data_split(*raw, train=0.8, val=0.1, test=0.1, shuffle = False)

    enc = OneHotEncoder(handle_unknown='ignore').fit([[11],[13],[22],[130],[211],[321],[2112],[2212]])

    for split, value in data.items():
        pid = torch.from_numpy(np.abs(np.asarray(value['raw'][...,3], dtype=int))).unsqueeze(-1)
        p4s = torch.from_numpy(energyflow.p4s_from_ptyphipids(value['raw'],error_on_unknown=True))
        one_hot = enc.transform(pid.reshape(-1,1)).toarray().reshape(pid.shape[:2]+(-1,))
        # one_hot = enc.transform(pid.reshape(-1,1)).toarray().reshape(pid.shape[:2]+(-1,))
        one_hot = torch.from_numpy(one_hot)
        mass = torch.from_numpy(energyflow.ms_from_p4s(p4s)).unsqueeze(-1)
        charge = torch.from_numpy(energyflow.pids2chrgs(pid))

        if use_one_hot:
            nodes = one_hot

        # else:
        #     nodes = torch.cat((mass,charge),dim=-1)

        #     nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)


        else:
              # Concatenate mass and charge along the last dimension
              concatenated = torch.cat((mass, charge), dim=-1)  # Shape (batch_size, n_nodes, 2)

              # Reduce along the last dimension (e.g., by summing or averaging)
              nodes = concatenated.sum(dim=-1, keepdim=True)  # Shape (batch_size, n_nodes, 1)

              # Apply log-sign transformation if needed
              nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)

        atom_mask = (pid[...,0] != 0)

        value['p4s'] = p4s
        value['nodes'] = nodes
        value['label'] = torch.from_numpy(value['label'])
        value['atom_mask'] = atom_mask.to(torch.bool)

        if split == 'train':
            print(value['atom_mask'])

    datasets = {split: TensorDataset(value['label'], value['p4s'],
                                     value['nodes'], value['atom_mask'])
                for split, value in data.items()}

    # distributed training
    # train_sampler = DistributedSampler(datasets['train'], shuffle=True)
    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset,
                                     batch_size=batch_size,
                                     # sampler=train_sampler if (split == 'train') else DistributedSampler(dataset, shuffle=False),
                                     pin_memory=False,
                                     # persistent_workers=True,
                                     drop_last=True if (split == 'train') else False,
                                     num_workers=num_workers,
                                     collate_fn=collate_fn)
                        for split, dataset in datasets.items()}

    return dataloaders #train_sampler, dataloaders

if __name__ == '__main__':
    # train_sampler, dataloaders = retrieve_dataloaders(32, 100)
    dataloaders = retrieve_dataloaders(batch_size=16, num_data = 20, use_one_hot = True)
    for (label, p4s, nodes, atom_mask, edge_mask, edges) in dataloaders['train']:
        print(label.shape, p4s.shape, nodes.shape, atom_mask.shape,
              edge_mask.shape, edges[0].shape, edges[1].shape)
        break

Downloading QG_jets.npz from https://www.dropbox.com/s/fclsl7pukcpobsb/QG_jets.npz?dl=1 to ./data/datasets
URL fetch failure on https://www.dropbox.com/s/fclsl7pukcpobsb/QG_jets.npz?dl=1: None -- Bad Request
Failed to download QG_jets.npz from source 'dropbox', trying next source...
Downloading QG_jets.npz from https://zenodo.org/record/3164691/files/QG_jets.npz?download=1 to ./data/datasets
tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])
torch.Size([16]) torch.Size([16, 139, 4]) torch.Size([16, 139, 8]) torch.Size([16, 139]) torch.Size([16, 139, 139]) torch.Size([28736]) torch.Size([28736])


In [6]:
# Test the first batch
for label, p4s, nodes, atom_mask, edge_mask, edges in dataloaders["train"]:
    print(f"Label shape: {label.shape}")
    print(f"4-momenta shape: {p4s.shape}")
    print(f"Node features shape: {nodes.shape}")
    print(f"Atom mask shape: {atom_mask.shape}")
    print(f"Edge mask shape: {edge_mask.shape}")
    print(f"Edge indices shapes: {edges[0].shape}, {edges[1].shape}")
    break

Label shape: torch.Size([16])
4-momenta shape: torch.Size([16, 139, 4])
Node features shape: torch.Size([16, 139, 8])
Atom mask shape: torch.Size([16, 139])
Edge mask shape: torch.Size([16, 139, 139])
Edge indices shapes: torch.Size([28736]), torch.Size([28736])


In [7]:
import torch
import numpy as np
import energyflow
import os
from sklearn.preprocessing import OneHotEncoder
from scipy.sparse import coo_matrix

def save_physics_tensors(num_data=-1, use_one_hot=False, save_dir="random/data"):
    """
    Generate and save tensor data files needed for physics analysis.

    Args:
        num_data: Number of data points to generate (-1 for all)
        save_dir: Directory to save the tensor files
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Load raw data
    raw = energyflow.qg_jets.load(
        num_data=num_data,
        pad=True,
        ncol=4,
        generator="pythia",
        with_bc=False,
    )

    # Get data and labels
    data, labels = raw

    # Initialize one-hot encoder for particle IDs
    enc = OneHotEncoder(handle_unknown="ignore").fit(
        [[11], [13], [22], [130], [211], [321], [2112], [2212]]
    )

    # Process data
    pid = torch.from_numpy(np.abs(np.asarray(data[..., 3], dtype=int))).unsqueeze(-1)
    p4s = torch.from_numpy(energyflow.p4s_from_ptyphipids(data, error_on_unknown=True))

    # Create one-hot encoded nodes
    one_hot = enc.transform(pid.reshape(-1, 1)).toarray().reshape(pid.shape[:2] + (-1,))
    nodes = torch.from_numpy(one_hot)
    mass = torch.from_numpy(energyflow.ms_from_p4s(p4s)).unsqueeze(-1)
    charge = torch.from_numpy(energyflow.pids2chrgs(pid))

    if use_one_hot:
        nodes = one_hot

    else:
          # Concatenate mass and charge along the last dimension
          concatenated = torch.cat((mass, charge), dim=-1)  # Shape (batch_size, n_nodes, 2)

          # Reduce along the last dimension (e.g., by summing or averaging)
          nodes = concatenated.sum(dim=-1, keepdim=True)  # Shape (batch_size, n_nodes, 1)

          # Apply log-sign transformation if needed
          nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)

    # Create masks
    atom_mask = (pid[..., 0] != 0).to(torch.bool)

    # Create edge mask
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask = edge_mask * diag_mask

    # Convert labels to tensor
    labels = torch.from_numpy(labels)

    # Calculate edges for the full dataset
    n_nodes = p4s.size(1)
    batch_size = p4s.size(0)

    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx * n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    edges = np.stack([rows, cols])

    # Save tensors
    torch.save(p4s, os.path.join(save_dir, "p4s.pt"))
    torch.save(nodes, os.path.join(save_dir, "nodes.pt"))
    torch.save(labels, os.path.join(save_dir, "labels.pt"))
    torch.save(atom_mask, os.path.join(save_dir, "atom_mask.pt"))
    np.save(os.path.join(save_dir, "edge_mask.npy"), edge_mask.numpy())
    np.save(os.path.join(save_dir, "edges.npy"), edges)

    print(f"Saved tensor files to {save_dir}")
    print(f"Shapes:")
    print(f"p4s: {p4s.shape}")
    print(f"nodes: {nodes.shape}")
    print(f"labels: {labels.shape}")
    print(f"atom_mask: {atom_mask.shape}")
    print(f"edge_mask: {edge_mask.shape}")
    print(f"edges: {edges.shape}")

# Generate and save the tensor files
save_physics_tensors(num_data=1000, use_one_hot=False)  # Use same number of data points as before

Downloading QG_jets.npz from https://www.dropbox.com/s/fclsl7pukcpobsb/QG_jets.npz?dl=1 to /root/.energyflow/datasets
URL fetch failure on https://www.dropbox.com/s/fclsl7pukcpobsb/QG_jets.npz?dl=1: None -- Bad Request
Failed to download QG_jets.npz from source 'dropbox', trying next source...
Downloading QG_jets.npz from https://zenodo.org/record/3164691/files/QG_jets.npz?download=1 to /root/.energyflow/datasets
Saved tensor files to random/data
Shapes:
p4s: torch.Size([1000, 139, 4])
nodes: torch.Size([1000, 139, 1])
labels: torch.Size([1000])
atom_mask: torch.Size([1000, 139])
edge_mask: torch.Size([1000, 139, 139])
edges: (2, 2145950)


In [8]:
from torch.utils.data import TensorDataset, random_split

def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    # edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    # diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    # edge_mask *= diag_mask

    edge_mask = data[-2]

    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edges]


p4s = torch.load('random/data/p4s.pt')
nodes = torch.load('random/data/nodes.pt')
labels = torch.load('random/data/labels.pt')
atom_mask = torch.load('random/data/atom_mask.pt')
edge_mask = torch.from_numpy(np.load('random/data/edge_mask.npy'))
edges = torch.from_numpy(np.load('random/data/edges.npy'))


# Create a TensorDataset
dataset_all = TensorDataset(labels, p4s, nodes, atom_mask, edge_mask)

# Define the split ratios
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

# Calculate the lengths for each split
total_size = len(dataset_all)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
test_size = total_size - train_size - val_size  # Ensure all data is used

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_all, [train_size, val_size, test_size])

# Create a dictionary to hold the datasets
datasets = {
    "train": train_dataset,
    "val": val_dataset,
    "test": test_dataset
}

dataloaders = {split: DataLoader(dataset,
                                 batch_size=16,
                                 # sampler=train_sampler if (split == 'train') else DistributedSampler(dataset, shuffle=False),
                                 pin_memory=False,
                                 # persistent_workers=True,
                                 collate_fn = collate_fn,
                                 drop_last=True if (split == 'train') else False,
                                 num_workers=0)
                    for split, dataset in datasets.items()}

  p4s = torch.load('random/data/p4s.pt')
  nodes = torch.load('random/data/nodes.pt')
  labels = torch.load('random/data/labels.pt')
  atom_mask = torch.load('random/data/atom_mask.pt')


In [9]:
# # we can peek at a batch to see what it looks like.
# next(iter(dataloaders['val']))

In [10]:
print(p4s.shape) # p4s
print(nodes.shape) # mass
print(atom_mask.shape) # torch.ones
print(edge_mask.shape) # adj_matrix

torch.Size([1000, 139, 4])
torch.Size([1000, 139, 1])
torch.Size([1000, 139])
torch.Size([1000, 139, 139])


In [11]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7efac5965570>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x7efae2850c70>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7efae2850c40>}

In [12]:
# Set desired dimensions
batch_size = 1
n_nodes = 3
device = 'cpu'
dtype = torch.float32

# Print initial shapes
print("Initial shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Select subset of data
p4s = p4s[:batch_size, :n_nodes, :]
atom_mask = atom_mask[:batch_size, :n_nodes]
edge_mask = edge_mask[:batch_size, :n_nodes, :n_nodes]
nodes = nodes[:batch_size, :n_nodes, :]

print("\nAfter selection shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Reshape tensors
atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
# Don't reshape edge_mask yet
nodes = nodes.view(batch_size * n_nodes, -1).to(device, dtype)

print("\nAfter reshape shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)  # original shape
print("nodes:", nodes.shape)

# Recalculate edges for the subset
from scipy.sparse import coo_matrix
rows, cols = [], []
for batch_idx in range(batch_size):
    nn = batch_idx * n_nodes
    # Convert edge_mask to numpy and remove any extra dimensions
    edge_mask_np = edge_mask[batch_idx].cpu().numpy().squeeze()
    x = coo_matrix(edge_mask_np)
    rows.append(nn + x.row)
    cols.append(nn + x.col)

edges = [torch.LongTensor(np.concatenate(rows)).to(device),
         torch.LongTensor(np.concatenate(cols)).to(device)]

# Now reshape edge_mask after edges are calculated
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])

Initial shapes:
p4s: torch.Size([1000, 139, 4])
atom_mask: torch.Size([1000, 139])
edge_mask: torch.Size([1000, 139, 139])
nodes: torch.Size([1000, 139, 1])

After selection shapes:
p4s: torch.Size([1, 3, 4])
atom_mask: torch.Size([1, 3])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([1, 3, 1])

After reshape shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([3, 1])

Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


In [13]:
batch_size = 1  #2500 #1
n_nodes = 3 #139
device = 'cpu'
dtype = torch.float32

atom_positions = p4s[:, :, :].view(batch_size * n_nodes, -1).to(device, dtype)

atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

edges = [a.to(device) for a in edges]
nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)

In [14]:
print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])


Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


In [15]:
atom_mask[0]#.shape

tensor([1.])

In [16]:
p4s[0][2]

tensor([ 1.1594, -0.2378, -1.1238, -0.0723], dtype=torch.float64)

In [17]:
atom_mask.shape

torch.Size([3, 1])

In [18]:
p4s.shape # batch_size (number of jets or graphs), n_nodes (particles), n_features

torch.Size([1, 3, 4])

In [19]:
# random: x(atom_pos), edge_indx_tensor (edges = adj_matrix), edge_tensor (edge_mask = adj_matrix)
print("Atom mask: {}".format(atom_mask[:2]))
print("Atom positions (x features, 4-momenta): {}".format(atom_positions[:2]))
print("Nodes (scalars: mass & charge): {}".format(nodes[:2]))
print("Edge mask: {}".format(edge_mask[:2]))
print("Edges: {}".format(edges[:2]))

Atom mask: tensor([[1.],
        [1.]])
Atom positions (x features, 4-momenta): tensor([[ 0.2861,  0.0078, -0.2687,  0.0980],
        [ 0.1653, -0.0258, -0.1580, -0.0414]])
Nodes (scalars: mass & charge): tensor([[-4.7488e-09],
        [-2.2813e-09]])
Edge mask: tensor([[False],
        [ True]])
Edges: [tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 2, 0, 1])]


In [20]:
edges[:2]#[0].shape

[tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 2, 0, 1])]

In [21]:
# model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
#                          edge_mask=edge_mask, n_nodes=n_nodes)

# 3. LorentzNet
Before delving into the realm of quantum graph neural networks (QGNNs), we shall examine the performance and structure of a very well-known equivariant GNN, **LorentzNet** ([arXiv:2201.08187](https://arxiv.org/abs/2201.08187)), which is classical, on our dataset. Understanding the structure underlying LorentzNet will allow us to understand where to fit in our quantum models, and this will be the heart of our approach.

## 3.1. Dataset Representation as Graphs

We already discussed this in the introduction, but again, let's remmber that in high-energy particle physics, **jets**—collimated streams of particles resulting from particle collisions—are complex objects that can be naturally represented as graphs. In our dataset:

- Each **jet** is modeled as a graph \( G = (V, E) \), where:
  - $V$ is the set of **nodes**, each corresponding to a constituent particle within the jet.
  - $E$ is the set of **edges**, representing interactions or relationships between particles.
- Each node (particle) is considered a point in Minkowski space $\mathbb{R}^{1,3}$, respecting the spacetime symmetries of special relativity.
- The number of particles (nodes) varies for each jet, reflecting the stochastic nature of particle collisions.

**Reconstructing Four-Momentum Vectors**

In practice, particle data may not be directly provided as four-momentum vectors. Instead, they are often given in terms of:

- **Transverse Momentum $p_T$**: Momentum perpendicular to the beam axis.
- **Pseudo-rapidity $\eta$**: A spatial coordinate describing the angle of a particle relative to the beam (forward-backward) direction.
- **Azimuthal Angle $\phi$**: Angle around the beam axis in the transverse plane.
- **Particle Identification (PID)**: Integer codes representing particle types.

**Conversion to Four-Momentum**

- Using the relationships:

  - $p_x = p_T \cos\phi$
  - $p_y = p_T \sin\phi$
  - $p_z = p_T \sinh\eta$
  - $E = \sqrt{p_T^2 \cosh^2\eta + m^2}$, where $m$ is the particle mass.

- The **[EnergyFlow](https://energyflow.network/)** package converts this for us.

**Implementation in Code**:

- The first step in the data preprocessing involves reconstructing the four-momentum vectors using the available kinematic variables, which is fundamental for us, since:
    - First, we want to ensure that the input to LorentzNet is correctly formatted and physically meaningful.
    - Also, given the limitations on current quantum hardware, and since we are performing simulations currently, then the number of particles in the jet has to be cut down.

## 3.2. Architecture Overview

The **LorentzNet** architecture is designed to process and analyze graphs while respecting the **Lorentz symmetry**, a fundamental symmetry in relativistic physics involving rotations and boosts in spacetime (changes in inertial frames).

**Key Features of LorentzNet**:

- Built upon the **universal approximation theorem** for **Lorentz-equivariant functions**. This theorem ensures that the network can approximate any Lorentz-equivariant function to arbitrary precision, given sufficient capacity.
- Incorporates **message passing** mechanisms tailored to respect Lorentz symmetry.
- Utilizes **continuous functions** modeled by neural networks to update node and edge features throughout the network layers.

**Architecture Diagram**:

<center>
<img src="../figures/LorentzNet.png" width="65%" style="margin-left:auto; margin-right:auto">
</center>

*(Figure: Schematic representation of the LorentzNet architecture.)*

**Input Layer**

The **input** to the LorentzNet consists of:

- **Four-momentum vectors** (coordinate embeddings) of particles from collision events.
  - Each particle $i$ has a four-momentum $v_i = (E_i, p_{x_i}, p_{y_i}, p_{z_i})$, where:
    - $E_i$ is the energy.
    - $p_{x_i}, p_{y_i}, p_{z_i}$ are momentum components in three-dimensional space.
- **Scalar features** (scalar embeddings) $s_i$ associated with each particle, such as:
  - Mass.
  - Electric charge.
  - Particle identification (PID) codes.

The combined feature vector for each particle is:

$$
f_i = v_i \oplus s_i,
$$

where $\oplus$ denotes concatenation.

**Lorentz Group Equivariant Block (LGEB)**

At the core of LorentzNet is the **Lorentz Group Equivariant Block (LGEB)**, which updates the features of particles (nodes) and their interactions (edges) while preserving Lorentz equivariance.

**The components of LGEB**:

1. **Edge Message Function $\phi_e$**:
   - Computes messages passed between particles.
   - Captures pairwise interactions and relativistic geometrical relationships.

2. **Coordinate Update Function $\phi_x$**:
   - Updates the coordinate embeddings of particles.
   - Incorporates attention mechanisms respecting Minkowski spacetime.

3. **Scalar Feature Update Function $\phi_h$**:
   - Updates scalar features of particles.
   - Aggregates information from neighboring particles.

These functions are modeled using neural networks capable of approximating continuous functions.

## 3.3. Detailed Formulation

1. **Edge Message Computation $\phi_e$**:

   For particles $i$ and $j$ at layer $l$, the **edge message** $m_{ij}^{l}$ is computed as:

   $$
   m_{ij}^{l} = \phi_e \left( h_i^{l}, h_j^{l}, \psi\left( \| x_i^{l} - x_j^{l} \|^2 \right), \psi\left( \langle x_i^{l}, x_j^{l} \rangle \right) \right),
   $$

   where:

   - $h_i^{l}$ and $h_j^{l}$ are the scalar features of particles $i$ and $j$ at layer $l$.
   - $x_i^{l}$ and $x_j^{l}$ are the coordinate embeddings (four-vectors) at layer $l$.
   - $\| x_i^{l} - x_j^{l} \|^2$ is the squared Minkowski **distance** between particles $i$ and $j$.
   - $\langle x_i^{l}, x_j^{l} \rangle$ is the Minkowski **inner product** (Lorentz dot product).
   - $\psi(\cdot)$ is a normalization function defined as:

     $$
     \psi(a) = \operatorname{sgn}(a) \cdot \log\left( |a| + 1 \right),
     $$

     with $\operatorname{sgn}(a)$ being the sign function.

   **Purpose of $\psi(\cdot)$**:

   - Helps normalize values that may have large magnitudes or come from different distributions.
   - Ensures numerical stability during optimization by mapping inputs to a manageable range.


2. **Coordinate Embedding Update $\phi_x$**:

   The **coordinate embeddings** of particles are updated via:

   $$
   x_i^{l+1} = x_i^{l} + c \sum_{j \in \mathcal{N}(i)} \phi_x ( m_{ij}^{l}) \cdot x_j^{l},
   $$

   where:

   - $\mathcal{N}(i)$ denotes the **neighborhood** of particle $i$, i.e., particles connected to $i$ in the graph.
   - $c$ is a scaling constant controlling the update magnitude.
   - $\phi_x ( m_{ij}^{l})$ computes an **attention weight** based on the edge message $m_{ij}^{l}$.

   **Interpretation**:

   - The update adds a weighted sum of neighboring coordinate embeddings $x_j^{l}$ to the current embedding $x_i^{l}$.
   - This mechanism allows particles to incorporate spatial information from their neighbors, guided by the learned attention weights.


3. **Scalar Feature Update $\phi_h$**:

   The **scalar features** are updated as:

   $$
   h_i^{l+1} = h_i^{l} + \phi_h \left( h_i^{l}, \sum_{j \in \mathcal{N}(i)} w_{ij}^{l} m_{ij}^{l} \right),
   $$

   where:

   - $w_{ij}^{l}$ is an **edge significance weight** calculated by:

     $$
     w_{ij}^{l} = \phi_m \left( m_{ij}^{l} \right) \in [0, 1],
     $$

     with $\phi_m$ being a neural network outputting values in the range [0, 1].

   - $\phi_h$ aggregates information from neighboring particles to update $h_i^{l}$.

   And for the Purpose of $w_{ij}^{l}$ and $\phi_h$:

   - $w_{ij}^{l}$ signifies the importance of the edge between particles $i$ and $j$.
   - $\phi_h$ integrates these weighted messages to refine the scalar features, enabling the network to learn complex interactions.


**Avoiding Redundancy**

A noteworthy aspect of LorentzNet is its approach to handling outputs:

- Although both **coordinate embeddings** $x_i^{l}$ and **scalar features** $h_i^{l}$ are updated through the layers, the final output only uses the **scalar features** $h_i^{L}$ from the last layer $L$.
- This strategy reduces redundancy and computational overhead because:

  - The edge messages $m_{ij}^{l}$ already incorporate information from both $x_i^{l}$ and $x_j^{l}$.
  - Focusing on scalar features simplifies the network output without losing critical information.

      
**Implementation Details**

To ensure fidelity with the original LorentzNet architecture and leverage existing optimizations, we utilize the official implementation provided by the authors:

- **Repository**: [LorentzNet-release](https://github.com/sdogsq/LorentzNet-release/tree/main)

In [22]:
# @title
import torch
from torch import nn
import numpy as np



"""Some auxiliary functions"""

def unsorted_segment_sum(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_mean`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    count = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    count.index_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

def normsq4(p):
    r''' Minkowski square norm
         `\|p\|^2 = p[0]^2-p[1]^2-p[2]^2-p[3]^2`
    '''
    psq = torch.pow(p, 2)
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def dotsq4(p,q):
    r''' Minkowski inner product
         `<p,q> = p[0]q[0]-p[1]q[1]-p[2]q[2]-p[3]q[3]`
    '''
    psq = p*q
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def normA_fn(A):
    return lambda p: torch.einsum('...i, ij, ...j->...', p, A, p)

def dotA_fn(A):
    return lambda p, q: torch.einsum('...i, ij, ...j->...', p, A, q)

def psi(p):
    ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`
    '''
    return torch.sign(p) * torch.log(torch.abs(p) + 1)


"""Lorentz Group-Equivariant Block"""

class LGEB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LGEB, self).__init__()
        self.c_weight = c_weight
        self.dimension_reducer = nn.Linear(10, 4) # New linear layer for dimension reduction
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2.
        print('Input size of phi_e: ', n_input)

        self.include_x = include_x
        self.phi_e = nn.Sequential(
            nn.Linear(n_input, n_hidden, bias=False), # n_input * 2 + n_edge_attr
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU())

        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4


    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        # Reduce the dimension of 'out' to 4 using a linear layer
        out = self.dimension_reducer(out)
        out = self.phi_e(out)
        # print("m_model output: ", out.shape)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.phi_e(out)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m): # norms
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # print("m: ", m.shape)
        # print("trans: ", trans.shape)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        # print("trans: ", trans.shape)
        # print("x.size: ", x.size(0))
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight # * norms[i, j], smth like that, or norms
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            # print("X: ", x)
            x = self.x_model(x, edges, x_diff, m)
            # print("phi_x(X) = ", x, '\n---\n')

        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LorentzNet(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LGEB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LorentzNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LGEBs = nn.ModuleList([LGEB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LGEBs[i](h, x, edges, node_attr=scalars)
        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)

        # print("Final preds: \n", pred.cpu().detach().numpy())
        return pred.squeeze(1)

LGEB(self.n_hidden, self.n_hidden, self.n_hidden,\
                                    n_node_attr=n_scalar, dropout=dropout,\
                                    c_weight=c_weight, last_layer=\(i==n_layers-1), A=A, include_x=include_x)
                                    
We are using n_hidden = 4 and n_layers = 6

n_input=n_hidden, n_output=n_hidden, n_hidden=n_hidden, n_node_attr=n_scalar=8

### Now that we have the official code for the classical, just for sanity checking, let's test for equivariance

The cell below is just an auxiliary function to give us the boosts

In [23]:
# @title
from math import sqrt
import numpy as np

# Speed of light (m/s)
c = 299792458

"""Lorentz transformations describe the transition between two inertial reference
frames F and F', each of which is moving in some direction with respect to the
other. This code only calculates Lorentz transformations for movement in the x
direction with no spatial rotation (i.e., a Lorentz boost in the x direction).
The Lorentz transformations are calculated here as linear transformations of
four-vectors [ct, x, y, z] described by Minkowski space. Note that t (time) is
multiplied by c (the speed of light) in the first entry of each four-vector.

Thus, if X = [ct; x; y; z] and X' = [ct'; x'; y'; z'] are the four-vectors for
two inertial reference frames and X' moves in the x direction with velocity v
with respect to X, then the Lorentz transformation from X to X' is X' = BX,
where

    | γ  -γβ  0  0|
B = |-γβ  γ   0  0|
    | 0   0   1  0|
    | 0   0   0  1|

is the matrix describing the Lorentz boost between X and X',
γ = 1 / √(1 - v²/c²) is the Lorentz factor, and β = v/c is the velocity as
a fraction of c.
"""


def beta(velocity: float) -> float:
    """
    Calculates β = v/c, the given velocity as a fraction of c
    >>> beta(c)
    1.0
    >>> beta(199792458)
    0.666435904801848
    """
    if velocity > c:
        raise ValueError("Speed must not exceed light speed 299,792,458 [m/s]!")
    elif velocity < 1:
        # Usually the speed should be much higher than 1 (c order of magnitude)
        raise ValueError("Speed must be greater than or equal to 1!")

    return velocity / c


def gamma(velocity: float) -> float:
    """
    Calculate the Lorentz factor γ = 1 / √(1 - v²/c²) for a given velocity
    >>> gamma(4)
    1.0000000000000002
    >>> gamma(1e5)
    1.0000000556325075
    >>> gamma(3e7)
    1.005044845777813
    >>> gamma(2.8e8)
    2.7985595722318277
    """
    return 1 / sqrt(1 - beta(velocity) ** 2)


def transformation_matrix(velocity: float) -> np.ndarray:
    """
    Calculate the Lorentz transformation matrix for movement in the x direction:

    | γ  -γβ  0  0|
    |-γβ  γ   0  0|
    | 0   0   1  0|
    | 0   0   0  1|

    where γ is the Lorentz factor and β is the velocity as a fraction of c
    >>> transformation_matrix(29979245)
    array([[ 1.00503781, -0.10050378,  0.        ,  0.        ],
           [-0.10050378,  1.00503781,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  1.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  1.        ]])
    """
    return np.array(
        [
            [gamma(velocity), -gamma(velocity) * beta(velocity), 0, 0],
            [-gamma(velocity) * beta(velocity), gamma(velocity), 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ]
    )


### Now, the model

In [24]:
# n_scalar = 8 in original !
model = LorentzNet(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

Input size of phi_e:  4


### Let's start with a default prediction

In [25]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

In [26]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

### ... taking any random nonsense transformation in the four-momentum vectors
i.e.: multiplying by 0.1. Does the hidden rep stay the same?

In [27]:
pred = model(scalars=nodes, x= 0.1 * atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

In [28]:
pred = model(scalars=nodes, x= 0.1 * atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

### Even though the final logits in this case wasn't different, if we look the last output of h (which contains both scalar and 4-momenta information), it changed! Now, what about Lorentz transformations?

In [29]:
pred = model(scalars=nodes, x= (torch.tensor(transformation_matrix(220000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

In [30]:
pred = model(scalars=nodes, x= (torch.tensor(transformation_matrix(220000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

## Equivariance works. Finally, let's train on some data

In [31]:
# @title
import torch
import os, json, random, string
import torch.distributed as dist

def makedir(path):
    try:
        os.makedirs(path)
    except OSError:
        pass

def args_init(args):
    r''' Initialize seed and exp_name.
    '''
    if args.seed is None: # use random seed if not specified
        args.seed = np.random.randint(100)
    if args.exp_name == '': # use random strings if not specified
        args.exp_name = ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
    if (args.local_rank == 0): # master
        print(args)
        makedir(f"{args.logdir}/{args.exp_name}")
        with open(f"{args.logdir}/{args.exp_name}/args.json", 'w') as f:
            json.dump(args.__dict__, f, indent=4)

def sum_reduce(num, device):
    r''' Sum the tensor across the devices.
    '''
    if not torch.is_tensor(num):
        rt = torch.tensor(num).to(device)
    else:
        rt = num.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        warmup_epoch: target learning rate is reached at warmup_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    Reference:
        https://github.com/ildoonet/pytorch-gradual-warmup-lr
    """

    def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.warmup_epoch = warmup_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    @property
    def _warmup_lr(self):
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch + 1) / self.warmup_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * (self.last_epoch + 1) / self.warmup_epoch + 1.) for base_lr in self.base_lrs]

    def get_lr(self):
        if self.last_epoch >= self.warmup_epoch - 1:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        return self._warmup_lr

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        self.last_epoch = self.last_epoch + 1 if epoch==None else epoch
        if self.last_epoch >= self.warmup_epoch - 1:
            if not self.finished:
                warmup_lr = [base_lr * self.multiplier for base_lr in self.base_lrs]
                for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                    param_group['lr'] = lr
                self.finished = True
                return
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.warmup_epoch)
            return

        for param_group, lr in zip(self.optimizer.param_groups, self._warmup_lr):
            param_group['lr'] = lr

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.warmup_epoch)
                self.last_epoch = self.after_scheduler.last_epoch + self.warmup_epoch + 1
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        result = {key: value for key, value in self.__dict__.items() if key != 'optimizer' or key != "after_scheduler"}
        if self.after_scheduler:
            result.update({"after_scheduler": self.after_scheduler.state_dict()})
        return result

    def load_state_dict(self, state_dict):
        after_scheduler_state = state_dict.pop("after_scheduler", None)
        self.__dict__.update(state_dict)
        if after_scheduler_state:
            self.after_scheduler.load_state_dict(after_scheduler_state)


from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

def buildROC(labels, score, targetEff=[0.3,0.5]):
    r''' ROC curve is a plot of the true positive rate (Sensitivity) in the function of the false positive rate
    (100-Specificity) for different cut-off points of a parameter. Each point on the ROC curve represents a
    sensitivity/specificity pair corresponding to a particular decision threshold. The Area Under the ROC
    curve (AUC) is a measure of how well a parameter can distinguish between two diagnostic groups.
    '''
    if not isinstance(targetEff, list):
        targetEff = [targetEff]
    fpr, tpr, threshold = roc_curve(labels, score)
    idx = [np.argmin(np.abs(tpr - Eff)) for Eff in targetEff]
    eB, eS = fpr[idx], tpr[idx]
    return fpr, tpr, threshold, eB, eS

In [32]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from tqdm import tqdm

def run(model, epoch, loader, partition, N_EPOCHS=None):
    if partition == 'train':
        model.train()
    else:
        model.eval()

    res = {'time':0, 'correct':0, 'loss': 0, 'counter': 0, 'acc': 0,
           'loss_arr':[], 'correct_arr':[],'label':[],'score':[]}

    tik = time.time()
    loader_length = len(loader)

    for i, (label, p4s, nodes, atom_mask, edge_mask, edges) in tqdm(enumerate(loader)):
        if partition == 'train':
            optimizer.zero_grad()

        batch_size, n_nodes, _ = p4s.size()
        atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
        atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device)
        edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)
        nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)
        edges = [a.to(device) for a in edges]
        label = label.to(device, dtype).long()

        pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                         edge_mask=edge_mask, n_nodes=n_nodes)

        predict = pred.max(1).indices
        correct = torch.sum(predict == label).item()
        loss = loss_fn(pred, label)

        if partition == 'train':
            loss.backward()
            optimizer.step()
        elif partition == 'test':
            # save labels and probilities for ROC / AUC
            # print("Preds ", pred)
            score = torch.nn.functional.softmax(pred, dim = -1)
            # print("Score test ", score)
            # raise
            res['label'].append(label)
            res['score'].append(score)

        res['time'] = time.time() - tik
        res['correct'] += correct
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        res['loss_arr'].append(loss.item())
        res['correct_arr'].append(correct)

        # if i != 0 and i % args.log_interval == 0:

    running_loss = sum(res['loss_arr'])/len(res['loss_arr'])
    running_acc = sum(res['correct_arr'])/(len(res['correct_arr'])*batch_size)
    avg_time = res['time']/res['counter'] * batch_size
    tmp_counter = res['counter']
    tmp_loss = res['loss'] / tmp_counter
    tmp_acc = res['correct'] / tmp_counter

    if N_EPOCHS:
        print(">> %s \t Epoch %d/%d \t Batch %d/%d \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
             (partition, epoch + 1, N_EPOCHS, i, loader_length, running_loss, running_acc, tmp_acc, avg_time))
    else:
        print(">> %s \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
             (partition, running_loss, running_acc, tmp_acc, avg_time))

    torch.cuda.empty_cache()
    # ---------- reduce -----------
    if partition == 'test':
        res['label'] = torch.cat(res['label']).unsqueeze(-1)
        res['score'] = torch.cat(res['score'])
        res['score'] = torch.cat((res['label'],res['score']),dim=-1)
    res['counter'] = res['counter']
    res['loss'] = res['loss'] / res['counter']
    res['acc'] = res['correct'] / res['counter']
    return res

def train(model, res, N_EPOCHS, model_path, log_path):
    ### training and validation
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(log_path, exist_ok=True)

    for epoch in range(N_EPOCHS):
        train_res = run(model, epoch, dataloaders['train'], partition='train', N_EPOCHS = N_EPOCHS)
        print("Time: train: %.2f \t Train loss %.4f \t Train acc: %.4f" % (train_res['time'],train_res['loss'],train_res['acc']))
        # if epoch % args.val_interval == 0:

        # if (args.local_rank == 0):
        torch.save(model.state_dict(), os.path.join(model_path, "checkpoint-epoch-{}.pt".format(epoch)) )
        with torch.no_grad():
            val_res = run(model, epoch, dataloaders['val'], partition='val')

        # if (args.local_rank == 0): # only master process save
        res['lr'].append(optimizer.param_groups[0]['lr'])
        res['train_time'].append(train_res['time'])
        res['val_time'].append(val_res['time'])
        res['train_loss'].append(train_res['loss'])
        res['train_acc'].append(train_res['acc'])
        res['val_loss'].append(val_res['loss'])
        res['val_acc'].append(val_res['acc'])
        res['epochs'].append(epoch)

        ## save best model
        if val_res['acc'] > res['best_val']:
            print("New best validation model, saving...")
            torch.save(model.state_dict(), os.path.join(model_path,"best-val-model.pt"))
            res['best_val'] = val_res['acc']
            res['best_epoch'] = epoch

        print("Epoch %d/%d finished." % (epoch, N_EPOCHS))
        print("Train time: %.2f \t Val time %.2f" % (train_res['time'], val_res['time']))
        print("Train loss %.4f \t Train acc: %.4f" % (train_res['loss'], train_res['acc']))
        print("Val loss: %.4f \t Val acc: %.4f" % (val_res['loss'], val_res['acc']))
        print("Best val acc: %.4f at epoch %d." % (res['best_val'],  res['best_epoch']))

        json_object = json.dumps(res, indent=4)
        with open(os.path.join(log_path, "train-result-epoch{}.json".format(epoch)), "w") as outfile:
            outfile.write(json_object)

        ## adjust learning rate
        if (epoch < 31):
            lr_scheduler.step(metrics=val_res['acc'])
        else:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5


def test(model, res, model_path, log_path):
    ### test on best model
    best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)
    model.load_state_dict(best_model)
    with torch.no_grad():
        test_res = run(model, 0, dataloaders['test'], partition='test')

    print("Final ", test_res['score'])
    pred = test_res['score'].cpu()

    np.save(os.path.join(log_path, "score.npy"), pred)
    fpr, tpr, thres, eB, eS  = buildROC(pred[...,0], pred[...,2])
    auc = roc_auc_score(pred[...,0], pred[...,2])

    metric = {'test_loss': test_res['loss'], 'test_acc': test_res['acc'],
              'test_auc': auc, 'test_1/eB_0.3':1./eB[0],'test_1/eB_0.5':1./eB[1]}
    res.update(metric)
    print("Test: Loss %.4f \t Acc %.4f \t AUC: %.4f \t 1/eB 0.3: %.4f \t 1/eB 0.5: %.4f"\
           % (test_res['loss'], test_res['acc'], auc, 1./eB[0], 1./eB[1]))
    json_object = json.dumps(res, indent=4)
    with open(os.path.join(log_path, "test-result.json"), "w") as outfile:
        outfile.write(json_object)

if __name__ == "__main__":

    N_EPOCHS = 45 # 60

    model_path = "models/LorentzNet/"
    log_path = "logs/LorentzNet/"
    # args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cuda
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cpu")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    ### create parallel model
    model = LorentzNet(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)

Input size of phi_e:  4
Model Size: 199
 train samples: 800
 val samples: 100
 test samples: 100




Training...


50it [00:00, 95.57it/s]


>> train 	 Epoch 1/45 	 Batch 49/50 	 Loss 0.6825 	 Running Acc 0.517 	 Total Acc 0.517 	 Avg Batch Time 0.0106
Time: train: 0.53 	 Train loss 0.6825 	 Train acc: 0.5175


7it [00:00, 208.22it/s]


>> val 	 Loss 0.6789 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 0/45 finished.
Train time: 0.53 	 Val time 0.04
Train loss 0.6825 	 Train acc: 0.5175
Val loss: 0.6762 	 Val acc: 0.5400
Best val acc: 0.5400 at epoch 0.


50it [00:00, 102.61it/s]


>> train 	 Epoch 2/45 	 Batch 49/50 	 Loss 0.6855 	 Running Acc 0.501 	 Total Acc 0.501 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.6855 	 Train acc: 0.5012


7it [00:00, 212.08it/s]


>> val 	 Loss 0.6760 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0014
Epoch 1/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.6855 	 Train acc: 0.5012
Val loss: 0.6729 	 Val acc: 0.5400
Best val acc: 0.5400 at epoch 0.


50it [00:00, 103.68it/s]


>> train 	 Epoch 3/45 	 Batch 49/50 	 Loss 0.6805 	 Running Acc 0.511 	 Total Acc 0.511 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.6805 	 Train acc: 0.5112


7it [00:00, 195.94it/s]


>> val 	 Loss 0.6718 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0015
New best validation model, saving...
Epoch 2/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.6805 	 Train acc: 0.5112
Val loss: 0.6680 	 Val acc: 0.5600
Best val acc: 0.5600 at epoch 2.


50it [00:00, 104.01it/s]


>> train 	 Epoch 4/45 	 Batch 49/50 	 Loss 0.6710 	 Running Acc 0.552 	 Total Acc 0.552 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.6710 	 Train acc: 0.5525


7it [00:00, 215.16it/s]


>> val 	 Loss 0.6666 	 Running Acc 2.250 	 Total Acc 0.630 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 3/45 finished.
Train time: 0.48 	 Val time 0.03
Train loss 0.6710 	 Train acc: 0.5525
Val loss: 0.6619 	 Val acc: 0.6300
Best val acc: 0.6300 at epoch 3.


50it [00:00, 104.06it/s]


>> train 	 Epoch 5/45 	 Batch 49/50 	 Loss 0.6622 	 Running Acc 0.588 	 Total Acc 0.588 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.6622 	 Train acc: 0.5875


7it [00:00, 204.45it/s]


>> val 	 Loss 0.6588 	 Running Acc 2.393 	 Total Acc 0.670 	 Avg Batch Time 0.0015
New best validation model, saving...
Epoch 4/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.6622 	 Train acc: 0.5875
Val loss: 0.6526 	 Val acc: 0.6700
Best val acc: 0.6700 at epoch 4.


50it [00:00, 103.65it/s]


>> train 	 Epoch 6/45 	 Batch 49/50 	 Loss 0.6528 	 Running Acc 0.661 	 Total Acc 0.661 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.6528 	 Train acc: 0.6613


7it [00:00, 214.82it/s]


>> val 	 Loss 0.6524 	 Running Acc 2.429 	 Total Acc 0.680 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 5/45 finished.
Train time: 0.48 	 Val time 0.03
Train loss 0.6528 	 Train acc: 0.6613
Val loss: 0.6450 	 Val acc: 0.6800
Best val acc: 0.6800 at epoch 5.


50it [00:00, 102.55it/s]


>> train 	 Epoch 7/45 	 Batch 49/50 	 Loss 0.6459 	 Running Acc 0.665 	 Total Acc 0.665 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.6459 	 Train acc: 0.6650


7it [00:00, 222.38it/s]


>> val 	 Loss 0.6485 	 Running Acc 2.357 	 Total Acc 0.660 	 Avg Batch Time 0.0014
Epoch 6/45 finished.
Train time: 0.49 	 Val time 0.03
Train loss 0.6459 	 Train acc: 0.6650
Val loss: 0.6403 	 Val acc: 0.6600
Best val acc: 0.6800 at epoch 5.


50it [00:00, 104.09it/s]


>> train 	 Epoch 8/45 	 Batch 49/50 	 Loss 0.6413 	 Running Acc 0.681 	 Total Acc 0.681 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.6413 	 Train acc: 0.6813


7it [00:00, 211.94it/s]


>> val 	 Loss 0.6477 	 Running Acc 2.393 	 Total Acc 0.670 	 Avg Batch Time 0.0014
Epoch 7/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.6413 	 Train acc: 0.6813
Val loss: 0.6392 	 Val acc: 0.6700
Best val acc: 0.6800 at epoch 5.


50it [00:00, 103.08it/s]


>> train 	 Epoch 9/45 	 Batch 49/50 	 Loss 0.6381 	 Running Acc 0.695 	 Total Acc 0.695 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.6381 	 Train acc: 0.6950


7it [00:00, 205.08it/s]


>> val 	 Loss 0.6352 	 Running Acc 2.464 	 Total Acc 0.690 	 Avg Batch Time 0.0015
New best validation model, saving...
Epoch 8/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.6381 	 Train acc: 0.6950
Val loss: 0.6238 	 Val acc: 0.6900
Best val acc: 0.6900 at epoch 8.


50it [00:00, 101.47it/s]


>> train 	 Epoch 10/45 	 Batch 49/50 	 Loss 0.6154 	 Running Acc 0.725 	 Total Acc 0.725 	 Avg Batch Time 0.0099
Time: train: 0.50 	 Train loss 0.6154 	 Train acc: 0.7250


7it [00:00, 202.38it/s]


>> val 	 Loss 0.6204 	 Running Acc 2.429 	 Total Acc 0.680 	 Avg Batch Time 0.0015
Epoch 9/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.6154 	 Train acc: 0.7250
Val loss: 0.6043 	 Val acc: 0.6800
Best val acc: 0.6900 at epoch 8.


50it [00:00, 94.07it/s]


>> train 	 Epoch 11/45 	 Batch 49/50 	 Loss 0.5964 	 Running Acc 0.754 	 Total Acc 0.754 	 Avg Batch Time 0.0107
Time: train: 0.53 	 Train loss 0.5964 	 Train acc: 0.7538


7it [00:00, 212.38it/s]


>> val 	 Loss 0.6124 	 Running Acc 2.464 	 Total Acc 0.690 	 Avg Batch Time 0.0014
Epoch 10/45 finished.
Train time: 0.53 	 Val time 0.03
Train loss 0.5964 	 Train acc: 0.7538
Val loss: 0.5953 	 Val acc: 0.6900
Best val acc: 0.6900 at epoch 8.


50it [00:00, 102.39it/s]


>> train 	 Epoch 12/45 	 Batch 49/50 	 Loss 0.5956 	 Running Acc 0.726 	 Total Acc 0.726 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.5956 	 Train acc: 0.7262


7it [00:00, 200.30it/s]


>> val 	 Loss 0.6077 	 Running Acc 2.429 	 Total Acc 0.680 	 Avg Batch Time 0.0015
Epoch 11/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5956 	 Train acc: 0.7262
Val loss: 0.5897 	 Val acc: 0.6800
Best val acc: 0.6900 at epoch 8.


50it [00:00, 103.90it/s]


>> train 	 Epoch 13/45 	 Batch 49/50 	 Loss 0.5785 	 Running Acc 0.751 	 Total Acc 0.751 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.5785 	 Train acc: 0.7512


7it [00:00, 209.46it/s]


>> val 	 Loss 0.6048 	 Running Acc 2.464 	 Total Acc 0.690 	 Avg Batch Time 0.0014
Epoch 12/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5785 	 Train acc: 0.7512
Val loss: 0.5860 	 Val acc: 0.6900
Best val acc: 0.6900 at epoch 8.


50it [00:00, 101.57it/s]


>> train 	 Epoch 14/45 	 Batch 49/50 	 Loss 0.5808 	 Running Acc 0.740 	 Total Acc 0.740 	 Avg Batch Time 0.0099
Time: train: 0.49 	 Train loss 0.5808 	 Train acc: 0.7400


7it [00:00, 211.89it/s]


>> val 	 Loss 0.6029 	 Running Acc 2.607 	 Total Acc 0.730 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 13/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5808 	 Train acc: 0.7400
Val loss: 0.5840 	 Val acc: 0.7300
Best val acc: 0.7300 at epoch 13.


50it [00:00, 101.31it/s]


>> train 	 Epoch 15/45 	 Batch 49/50 	 Loss 0.5773 	 Running Acc 0.741 	 Total Acc 0.741 	 Avg Batch Time 0.0099
Time: train: 0.50 	 Train loss 0.5773 	 Train acc: 0.7412


7it [00:00, 221.14it/s]


>> val 	 Loss 0.6024 	 Running Acc 2.643 	 Total Acc 0.740 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 14/45 finished.
Train time: 0.50 	 Val time 0.03
Train loss 0.5773 	 Train acc: 0.7412
Val loss: 0.5834 	 Val acc: 0.7400
Best val acc: 0.7400 at epoch 14.


50it [00:00, 100.59it/s]


>> train 	 Epoch 16/45 	 Batch 49/50 	 Loss 0.5816 	 Running Acc 0.750 	 Total Acc 0.750 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.5816 	 Train acc: 0.7500


7it [00:00, 186.87it/s]


>> val 	 Loss 0.6023 	 Running Acc 2.643 	 Total Acc 0.740 	 Avg Batch Time 0.0016
Epoch 15/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5816 	 Train acc: 0.7500
Val loss: 0.5834 	 Val acc: 0.7400
Best val acc: 0.7400 at epoch 14.


50it [00:00, 104.91it/s]


>> train 	 Epoch 17/45 	 Batch 49/50 	 Loss 0.5808 	 Running Acc 0.754 	 Total Acc 0.754 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5808 	 Train acc: 0.7538


7it [00:00, 206.90it/s]


>> val 	 Loss 0.5948 	 Running Acc 2.500 	 Total Acc 0.700 	 Avg Batch Time 0.0015
Epoch 16/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5808 	 Train acc: 0.7538
Val loss: 0.5730 	 Val acc: 0.7000
Best val acc: 0.7400 at epoch 14.


50it [00:00, 105.55it/s]


>> train 	 Epoch 18/45 	 Batch 49/50 	 Loss 0.5707 	 Running Acc 0.751 	 Total Acc 0.751 	 Avg Batch Time 0.0095
Time: train: 0.48 	 Train loss 0.5707 	 Train acc: 0.7512


7it [00:00, 199.53it/s]


>> val 	 Loss 0.5902 	 Running Acc 2.643 	 Total Acc 0.740 	 Avg Batch Time 0.0015
Epoch 17/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5707 	 Train acc: 0.7512
Val loss: 0.5675 	 Val acc: 0.7400
Best val acc: 0.7400 at epoch 14.


50it [00:00, 102.91it/s]


>> train 	 Epoch 19/45 	 Batch 49/50 	 Loss 0.5673 	 Running Acc 0.750 	 Total Acc 0.750 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.5673 	 Train acc: 0.7500


7it [00:00, 210.29it/s]


>> val 	 Loss 0.5868 	 Running Acc 2.714 	 Total Acc 0.760 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 18/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5673 	 Train acc: 0.7500
Val loss: 0.5631 	 Val acc: 0.7600
Best val acc: 0.7600 at epoch 18.


50it [00:00, 104.43it/s]


>> train 	 Epoch 20/45 	 Batch 49/50 	 Loss 0.5534 	 Running Acc 0.759 	 Total Acc 0.759 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5534 	 Train acc: 0.7588


7it [00:00, 184.16it/s]


>> val 	 Loss 0.5832 	 Running Acc 2.714 	 Total Acc 0.760 	 Avg Batch Time 0.0016
Epoch 19/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5534 	 Train acc: 0.7588
Val loss: 0.5579 	 Val acc: 0.7600
Best val acc: 0.7600 at epoch 18.


50it [00:00, 102.25it/s]


>> train 	 Epoch 21/45 	 Batch 49/50 	 Loss 0.5535 	 Running Acc 0.745 	 Total Acc 0.745 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.5535 	 Train acc: 0.7450


7it [00:00, 198.52it/s]


>> val 	 Loss 0.5808 	 Running Acc 2.714 	 Total Acc 0.760 	 Avg Batch Time 0.0015
Epoch 20/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5535 	 Train acc: 0.7450
Val loss: 0.5541 	 Val acc: 0.7600
Best val acc: 0.7600 at epoch 18.


50it [00:00, 100.68it/s]


>> train 	 Epoch 22/45 	 Batch 49/50 	 Loss 0.5495 	 Running Acc 0.743 	 Total Acc 0.743 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.5495 	 Train acc: 0.7425


7it [00:00, 201.30it/s]


>> val 	 Loss 0.5768 	 Running Acc 2.714 	 Total Acc 0.760 	 Avg Batch Time 0.0015
Epoch 21/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5495 	 Train acc: 0.7425
Val loss: 0.5492 	 Val acc: 0.7600
Best val acc: 0.7600 at epoch 18.


50it [00:00, 100.12it/s]


>> train 	 Epoch 23/45 	 Batch 49/50 	 Loss 0.5445 	 Running Acc 0.750 	 Total Acc 0.750 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.5445 	 Train acc: 0.7500


7it [00:00, 197.96it/s]


>> val 	 Loss 0.5764 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 0.0015
New best validation model, saving...
Epoch 22/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5445 	 Train acc: 0.7500
Val loss: 0.5478 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 22.


50it [00:00, 100.70it/s]


>> train 	 Epoch 24/45 	 Batch 49/50 	 Loss 0.5416 	 Running Acc 0.750 	 Total Acc 0.750 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.5416 	 Train acc: 0.7500


7it [00:00, 202.03it/s]


>> val 	 Loss 0.5757 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 0.0015
Epoch 23/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5416 	 Train acc: 0.7500
Val loss: 0.5463 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 22.


50it [00:00, 102.80it/s]


>> train 	 Epoch 25/45 	 Batch 49/50 	 Loss 0.5418 	 Running Acc 0.752 	 Total Acc 0.752 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.5418 	 Train acc: 0.7525


7it [00:00, 218.36it/s]


>> val 	 Loss 0.5749 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
New best validation model, saving...
Epoch 24/45 finished.
Train time: 0.49 	 Val time 0.03
Train loss 0.5418 	 Train acc: 0.7525
Val loss: 0.5456 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 104.93it/s]


>> train 	 Epoch 26/45 	 Batch 49/50 	 Loss 0.5372 	 Running Acc 0.746 	 Total Acc 0.746 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5372 	 Train acc: 0.7462


7it [00:00, 209.82it/s]


>> val 	 Loss 0.5730 	 Running Acc 2.786 	 Total Acc 0.780 	 Avg Batch Time 0.0015
Epoch 25/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5372 	 Train acc: 0.7462
Val loss: 0.5429 	 Val acc: 0.7800
Best val acc: 0.8000 at epoch 24.


50it [00:00, 107.01it/s]


>> train 	 Epoch 27/45 	 Batch 49/50 	 Loss 0.5411 	 Running Acc 0.746 	 Total Acc 0.746 	 Avg Batch Time 0.0094
Time: train: 0.47 	 Train loss 0.5411 	 Train acc: 0.7462


7it [00:00, 220.08it/s]


>> val 	 Loss 0.5727 	 Running Acc 2.786 	 Total Acc 0.780 	 Avg Batch Time 0.0014
Epoch 26/45 finished.
Train time: 0.47 	 Val time 0.03
Train loss 0.5411 	 Train acc: 0.7462
Val loss: 0.5424 	 Val acc: 0.7800
Best val acc: 0.8000 at epoch 24.


50it [00:00, 104.02it/s]


>> train 	 Epoch 28/45 	 Batch 49/50 	 Loss 0.5394 	 Running Acc 0.752 	 Total Acc 0.752 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.5394 	 Train acc: 0.7525


7it [00:00, 219.51it/s]


>> val 	 Loss 0.5727 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 27/45 finished.
Train time: 0.48 	 Val time 0.03
Train loss 0.5394 	 Train acc: 0.7525
Val loss: 0.5427 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 94.78it/s] 


>> train 	 Epoch 29/45 	 Batch 49/50 	 Loss 0.5389 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0106
Time: train: 0.53 	 Train loss 0.5389 	 Train acc: 0.7575


7it [00:00, 205.82it/s]


>> val 	 Loss 0.5720 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0015
Epoch 28/45 finished.
Train time: 0.53 	 Val time 0.04
Train loss 0.5389 	 Train acc: 0.7575
Val loss: 0.5417 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 104.62it/s]


>> train 	 Epoch 30/45 	 Batch 49/50 	 Loss 0.5317 	 Running Acc 0.746 	 Total Acc 0.746 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5317 	 Train acc: 0.7462


7it [00:00, 207.60it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 29/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5317 	 Train acc: 0.7462
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 100.65it/s]


>> train 	 Epoch 31/45 	 Batch 49/50 	 Loss 0.5410 	 Running Acc 0.750 	 Total Acc 0.750 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.5410 	 Train acc: 0.7500


7it [00:00, 202.54it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0015
Epoch 30/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5410 	 Train acc: 0.7500
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 104.11it/s]


>> train 	 Epoch 32/45 	 Batch 49/50 	 Loss 0.5361 	 Running Acc 0.749 	 Total Acc 0.749 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5361 	 Train acc: 0.7488


7it [00:00, 215.22it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 31/45 finished.
Train time: 0.48 	 Val time 0.03
Train loss 0.5361 	 Train acc: 0.7488
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 105.01it/s]


>> train 	 Epoch 33/45 	 Batch 49/50 	 Loss 0.5319 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5319 	 Train acc: 0.7575


7it [00:00, 207.05it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 32/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5319 	 Train acc: 0.7575
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 103.69it/s]


>> train 	 Epoch 34/45 	 Batch 49/50 	 Loss 0.5300 	 Running Acc 0.761 	 Total Acc 0.761 	 Avg Batch Time 0.0097
Time: train: 0.49 	 Train loss 0.5300 	 Train acc: 0.7612


7it [00:00, 207.94it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 33/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5300 	 Train acc: 0.7612
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 101.28it/s]


>> train 	 Epoch 35/45 	 Batch 49/50 	 Loss 0.5316 	 Running Acc 0.749 	 Total Acc 0.749 	 Avg Batch Time 0.0099
Time: train: 0.50 	 Train loss 0.5316 	 Train acc: 0.7488


7it [00:00, 208.45it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0015
Epoch 34/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5316 	 Train acc: 0.7488
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 103.79it/s]


>> train 	 Epoch 36/45 	 Batch 49/50 	 Loss 0.5314 	 Running Acc 0.762 	 Total Acc 0.762 	 Avg Batch Time 0.0097
Time: train: 0.48 	 Train loss 0.5314 	 Train acc: 0.7625


7it [00:00, 212.08it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 35/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5314 	 Train acc: 0.7625
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 103.47it/s]


>> train 	 Epoch 37/45 	 Batch 49/50 	 Loss 0.5332 	 Running Acc 0.749 	 Total Acc 0.749 	 Avg Batch Time 0.0097
Time: train: 0.49 	 Train loss 0.5332 	 Train acc: 0.7488


7it [00:00, 210.50it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0015
Epoch 36/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5332 	 Train acc: 0.7488
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 101.29it/s]


>> train 	 Epoch 38/45 	 Batch 49/50 	 Loss 0.5346 	 Running Acc 0.764 	 Total Acc 0.764 	 Avg Batch Time 0.0099
Time: train: 0.50 	 Train loss 0.5346 	 Train acc: 0.7638


7it [00:00, 211.28it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 37/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5346 	 Train acc: 0.7638
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 101.28it/s]


>> train 	 Epoch 39/45 	 Batch 49/50 	 Loss 0.5302 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0099
Time: train: 0.50 	 Train loss 0.5302 	 Train acc: 0.7575


7it [00:00, 205.49it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 38/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5302 	 Train acc: 0.7575
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 104.68it/s]


>> train 	 Epoch 40/45 	 Batch 49/50 	 Loss 0.5334 	 Running Acc 0.755 	 Total Acc 0.755 	 Avg Batch Time 0.0096
Time: train: 0.48 	 Train loss 0.5334 	 Train acc: 0.7550


7it [00:00, 214.11it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 39/45 finished.
Train time: 0.48 	 Val time 0.04
Train loss 0.5334 	 Train acc: 0.7550
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 101.76it/s]


>> train 	 Epoch 41/45 	 Batch 49/50 	 Loss 0.5328 	 Running Acc 0.755 	 Total Acc 0.755 	 Avg Batch Time 0.0099
Time: train: 0.49 	 Train loss 0.5328 	 Train acc: 0.7550


7it [00:00, 203.93it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0015
Epoch 40/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5328 	 Train acc: 0.7550
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 102.94it/s]


>> train 	 Epoch 42/45 	 Batch 49/50 	 Loss 0.5353 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0098
Time: train: 0.49 	 Train loss 0.5353 	 Train acc: 0.7575


7it [00:00, 195.48it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0015
Epoch 41/45 finished.
Train time: 0.49 	 Val time 0.04
Train loss 0.5353 	 Train acc: 0.7575
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 100.96it/s]


>> train 	 Epoch 43/45 	 Batch 49/50 	 Loss 0.5338 	 Running Acc 0.748 	 Total Acc 0.748 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.5338 	 Train acc: 0.7475


7it [00:00, 210.59it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 42/45 finished.
Train time: 0.50 	 Val time 0.04
Train loss 0.5338 	 Train acc: 0.7475
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 89.39it/s] 


>> train 	 Epoch 44/45 	 Batch 49/50 	 Loss 0.5306 	 Running Acc 0.762 	 Total Acc 0.762 	 Avg Batch Time 0.0112
Time: train: 0.56 	 Train loss 0.5306 	 Train acc: 0.7625


7it [00:00, 186.58it/s]


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0016
Epoch 43/45 finished.
Train time: 0.56 	 Val time 0.04
Train loss 0.5306 	 Train acc: 0.7625
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


50it [00:00, 97.47it/s]


>> train 	 Epoch 45/45 	 Batch 49/50 	 Loss 0.5352 	 Running Acc 0.754 	 Total Acc 0.754 	 Avg Batch Time 0.0103
Time: train: 0.52 	 Train loss 0.5352 	 Train acc: 0.7538


7it [00:00, 213.61it/s]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.5718 	 Running Acc 2.857 	 Total Acc 0.800 	 Avg Batch Time 0.0014
Epoch 44/45 finished.
Train time: 0.52 	 Val time 0.03
Train loss 0.5352 	 Train acc: 0.7538
Val loss: 0.5415 	 Val acc: 0.8000
Best val acc: 0.8000 at epoch 24.


7it [00:00, 198.18it/s]

>> test 	 Loss 0.5655 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 0.0015
Final  tensor([[1.0000, 0.3362, 0.6638],
        [1.0000, 0.2920, 0.7080],
        [0.0000, 0.7204, 0.2796],
        [1.0000, 0.5617, 0.4383],
        [1.0000, 0.3015, 0.6985],
        [0.0000, 0.6670, 0.3330],
        [0.0000, 0.7227, 0.2773],
        [1.0000, 0.4954, 0.5046],
        [1.0000, 0.3361, 0.6639],
        [0.0000, 0.5018, 0.4982],
        [0.0000, 0.6074, 0.3926],
        [0.0000, 0.6310, 0.3690],
        [0.0000, 0.3891, 0.6109],
        [0.0000, 0.5062, 0.4938],
        [0.0000, 0.5912, 0.4088],
        [0.0000, 0.5481, 0.4519],
        [0.0000, 0.5460, 0.4540],
        [1.0000, 0.4910, 0.5090],
        [0.0000, 0.7435, 0.2565],
        [1.0000, 0.2933, 0.7067],
        [1.0000, 0.4500, 0.5500],
        [0.0000, 0.4926, 0.5074],
        [0.0000, 0.2876, 0.7124],
        [0.0000, 0.5232, 0.4768],
        [1.0000, 0.7077, 0.2923],
        [0.0000, 0.8423, 0.1577],
        [0.0000, 0.8399, 




# 4. Equivariant Quantum Neural Networks
Now, let's move to quantum machine learning. Given some group $\mathcal{G}$, one common way to achieve equivariance [6] is to have a quantum neural network of the form $h_{\theta} = Tr[\rho \tilde{O}_{\theta}]$ such that:

$$\begin{align*}
\tilde{O}_{\theta} \in Comm(G) = \{A \in \mathbb{C}^{d\times d} / [A, R(g)] = 0 \text{ for all } g \in G\}
\end{align*}$$

To see why, we need to observe that the trace is cyclical, so:

$$\begin{align*}
    h_{\theta} (g\cdot \rho) = Tr[R(g)\rho R^{\dagger}(g)\tilde{O}_{\theta}] = Tr[\rho R^{\dagger}(g)\tilde{O}_{\theta}R(g)] &= Tr[\rho R^{\dagger}(g)R(g)\tilde{O}_{\theta}]\\
    &= Tr[\rho \tilde{O}_{\theta}]\\
    &= h_{\theta}(\rho).
\end{align*}$$

Essentially, we are using the Heisenberg picture, where we apply the time evolution to the measurement operator instead of the initial quantum state. When the observable is included in the commutant of $\mathcal{G}$, we can see how invariance is achieved.

The challenge with this approach is that it only works for finite-dimensional and compact groups, like $p4m$, $SO(3)$, etc. The Lorentz group is known to be continuous and non-compact, so it has no finite-dimensional unitary representation. Hence, the approach above is of no use for us. Hopefully, there is another way: instead of baking equivariance directly into the ansatze, we'll do it in the feature space and in the message passing function. When the input is invariant, the message passing becomes equivariant.

Similarly to LorentzNet, for standard jet tagging approach, our input is made of $4$-momentum vectors and any associated particle scalar one may wish to include, like color and charge. In fact, in this project, we start with the traditional LorentzNet architecture, but two modifications are made: first, the invariant metric can be extracted from the machine learned algebra; secondly, the $\phi_e, \phi_x, \phi_h$ and $\phi_m$ - classical parts modeled as classical multilayer perceptrons in Lorentznet, are now substituted by quantum parameterized circuits. Below we show how invariance-equivariance is preserved under this modification.

## Infrared safe observables
Another interesting bias to incorporate is the infrared and collinear (IRC) safety. An infrared and collinear safe observable is the same in the presence or absence of soft or collinear particles. In [6], an IRC-safe equivariant (classical) GNN was proposed for tagging simulated semi-visible jets from Hidden Valley models, showing superior performance on this data for Beyond the Standard Model (BSM) search.

We saw before that in LorentzNet, the message is calculated as:

\begin{equation}
m_{ij}^{l} = \phi_{e}(h_i, h_j, \psi(||x_{i}^{l} - x_{j}^{l}||^2), \psi(\langle x_{i}, x_{j}\rangle)),
\end{equation}

Now, intuitively, an IRC-safe model should give us a graph that stays invariant under any particle corresponding to an infinitesimal emission, or a collinear one. This means that such particles have no influence on other particles in our point cloud. But, how can we do this? Message passing!

\begin{align}
\text{IR safety}:& m^{l}(i,j) \rightarrow 0 \text{ as } z \rightarrow 0,\\
\text{C safety}:& m^{l}(i,j + r) = m^{l}(i,j) + m^{l}(i,r) \text{ as } \Delta_{jr} \rightarrow 0,
\end{align}

To ensure IR safety, we can not use $z_j$ directly, as it breaks equivariance. We propose, thus, the following substitution:

\begin{equation}
m_{ij}^{l} = \frac{\langle x_i , x_j\rangle}{\sum_{k \in \mathcal{N(j)} } \langle x_i , x_k\rangle } \cdot \phi_{e}(h_i, h_j, \psi(||x_{i}^{l} - x_{j}^{l}||^2), \psi(\langle x_{i}, x_{j}\rangle)),
\end{equation}


Where $\langle \cdot,\cdot\rangle$ is the Minkowski inner product, and $\mathcal{N(j)}$ represents all neighboring particles of $j$. If $j$ is a soft particle, then the Minkowski inner product should be small, thus , which makes the edge connection irrelevant, thus ensuring IR safety. Also, any Lorentz transformation preserves the inner product, so the message should remain symmetry-preserving.


## 4.1. Lorentz Equivariant Quantum Block (LEQB)

LEQB is the main piece of our model. We aim to fundamentally learn deeper quantum representations of $|\psi_{x}^{l+1}\rangle$ and $|\psi_h^{l+1} \rangle$ from $|\psi_{x}^{l} \rangle$ and $|\psi_{h}^{l}\rangle$, where:

$$\begin{align}
    |\psi_{x}^{l+1}\rangle &= \mathcal{U}_{x^{l+1}}({x}^{l})|0\rangle,\\
    |\psi_{h}^{l+1}\rangle &= \mathcal{U}_{h^{l+1}}({h}^{l})|0\rangle,
\end{align}$$

where $\mathcal{U_{x^{l}}}, \mathcal{U_{x^{l+1}}}, \mathcal{U_{h^{l}}}, \mathcal{U_{h^{l+1}}}$ are all parameterized standard gate unitaries, or variational circuits. Note that $x^{l}$ are the observables and $h^{l}$ are the particle scalars when $l=0$, but $x^{l} = \langle \psi_x | \mathcal{M} | \psi_x\rangle$ and $h^{l} = \langle \psi_h | \mathcal{M} | \psi_h\rangle$ for $l > 0$, where $\mathcal{M}$ is some measurement operator.

## 4.2. Theoretical analysis

Let's start with the following proposition:

> The coordinate embedding $x^{l} = \{x_1^{l} , x_2^{l} , \dots , x_n^{l}\}$ is Lorentz group equivariant and the node embedding $h^{l} = \{h_1^{l} , h_2^{l}, \dots , h_n^{l}\}$ - representing the particle scalars - is Lorentz group invariant.

To prove it, let $Q$ be some Lie group transformation. If the message $m_{ij}^{l}$ is invariant under the action of $Q$ for all $i,j,l,$ then $x_{i}^{l}$ is naturally Lie group equivariant since:

$$\begin{align*}
    Q\cdot x_i^{l+1} &= Q(x_i^{l} + \sum_{j\in \mathcal{N}(i)} x_j^{l}\cdot \phi_x (m_{ij}^{l}))\\
    &= Q\cdot x_i^{l} + \sum_{j\in \mathcal{N}(i)} Q\cdot x_j^{l}\cdot \phi_x (m_{ij}^{l}),
\end{align*}$$

where $Q$ acts under matrix multiplication. The equation above means that acting with $Q$ from the outside is the same as acting with $Q$ from the inside - directly into the node embeddings from the layer before. Then, for the invariance of $m_{ij}^{l}$, since the norm induced by the extracted metric is invariant under the action of $Q$, it holds that $\|\|x_{i}^{0} - x_{j}^{0}\|\|^2 = \|\|Q\cdot x_{i}^{0} - Q\cdot x_{j}^{0}\|\|^2$, and $\langle x_{i}^{0}, x_{j}^{0} \rangle = \langle Q\cdot x_{i}^{0}, Q\cdot x_{j}^{0} \rangle$. Since $m_{ij}^{l+1} = \phi_e(h_i^{l}, h_j^{l}, \|\|x_{i}^{l} - x_{j}^{l}\|\|^2, \langle x_{i}^{l}, x_{j}^{l} \rangle)$, and the norm and the inner product are already invariant, we just have to show that $h^{l}$ is also invariant, since:

$$\begin{equation*}
    h_i^{l+1} = h_i^{l} + \phi_h (h_i^{l}, \sum_{j\in \mathcal{N}(i)} w_{ij} m_{ij}^{l}).
\end{equation*}$$
    
For layer $l=0$, $h_{i}^{l}$ is already invariant (since it contains information only about the particle scalars). Then, $m_{ij}^{l+1}$ will be invariant, since all of its inputs are also invariant, and we follow the same logic for $x_{i}^{l+1}$. Given that these properties of $x,h,m$ hold for the first layer and the next, we reach the conclusion recursively.

Having a quick glance at the discussion we had about groups, equivariance, particles and quantum machine learning, we are getting a hint that the marriage between Physics and symmetries is actually deep. Indeed it is! To quote Philip Anderson, who won the 1977 Nobel prize “for their fundamental theoretical investigations of the electronic structure of magnetic and disordered systems”:

> It is only slightly overstating the case to say that physics is the study of symmetry.

In [33]:
import torch
import pennylane as qml
import torch.nn.functional as F
from torch import nn
from torch_geometric.utils import to_dense_adj

n_qubits = 4

dev = qml.device('default.qubit', wires=n_qubits)
# dev = qml.device("qiskit.aer", wires=n_qubits)


def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)

def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)

def RY_RX_layer(weights):
    """Applies a layer of parametrized RY and RX rotations."""
    for i, w in enumerate(weights):
        qml.RY(w, wires=i)
        qml.RX(w, wires=i)

def full_entangling_layer(n_qubits):
    """Applies CNOT gates between all pairs of qubits."""
    for i in range(n_qubits):
        for j in range(i+1, n_qubits):
            qml.CNOT(wires=[i, j])

def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(nqubits - 1):
        qml.CRZ(np.pi / 2, wires=[i, i + 1])
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.SWAP(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.SWAP(wires=[i, i + 1])


@qml.qnode(dev, interface="torch")
def quantum_net(q_input_features, q_weights_flat, q_depth, n_qubits):
    """
    The variational quantum circuit.
    """

    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    # RY_layer(q_input_features)
    qml.AngleEmbedding(features=q_input_features, wires=range(n_qubits), rotation='Z')

    # Sequence of trainable variational layers
    # for k in range(q_depth):
    #     entangling_layer(n_qubits)
    #     RY_RX_layer(q_weights[k])
    #     # RY_layer(q_weights[k])
    for k in range(q_depth):
        if k % 2 == 0:
            entangling_layer(n_qubits)
            RY_layer(q_weights[k])
        else:
            full_entangling_layer(n_qubits)
            RY_RX_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self, n_qubits, q_depth = 1, q_delta=0.001):
        """
        Definition of the *dressed* layout.
        """
        print('n_qubits: ', n_qubits)
        super().__init__()
        self.n_qubits = n_qubits
        self.q_depth = q_depth
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))

    def forward(self, input_features):
        """
        Optimized forward pass to reduce runtime.
        """

        # Quantum Embedding (U(X))
        q_in = torch.tanh(input_features) * np.pi / 2.0

        # Preallocate output tensor
        batch_size = q_in.shape[0]
        q_out = torch.zeros(batch_size, self.n_qubits, device=q_in.device)

        # Vectorized execution
        for i, elem in enumerate(q_in):
            q_out_elem = torch.hstack(quantum_net(elem, self.q_params, self.q_depth, self.n_qubits)).float()
            q_out[i] = q_out_elem

        return q_out

In [34]:
# @title
import torch
from torch import nn
import numpy as np
import pennylane as qml

"""
    Lie-Equivariant Quantum Block (LEQB).

        - Given the Lie generators found (i.e.: through LieGAN, oracle-preserving latent flow, or some other approach
          that we develop further), once the metric tensor J is found via the equation:

                          L.J + J.(L^T) = 0,

          we just have to specify the metric to make the model symmetry-preserving to the corresponding Lie group.
          In the cells below, we can see how the model preserves symmetries (starting with the default Lorentz group),
          and when we change J to some other metric (Euclidean, for example), Lorentz boosts **break** equivariance, while other
          transformations preserve it (rotations, for the example shown in the cells below)
"""
class LEQB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LEQB, self).__init__()
        self.c_weight = c_weight
        self.dimension_reducer = nn.Linear(10, 4) # New linear layer for dimension reduction
        self.dimension_reducer2 = nn.Linear(9, 4) # New linear layer for dimension reduction for phi_h
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2. n_input = n_hidden = 4
        print('Input size of phi_e: ', n_input)
        self.include_x = include_x

        """
            phi_e: input size: n_qubits -> output size: n_qubits
            n_hidden has to be equal to n_input,
            but this is just considering that this is a simple working example.
        """
#         self.phi_e = DressedQuantumNet(n_input)
        self.phi_e = nn.Sequential(
            nn.Linear(n_input, n_hidden, bias=False),  # n_input * 2 + n_edge_attr
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU())

        n_hidden = n_input # n_input * 2 + n_edge_attr
#         self.phi_h = nn.Sequential(
#             nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
#             nn.BatchNorm1d(n_hidden),
#             nn.ReLU(),
#             nn.Linear(n_hidden, n_output))

        self.phi_h = DressedQuantumNet(n_hidden)

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)
        
#         self.phi_x = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             layer)

#         self.phi_m = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             nn.Linear(n_hidden, 1),
#             nn.Sigmoid())
        
        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        # self.phi_e = nn.Sequential(
        #     nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4

    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        out = self.dimension_reducer(out) # extra
        # print("Before embedding to |psi> : ", out)
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.dimension_reducer(out) # extra
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        agg = self.dimension_reducer2(agg) # extra for phi_h
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m):
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            x = self.x_model(x, edges, x_diff, m)
        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LieEQGNN(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LEQB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LieEQGNN, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LEQBs = nn.ModuleList([LEQB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LEQBs[i](h, x, edges, node_attr=scalars)

        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)

In [35]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

if __name__ == "__main__":

    N_EPOCHS = 7 # 60

    model_path = "models/LieEQGNN/"
    log_path = "logs/LieEQGNN/"
    # utils_lorentz.args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cpu
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cuda")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    model = LieEQGNN(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)



Input size of phi_e:  4
n_qubits:  4
Model Size: 175
 train samples: 800
 val samples: 100
 test samples: 100
Training...


50it [30:29, 36.59s/it]


>> train 	 Epoch 1/7 	 Batch 49/50 	 Loss 0.6938 	 Running Acc 0.525 	 Total Acc 0.525 	 Avg Batch Time 36.5897
Time: train: 1829.49 	 Train loss 0.6938 	 Train acc: 0.5250


7it [02:24, 20.64s/it]


>> val 	 Loss 0.7071 	 Running Acc 1.643 	 Total Acc 0.460 	 Avg Batch Time 5.7779
New best validation model, saving...
Epoch 0/7 finished.
Train time: 1829.49 	 Val time 144.45
Train loss 0.6938 	 Train acc: 0.5250
Val loss: 0.7078 	 Val acc: 0.4600
Best val acc: 0.4600 at epoch 0.


50it [30:15, 36.31s/it]


>> train 	 Epoch 2/7 	 Batch 49/50 	 Loss 0.6900 	 Running Acc 0.525 	 Total Acc 0.525 	 Avg Batch Time 36.3126
Time: train: 1815.63 	 Train loss 0.6900 	 Train acc: 0.5250


7it [02:23, 20.53s/it]


>> val 	 Loss 0.7006 	 Running Acc 1.643 	 Total Acc 0.460 	 Avg Batch Time 5.7496
Epoch 1/7 finished.
Train time: 1815.63 	 Val time 143.74
Train loss 0.6900 	 Train acc: 0.5250
Val loss: 0.7006 	 Val acc: 0.4600
Best val acc: 0.4600 at epoch 0.


50it [29:57, 35.95s/it]


>> train 	 Epoch 3/7 	 Batch 49/50 	 Loss 0.6823 	 Running Acc 0.526 	 Total Acc 0.526 	 Avg Batch Time 35.9510
Time: train: 1797.55 	 Train loss 0.6823 	 Train acc: 0.5262


7it [02:23, 20.55s/it]


>> val 	 Loss 0.6914 	 Running Acc 1.643 	 Total Acc 0.460 	 Avg Batch Time 5.7553
Epoch 2/7 finished.
Train time: 1797.55 	 Val time 143.88
Train loss 0.6823 	 Train acc: 0.5262
Val loss: 0.6903 	 Val acc: 0.4600
Best val acc: 0.4600 at epoch 0.


50it [30:04, 36.10s/it]


>> train 	 Epoch 4/7 	 Batch 49/50 	 Loss 0.6730 	 Running Acc 0.610 	 Total Acc 0.610 	 Avg Batch Time 36.0982
Time: train: 1804.91 	 Train loss 0.6730 	 Train acc: 0.6100


7it [02:21, 20.27s/it]


>> val 	 Loss 0.6780 	 Running Acc 2.143 	 Total Acc 0.600 	 Avg Batch Time 5.6744
New best validation model, saving...
Epoch 3/7 finished.
Train time: 1804.91 	 Val time 141.86
Train loss 0.6730 	 Train acc: 0.6100
Val loss: 0.6749 	 Val acc: 0.6000
Best val acc: 0.6000 at epoch 3.


50it [30:01, 36.04s/it]


>> train 	 Epoch 5/7 	 Batch 49/50 	 Loss 0.6595 	 Running Acc 0.708 	 Total Acc 0.708 	 Avg Batch Time 36.0386
Time: train: 1801.93 	 Train loss 0.6595 	 Train acc: 0.7075


7it [02:22, 20.40s/it]


>> val 	 Loss 0.6634 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 5.7113
New best validation model, saving...
Epoch 4/7 finished.
Train time: 1801.93 	 Val time 142.78
Train loss 0.6595 	 Train acc: 0.7075
Val loss: 0.6576 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 4.


50it [30:01, 36.02s/it]


>> train 	 Epoch 6/7 	 Batch 49/50 	 Loss 0.6440 	 Running Acc 0.723 	 Total Acc 0.723 	 Avg Batch Time 36.0219
Time: train: 1801.10 	 Train loss 0.6440 	 Train acc: 0.7225


7it [02:21, 20.18s/it]


>> val 	 Loss 0.6513 	 Running Acc 2.500 	 Total Acc 0.700 	 Avg Batch Time 5.6500
Epoch 5/7 finished.
Train time: 1801.10 	 Val time 141.25
Train loss 0.6440 	 Train acc: 0.7225
Val loss: 0.6429 	 Val acc: 0.7000
Best val acc: 0.7700 at epoch 4.


50it [30:06, 36.12s/it]


>> train 	 Epoch 7/7 	 Batch 49/50 	 Loss 0.6374 	 Running Acc 0.719 	 Total Acc 0.719 	 Avg Batch Time 36.1234
Time: train: 1806.17 	 Train loss 0.6374 	 Train acc: 0.7188


7it [02:25, 20.81s/it]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.6462 	 Running Acc 2.571 	 Total Acc 0.720 	 Avg Batch Time 5.8282
Epoch 6/7 finished.
Train time: 1806.17 	 Val time 145.71
Train loss 0.6374 	 Train acc: 0.7188
Val loss: 0.6364 	 Val acc: 0.7200
Best val acc: 0.7700 at epoch 4.


7it [02:22, 20.32s/it]

>> test 	 Loss 0.6578 	 Running Acc 2.643 	 Total Acc 0.740 	 Avg Batch Time 5.6903
Final  tensor([[1.0000, 0.4694, 0.5306],
        [1.0000, 0.4544, 0.5456],
        [0.0000, 0.5374, 0.4626],
        [1.0000, 0.5105, 0.4895],
        [1.0000, 0.4691, 0.5309],
        [0.0000, 0.5211, 0.4789],
        [0.0000, 0.5316, 0.4684],
        [1.0000, 0.5077, 0.4923],
        [1.0000, 0.4692, 0.5308],
        [0.0000, 0.5031, 0.4969],
        [0.0000, 0.5121, 0.4879],
        [0.0000, 0.5324, 0.4676],
        [0.0000, 0.4891, 0.5109],
        [0.0000, 0.4977, 0.5023],
        [0.0000, 0.5266, 0.4734],
        [0.0000, 0.5010, 0.4990],
        [0.0000, 0.5178, 0.4822],
        [1.0000, 0.4933, 0.5067],
        [0.0000, 0.5447, 0.4553],
        [1.0000, 0.4586, 0.5414],
        [1.0000, 0.4948, 0.5052],
        [0.0000, 0.5033, 0.4967],
        [0.0000, 0.4568, 0.5432],
        [0.0000, 0.4942, 0.5058],
        [1.0000, 0.5280, 0.4720],
        [0.0000, 0.5739, 0.4261],
        [0.0000, 0.5670, 


