# Lie-Equivariant Quantum Graph Neural Network (Lie-EQGNN)

In [1]:
# For Colab
!pip install torch_geometric
# !pip install torch_sparse
# !pip install torch_scatter

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
!pip install pennylane qiskit pennylane-qiskit pylatexenc

Collecting pennylane
  Downloading PennyLane-0.39.0-py3-none-any.whl.metadata (9.2 kB)
Collecting qiskit
  Downloading qiskit-1.3.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting pennylane-qiskit
  Downloading PennyLane_qiskit-0.39.1-py3-none-any.whl.metadata (6.4 kB)
Collecting pylatexenc
  Downloading pylatexenc-2.10.tar.gz (162 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.15.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting autograd (from pennylane)
  Downloading autograd-1.7.0-py3-none-any.whl.metadata (7.5 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.39 (from pennylane)
  Downloading PennyLane_Lightning-0.3

In [3]:
import pennylane as qml
import qiskit
print(qml.__version__)
print(qiskit.__version__)
import pennylane_qiskit
print(pennylane_qiskit.__version__)
import pennylane as qml
from pennylane import numpy as np
# from pennylane_qiskit import AerDevice

0.39.0
1.2.4
0.39.1


In [4]:
!pip install energyflow

Collecting energyflow
  Downloading energyflow-1.4.0-py3-none-any.whl.metadata (5.6 kB)
Collecting h5py!=3.11.0,>=2.9.0 (from energyflow)
  Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting wasserstein>=1.0.1 (from energyflow)
  Downloading wasserstein-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Downloading energyflow-1.4.0-py3-none-any.whl (700 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m700.8/700.8 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m75.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hDownloading wasserstein-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m502.2/502.2 kB[0m [31m20

In [36]:
import torch
import numpy as np
import energyflow
from scipy.sparse import coo_matrix
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data.distributed import DistributedSampler


# we define a function to return an adjacencyy matrix
# for our graph data representing the jets.
def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    # print(edge_mask[0])
    # raise
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask *= diag_mask
    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edge_mask, edges]

def retrieve_dataloaders(batch_size, num_data = -1, use_one_hot = False, cache_dir = './data', num_workers=4):
    raw = energyflow.qg_jets.load(num_data=num_data, pad=True, ncol=4, generator='pythia',
                            with_bc=False, cache_dir=cache_dir)
    splits = ['train', 'val', 'test']
    data = {type:{'raw':None,'label':None} for type in splits}
    (data['train']['raw'],  data['val']['raw'],   data['test']['raw'],
    data['train']['label'], data['val']['label'], data['test']['label']) = \
        energyflow.utils.data_split(*raw, train=0.8, val=0.1, test=0.1, shuffle = False)

    enc = OneHotEncoder(handle_unknown='ignore').fit([[11],[13],[22],[130],[211],[321],[2112],[2212]])

    for split, value in data.items():
        pid = torch.from_numpy(np.abs(np.asarray(value['raw'][...,3], dtype=int))).unsqueeze(-1)
        p4s = torch.from_numpy(energyflow.p4s_from_ptyphipids(value['raw'],error_on_unknown=True))
        one_hot = enc.transform(pid.reshape(-1,1)).toarray().reshape(pid.shape[:2]+(-1,))
        # one_hot = enc.transform(pid.reshape(-1,1)).toarray().reshape(pid.shape[:2]+(-1,))
        one_hot = torch.from_numpy(one_hot)
        mass = torch.from_numpy(energyflow.ms_from_p4s(p4s)).unsqueeze(-1)
        charge = torch.from_numpy(energyflow.pids2chrgs(pid))

        if use_one_hot:
            nodes = one_hot

        # else:
        #     nodes = torch.cat((mass,charge),dim=-1)

        #     nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)


        else:
              # Concatenate mass and charge along the last dimension
              concatenated = torch.cat((mass, charge), dim=-1)  # Shape (batch_size, n_nodes, 2)

              # Reduce along the last dimension (e.g., by summing or averaging)
              nodes = concatenated.sum(dim=-1, keepdim=True)  # Shape (batch_size, n_nodes, 1)

              # Apply log-sign transformation if needed
              nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)

        atom_mask = (pid[...,0] != 0)

        value['p4s'] = p4s
        value['nodes'] = nodes
        value['label'] = torch.from_numpy(value['label'])
        value['atom_mask'] = atom_mask.to(torch.bool)

        if split == 'train':
            print(value['atom_mask'])

    datasets = {split: TensorDataset(value['label'], value['p4s'],
                                     value['nodes'], value['atom_mask'])
                for split, value in data.items()}

    # distributed training
    # train_sampler = DistributedSampler(datasets['train'], shuffle=True)
    # Construct PyTorch dataloaders from datasets
    dataloaders = {split: DataLoader(dataset,
                                     batch_size=batch_size,
                                     # sampler=train_sampler if (split == 'train') else DistributedSampler(dataset, shuffle=False),
                                     pin_memory=False,
                                     # persistent_workers=True,
                                     drop_last=True if (split == 'train') else False,
                                     num_workers=num_workers,
                                     collate_fn=collate_fn)
                        for split, dataset in datasets.items()}

    return dataloaders #train_sampler, dataloaders

if __name__ == '__main__':
    # train_sampler, dataloaders = retrieve_dataloaders(32, 100)
    dataloaders = retrieve_dataloaders(batch_size=16, num_data = 20, use_one_hot = True)
    for (label, p4s, nodes, atom_mask, edge_mask, edges) in dataloaders['train']:
        print(label.shape, p4s.shape, nodes.shape, atom_mask.shape,
              edge_mask.shape, edges[0].shape, edges[1].shape)
        break

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]])
torch.Size([16]) torch.Size([16, 139, 4]) torch.Size([16, 139, 8]) torch.Size([16, 139]) torch.Size([16, 139, 139]) torch.Size([28736]) torch.Size([28736])


In [37]:
# Test the first batch
for label, p4s, nodes, atom_mask, edge_mask, edges in dataloaders["train"]:
    print(f"Label shape: {label.shape}")
    print(f"4-momenta shape: {p4s.shape}")
    print(f"Node features shape: {nodes.shape}")
    print(f"Atom mask shape: {atom_mask.shape}")
    print(f"Edge mask shape: {edge_mask.shape}")
    print(f"Edge indices shapes: {edges[0].shape}, {edges[1].shape}")
    break

Label shape: torch.Size([16])
4-momenta shape: torch.Size([16, 139, 4])
Node features shape: torch.Size([16, 139, 8])
Atom mask shape: torch.Size([16, 139])
Edge mask shape: torch.Size([16, 139, 139])
Edge indices shapes: torch.Size([28736]), torch.Size([28736])


In [38]:
import torch
import numpy as np
import energyflow
import os
from sklearn.preprocessing import OneHotEncoder
from scipy.sparse import coo_matrix

def save_physics_tensors(num_data=-1, use_one_hot=False, save_dir="random/data"):
    """
    Generate and save tensor data files needed for physics analysis.

    Args:
        num_data: Number of data points to generate (-1 for all)
        save_dir: Directory to save the tensor files
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Load raw data
    raw = energyflow.qg_jets.load(
        num_data=num_data,
        pad=True,
        ncol=4,
        generator="pythia",
        with_bc=False,
    )

    # Get data and labels
    data, labels = raw

    # Initialize one-hot encoder for particle IDs
    enc = OneHotEncoder(handle_unknown="ignore").fit(
        [[11], [13], [22], [130], [211], [321], [2112], [2212]]
    )

    # Process data
    pid = torch.from_numpy(np.abs(np.asarray(data[..., 3], dtype=int))).unsqueeze(-1)
    p4s = torch.from_numpy(energyflow.p4s_from_ptyphipids(data, error_on_unknown=True))

    # Create one-hot encoded nodes
    one_hot = enc.transform(pid.reshape(-1, 1)).toarray().reshape(pid.shape[:2] + (-1,))
    nodes = torch.from_numpy(one_hot)
    mass = torch.from_numpy(energyflow.ms_from_p4s(p4s)).unsqueeze(-1)
    charge = torch.from_numpy(energyflow.pids2chrgs(pid))

    if use_one_hot:
        nodes = one_hot

    else:
          # Concatenate mass and charge along the last dimension
          concatenated = torch.cat((mass, charge), dim=-1)  # Shape (batch_size, n_nodes, 2)

          # Reduce along the last dimension (e.g., by summing or averaging)
          nodes = concatenated.sum(dim=-1, keepdim=True)  # Shape (batch_size, n_nodes, 1)

          # Apply log-sign transformation if needed
          nodes = torch.sign(nodes) * torch.log(torch.abs(nodes) + 1)

    # Create masks
    atom_mask = (pid[..., 0] != 0).to(torch.bool)

    # Create edge mask
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask = edge_mask * diag_mask

    # Convert labels to tensor
    labels = torch.from_numpy(labels)

    # Calculate edges for the full dataset
    n_nodes = p4s.size(1)
    batch_size = p4s.size(0)

    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx * n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    edges = np.stack([rows, cols])

    # Save tensors
    torch.save(p4s, os.path.join(save_dir, "p4s.pt"))
    torch.save(nodes, os.path.join(save_dir, "nodes.pt"))
    torch.save(labels, os.path.join(save_dir, "labels.pt"))
    torch.save(atom_mask, os.path.join(save_dir, "atom_mask.pt"))
    np.save(os.path.join(save_dir, "edge_mask.npy"), edge_mask.numpy())
    np.save(os.path.join(save_dir, "edges.npy"), edges)

    print(f"Saved tensor files to {save_dir}")
    print(f"Shapes:")
    print(f"p4s: {p4s.shape}")
    print(f"nodes: {nodes.shape}")
    print(f"labels: {labels.shape}")
    print(f"atom_mask: {atom_mask.shape}")
    print(f"edge_mask: {edge_mask.shape}")
    print(f"edges: {edges.shape}")

# Generate and save the tensor files
save_physics_tensors(num_data=1000, use_one_hot=False)  # Use same number of data points as before

Saved tensor files to random/data
Shapes:
p4s: torch.Size([1000, 139, 4])
nodes: torch.Size([1000, 139, 1])
labels: torch.Size([1000])
atom_mask: torch.Size([1000, 139])
edge_mask: torch.Size([1000, 139, 139])
edges: (2, 2145950)


In [39]:
from torch.utils.data import TensorDataset, random_split

def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    # edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    # diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    # edge_mask *= diag_mask

    edge_mask = data[-2]

    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edges]


p4s = torch.load('random/data/p4s.pt')
nodes = torch.load('random/data/nodes.pt')
labels = torch.load('random/data/labels.pt')
atom_mask = torch.load('random/data/atom_mask.pt')
edge_mask = torch.from_numpy(np.load('random/data/edge_mask.npy'))
edges = torch.from_numpy(np.load('random/data/edges.npy'))


# Create a TensorDataset
dataset_all = TensorDataset(labels, p4s, nodes, atom_mask, edge_mask)

# Define the split ratios
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

# Calculate the lengths for each split
total_size = len(dataset_all)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
test_size = total_size - train_size - val_size  # Ensure all data is used

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_all, [train_size, val_size, test_size])

# Create a dictionary to hold the datasets
datasets = {
    "train": train_dataset,
    "val": val_dataset,
    "test": test_dataset
}

dataloaders = {split: DataLoader(dataset,
                                 batch_size=16,
                                 # sampler=train_sampler if (split == 'train') else DistributedSampler(dataset, shuffle=False),
                                 pin_memory=False,
                                 # persistent_workers=True,
                                 collate_fn = collate_fn,
                                 drop_last=True if (split == 'train') else False,
                                 num_workers=0)
                    for split, dataset in datasets.items()}

  p4s = torch.load('random/data/p4s.pt')
  nodes = torch.load('random/data/nodes.pt')
  labels = torch.load('random/data/labels.pt')
  atom_mask = torch.load('random/data/atom_mask.pt')


In [40]:
# # we can peek at a batch to see what it looks like.
# next(iter(dataloaders['val']))

In [41]:
print(p4s.shape) # p4s
print(nodes.shape) # mass
print(atom_mask.shape) # torch.ones
print(edge_mask.shape) # adj_matrix

torch.Size([1000, 139, 4])
torch.Size([1000, 139, 1])
torch.Size([1000, 139])
torch.Size([1000, 139, 139])


In [42]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x794093db6200>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x794093c5ebf0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x794093c5f7c0>}

In [43]:
# Set desired dimensions
batch_size = 1
n_nodes = 3
device = 'cpu'
dtype = torch.float32

# Print initial shapes
print("Initial shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Select subset of data
p4s = p4s[:batch_size, :n_nodes, :]
atom_mask = atom_mask[:batch_size, :n_nodes]
edge_mask = edge_mask[:batch_size, :n_nodes, :n_nodes]
nodes = nodes[:batch_size, :n_nodes, :]

print("\nAfter selection shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Reshape tensors
atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
# Don't reshape edge_mask yet
nodes = nodes.view(batch_size * n_nodes, -1).to(device, dtype)

print("\nAfter reshape shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)  # original shape
print("nodes:", nodes.shape)

# Recalculate edges for the subset
from scipy.sparse import coo_matrix
rows, cols = [], []
for batch_idx in range(batch_size):
    nn = batch_idx * n_nodes
    # Convert edge_mask to numpy and remove any extra dimensions
    edge_mask_np = edge_mask[batch_idx].cpu().numpy().squeeze()
    x = coo_matrix(edge_mask_np)
    rows.append(nn + x.row)
    cols.append(nn + x.col)

edges = [torch.LongTensor(np.concatenate(rows)).to(device),
         torch.LongTensor(np.concatenate(cols)).to(device)]

# Now reshape edge_mask after edges are calculated
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])

Initial shapes:
p4s: torch.Size([1000, 139, 4])
atom_mask: torch.Size([1000, 139])
edge_mask: torch.Size([1000, 139, 139])
nodes: torch.Size([1000, 139, 1])

After selection shapes:
p4s: torch.Size([1, 3, 4])
atom_mask: torch.Size([1, 3])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([1, 3, 1])

After reshape shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([3, 1])

Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


In [44]:
batch_size = 1  #2500 #1
n_nodes = 3 #139
device = 'cpu'
dtype = torch.float32

atom_positions = p4s[:, :, :].view(batch_size * n_nodes, -1).to(device, dtype)

atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

edges = [a.to(device) for a in edges]
nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)

In [45]:
print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])


Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


In [46]:
atom_mask[0]#.shape

tensor([1.])

In [47]:
p4s[0][2]

tensor([ 1.1594, -0.2378, -1.1238, -0.0723], dtype=torch.float64)

In [48]:
atom_mask.shape

torch.Size([3, 1])

In [49]:
p4s.shape # batch_size (number of jets or graphs), n_nodes (particles), n_features

torch.Size([1, 3, 4])

In [50]:
# random: x(atom_pos), edge_indx_tensor (edges = adj_matrix), edge_tensor (edge_mask = adj_matrix)
print("Atom mask: {}".format(atom_mask[:2]))
print("Atom positions (x features, 4-momenta): {}".format(atom_positions[:2]))
print("Nodes (scalars: mass & charge): {}".format(nodes[:2]))
print("Edge mask: {}".format(edge_mask[:2]))
print("Edges: {}".format(edges[:2]))

Atom mask: tensor([[1.],
        [1.]])
Atom positions (x features, 4-momenta): tensor([[ 0.2861,  0.0078, -0.2687,  0.0980],
        [ 0.1653, -0.0258, -0.1580, -0.0414]])
Nodes (scalars: mass & charge): tensor([[-4.7488e-09],
        [-2.2813e-09]])
Edge mask: tensor([[False],
        [ True]])
Edges: [tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 2, 0, 1])]


In [51]:
edges[:2]#[0].shape

[tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 2, 0, 1])]

In [52]:
# model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
#                          edge_mask=edge_mask, n_nodes=n_nodes)

In [68]:
# @title
import torch
from torch import nn
import numpy as np



"""Some auxiliary functions"""

def unsorted_segment_sum(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_mean`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    count = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    count.index_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

def normsq4(p):
    r''' Minkowski square norm
         `\|p\|^2 = p[0]^2-p[1]^2-p[2]^2-p[3]^2`
    '''
    psq = torch.pow(p, 2)
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def dotsq4(p,q):
    r''' Minkowski inner product
         `<p,q> = p[0]q[0]-p[1]q[1]-p[2]q[2]-p[3]q[3]`
    '''
    psq = p*q
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def normA_fn(A):
    return lambda p: torch.einsum('...i, ij, ...j->...', p, A, p)

def dotA_fn(A):
    return lambda p, q: torch.einsum('...i, ij, ...j->...', p, A, q)

def psi(p):
    ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`
    '''
    return torch.sign(p) * torch.log(torch.abs(p) + 1)


"""Lorentz Group-Equivariant Block"""

class LGEB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LGEB, self).__init__()
        self.c_weight = c_weight
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        self.dimension_reducer = nn.Linear(n_input * 2 + n_edge_attr, n_input) # New linear layer for dimension reduction
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2.
        print('Input size of phi_e: ', n_input)

        self.include_x = include_x
        self.phi_e = nn.Sequential(
            nn.Linear(n_input, n_hidden, bias=False), # n_input * 2 + n_edge_attr
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU())

        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4


    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        # Reduce the dimension of 'out' to 4 using a linear layer
        out = self.dimension_reducer(out)
        out = self.phi_e(out)
        print("m_model output: ", out.shape)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.phi_e(out)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m): # norms
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # print("m: ", m.shape)
        # print("trans: ", trans.shape)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        # print("trans: ", trans.shape)
        # print("x.size: ", x.size(0))
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight # * norms[i, j], smth like that, or norms
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            # print("X: ", x)
            x = self.x_model(x, edges, x_diff, m)
            # print("phi_x(X) = ", x, '\n---\n')

        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LorentzNet(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LGEB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LorentzNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LGEBs = nn.ModuleList([LGEB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LGEBs[i](h, x, edges, node_attr=scalars)
        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)

        # print("Final preds: \n", pred.cpu().detach().numpy())
        return pred.squeeze(1)

LGEB(self.n_hidden, self.n_hidden, self.n_hidden,\
                                    n_node_attr=n_scalar, dropout=dropout,\
                                    c_weight=c_weight, last_layer=\(i==n_layers-1), A=A, include_x=include_x)
                                    
We are using n_hidden = 4 and n_layers = 6

n_input=n_hidden, n_output=n_hidden, n_hidden=n_hidden, n_node_attr=n_scalar=8

### Now that we have the official code for the classical, just for sanity checking, let's test for equivariance

The cell below is just an auxiliary function to give us the boosts

In [69]:
# @title
from math import sqrt
import numpy as np

# Speed of light (m/s)
c = 299792458

"""Lorentz transformations describe the transition between two inertial reference
frames F and F', each of which is moving in some direction with respect to the
other. This code only calculates Lorentz transformations for movement in the x
direction with no spatial rotation (i.e., a Lorentz boost in the x direction).
The Lorentz transformations are calculated here as linear transformations of
four-vectors [ct, x, y, z] described by Minkowski space. Note that t (time) is
multiplied by c (the speed of light) in the first entry of each four-vector.

Thus, if X = [ct; x; y; z] and X' = [ct'; x'; y'; z'] are the four-vectors for
two inertial reference frames and X' moves in the x direction with velocity v
with respect to X, then the Lorentz transformation from X to X' is X' = BX,
where

    | γ  -γβ  0  0|
B = |-γβ  γ   0  0|
    | 0   0   1  0|
    | 0   0   0  1|

is the matrix describing the Lorentz boost between X and X',
γ = 1 / √(1 - v²/c²) is the Lorentz factor, and β = v/c is the velocity as
a fraction of c.
"""


def beta(velocity: float) -> float:
    """
    Calculates β = v/c, the given velocity as a fraction of c
    >>> beta(c)
    1.0
    >>> beta(199792458)
    0.666435904801848
    """
    if velocity > c:
        raise ValueError("Speed must not exceed light speed 299,792,458 [m/s]!")
    elif velocity < 1:
        # Usually the speed should be much higher than 1 (c order of magnitude)
        raise ValueError("Speed must be greater than or equal to 1!")

    return velocity / c


def gamma(velocity: float) -> float:
    """
    Calculate the Lorentz factor γ = 1 / √(1 - v²/c²) for a given velocity
    >>> gamma(4)
    1.0000000000000002
    >>> gamma(1e5)
    1.0000000556325075
    >>> gamma(3e7)
    1.005044845777813
    >>> gamma(2.8e8)
    2.7985595722318277
    """
    return 1 / sqrt(1 - beta(velocity) ** 2)


def transformation_matrix(velocity: float) -> np.ndarray:
    """
    Calculate the Lorentz transformation matrix for movement in the x direction:

    | γ  -γβ  0  0|
    |-γβ  γ   0  0|
    | 0   0   1  0|
    | 0   0   0  1|

    where γ is the Lorentz factor and β is the velocity as a fraction of c
    >>> transformation_matrix(29979245)
    array([[ 1.00503781, -0.10050378,  0.        ,  0.        ],
           [-0.10050378,  1.00503781,  0.        ,  0.        ],
           [ 0.        ,  0.        ,  1.        ,  0.        ],
           [ 0.        ,  0.        ,  0.        ,  1.        ]])
    """
    return np.array(
        [
            [gamma(velocity), -gamma(velocity) * beta(velocity), 0, 0],
            [-gamma(velocity) * beta(velocity), gamma(velocity), 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ]
    )


### Now, the model

In [70]:
# n_scalar = 8 in original !
model = LorentzNet(n_scalar = 1, n_hidden = 6, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

Input size of phi_e:  6


### Let's start with a default prediction

In [71]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

m_model output:  torch.Size([6, 6])


In [72]:
pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

m_model output:  torch.Size([6, 6])


### ... taking any random nonsense transformation in the four-momentum vectors
i.e.: multiplying by 0.1. Does the hidden rep stay the same?

In [73]:
pred = model(scalars=nodes, x= 0.1 * atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

m_model output:  torch.Size([6, 6])


In [74]:
pred = model(scalars=nodes, x= 0.1 * atom_positions, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

m_model output:  torch.Size([6, 6])


### Even though the final logits in this case wasn't different, if we look the last output of h (which contains both scalar and 4-momenta information), it changed! Now, what about Lorentz transformations?

In [75]:
pred = model(scalars=nodes, x= (torch.tensor(transformation_matrix(220000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

m_model output:  torch.Size([6, 6])


In [76]:
pred = model(scalars=nodes, x= (torch.tensor(transformation_matrix(220000000)) @ atom_positions.to(dtype=torch.float64).T).to(dtype=torch.float32).T, edges=edges, node_mask=atom_mask,
                     edge_mask=edge_mask, n_nodes=n_nodes)

m_model output:  torch.Size([6, 6])


## Equivariance works. Finally, let's train on some data

In [77]:
# @title
import torch
import os, json, random, string
import torch.distributed as dist

def makedir(path):
    try:
        os.makedirs(path)
    except OSError:
        pass

def args_init(args):
    r''' Initialize seed and exp_name.
    '''
    if args.seed is None: # use random seed if not specified
        args.seed = np.random.randint(100)
    if args.exp_name == '': # use random strings if not specified
        args.exp_name = ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
    if (args.local_rank == 0): # master
        print(args)
        makedir(f"{args.logdir}/{args.exp_name}")
        with open(f"{args.logdir}/{args.exp_name}/args.json", 'w') as f:
            json.dump(args.__dict__, f, indent=4)

def sum_reduce(num, device):
    r''' Sum the tensor across the devices.
    '''
    if not torch.is_tensor(num):
        rt = torch.tensor(num).to(device)
    else:
        rt = num.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        warmup_epoch: target learning rate is reached at warmup_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    Reference:
        https://github.com/ildoonet/pytorch-gradual-warmup-lr
    """

    def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.warmup_epoch = warmup_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    @property
    def _warmup_lr(self):
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch + 1) / self.warmup_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * (self.last_epoch + 1) / self.warmup_epoch + 1.) for base_lr in self.base_lrs]

    def get_lr(self):
        if self.last_epoch >= self.warmup_epoch - 1:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        return self._warmup_lr

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        self.last_epoch = self.last_epoch + 1 if epoch==None else epoch
        if self.last_epoch >= self.warmup_epoch - 1:
            if not self.finished:
                warmup_lr = [base_lr * self.multiplier for base_lr in self.base_lrs]
                for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                    param_group['lr'] = lr
                self.finished = True
                return
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.warmup_epoch)
            return

        for param_group, lr in zip(self.optimizer.param_groups, self._warmup_lr):
            param_group['lr'] = lr

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.warmup_epoch)
                self.last_epoch = self.after_scheduler.last_epoch + self.warmup_epoch + 1
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        result = {key: value for key, value in self.__dict__.items() if key != 'optimizer' or key != "after_scheduler"}
        if self.after_scheduler:
            result.update({"after_scheduler": self.after_scheduler.state_dict()})
        return result

    def load_state_dict(self, state_dict):
        after_scheduler_state = state_dict.pop("after_scheduler", None)
        self.__dict__.update(state_dict)
        if after_scheduler_state:
            self.after_scheduler.load_state_dict(after_scheduler_state)


from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

def buildROC(labels, score, targetEff=[0.3,0.5]):
    r''' ROC curve is a plot of the true positive rate (Sensitivity) in the function of the false positive rate
    (100-Specificity) for different cut-off points of a parameter. Each point on the ROC curve represents a
    sensitivity/specificity pair corresponding to a particular decision threshold. The Area Under the ROC
    curve (AUC) is a measure of how well a parameter can distinguish between two diagnostic groups.
    '''
    if not isinstance(targetEff, list):
        targetEff = [targetEff]
    fpr, tpr, threshold = roc_curve(labels, score)
    idx = [np.argmin(np.abs(tpr - Eff)) for Eff in targetEff]
    eB, eS = fpr[idx], tpr[idx]
    return fpr, tpr, threshold, eB, eS

In [79]:
# import os
# import torch
# from torch import nn, optim
# import json, time
# # import utils_lorentz
# import numpy as np
# import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel
# from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# from tqdm import tqdm

# def run(model, epoch, loader, partition, N_EPOCHS=None):
#     if partition == 'train':
#         model.train()
#     else:
#         model.eval()

#     res = {'time':0, 'correct':0, 'loss': 0, 'counter': 0, 'acc': 0,
#            'loss_arr':[], 'correct_arr':[],'label':[],'score':[]}

#     tik = time.time()
#     loader_length = len(loader)

#     for i, (label, p4s, nodes, atom_mask, edge_mask, edges) in tqdm(enumerate(loader)):
#         if partition == 'train':
#             optimizer.zero_grad()

#         batch_size, n_nodes, _ = p4s.size()
#         atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
#         atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device)
#         edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)
#         nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)
#         edges = [a.to(device) for a in edges]
#         label = label.to(device, dtype).long()

#         pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
#                          edge_mask=edge_mask, n_nodes=n_nodes)

#         predict = pred.max(1).indices
#         correct = torch.sum(predict == label).item()
#         loss = loss_fn(pred, label)

#         if partition == 'train':
#             loss.backward()
#             optimizer.step()
#         elif partition == 'test':
#             # save labels and probilities for ROC / AUC
#             # print("Preds ", pred)
#             score = torch.nn.functional.softmax(pred, dim = -1)
#             # print("Score test ", score)
#             # raise
#             res['label'].append(label)
#             res['score'].append(score)

#         res['time'] = time.time() - tik
#         res['correct'] += correct
#         res['loss'] += loss.item() * batch_size
#         res['counter'] += batch_size
#         res['loss_arr'].append(loss.item())
#         res['correct_arr'].append(correct)

#         # if i != 0 and i % args.log_interval == 0:

#     running_loss = sum(res['loss_arr'])/len(res['loss_arr'])
#     running_acc = sum(res['correct_arr'])/(len(res['correct_arr'])*batch_size)
#     avg_time = res['time']/res['counter'] * batch_size
#     tmp_counter = res['counter']
#     tmp_loss = res['loss'] / tmp_counter
#     tmp_acc = res['correct'] / tmp_counter

#     if N_EPOCHS:
#         print(">> %s \t Epoch %d/%d \t Batch %d/%d \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
#              (partition, epoch + 1, N_EPOCHS, i, loader_length, running_loss, running_acc, tmp_acc, avg_time))
#     else:
#         print(">> %s \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
#              (partition, running_loss, running_acc, tmp_acc, avg_time))

#     torch.cuda.empty_cache()
#     # ---------- reduce -----------
#     if partition == 'test':
#         res['label'] = torch.cat(res['label']).unsqueeze(-1)
#         res['score'] = torch.cat(res['score'])
#         res['score'] = torch.cat((res['label'],res['score']),dim=-1)
#     res['counter'] = res['counter']
#     res['loss'] = res['loss'] / res['counter']
#     res['acc'] = res['correct'] / res['counter']
#     return res

# def train(model, res, N_EPOCHS, model_path, log_path):
#     ### training and validation
#     os.makedirs(model_path, exist_ok=True)
#     os.makedirs(log_path, exist_ok=True)

#     for epoch in range(N_EPOCHS):
#         train_res = run(model, epoch, dataloaders['train'], partition='train', N_EPOCHS = N_EPOCHS)
#         print("Time: train: %.2f \t Train loss %.4f \t Train acc: %.4f" % (train_res['time'],train_res['loss'],train_res['acc']))
#         # if epoch % args.val_interval == 0:

#         # if (args.local_rank == 0):
#         torch.save(model.state_dict(), os.path.join(model_path, "checkpoint-epoch-{}.pt".format(epoch)) )
#         with torch.no_grad():
#             val_res = run(model, epoch, dataloaders['val'], partition='val')

#         # if (args.local_rank == 0): # only master process save
#         res['lr'].append(optimizer.param_groups[0]['lr'])
#         res['train_time'].append(train_res['time'])
#         res['val_time'].append(val_res['time'])
#         res['train_loss'].append(train_res['loss'])
#         res['train_acc'].append(train_res['acc'])
#         res['val_loss'].append(val_res['loss'])
#         res['val_acc'].append(val_res['acc'])
#         res['epochs'].append(epoch)

#         ## save best model
#         if val_res['acc'] > res['best_val']:
#             print("New best validation model, saving...")
#             torch.save(model.state_dict(), os.path.join(model_path,"best-val-model.pt"))
#             res['best_val'] = val_res['acc']
#             res['best_epoch'] = epoch

#         print("Epoch %d/%d finished." % (epoch, N_EPOCHS))
#         print("Train time: %.2f \t Val time %.2f" % (train_res['time'], val_res['time']))
#         print("Train loss %.4f \t Train acc: %.4f" % (train_res['loss'], train_res['acc']))
#         print("Val loss: %.4f \t Val acc: %.4f" % (val_res['loss'], val_res['acc']))
#         print("Best val acc: %.4f at epoch %d." % (res['best_val'],  res['best_epoch']))

#         json_object = json.dumps(res, indent=4)
#         with open(os.path.join(log_path, "train-result-epoch{}.json".format(epoch)), "w") as outfile:
#             outfile.write(json_object)

#         ## adjust learning rate
#         if (epoch < 31):
#             lr_scheduler.step(metrics=val_res['acc'])
#         else:
#             for g in optimizer.param_groups:
#                 g['lr'] = g['lr']*0.5


# def test(model, res, model_path, log_path):
#     ### test on best model
#     best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)
#     model.load_state_dict(best_model)
#     with torch.no_grad():
#         test_res = run(model, 0, dataloaders['test'], partition='test')

#     print("Final ", test_res['score'])
#     pred = test_res['score'].cpu()

#     np.save(os.path.join(log_path, "score.npy"), pred)
#     fpr, tpr, thres, eB, eS  = buildROC(pred[...,0], pred[...,2])
#     auc = roc_auc_score(pred[...,0], pred[...,2])

#     metric = {'test_loss': test_res['loss'], 'test_acc': test_res['acc'],
#               'test_auc': auc, 'test_1/eB_0.3':1./eB[0],'test_1/eB_0.5':1./eB[1]}
#     res.update(metric)
#     print("Test: Loss %.4f \t Acc %.4f \t AUC: %.4f \t 1/eB 0.3: %.4f \t 1/eB 0.5: %.4f"\
#            % (test_res['loss'], test_res['acc'], auc, 1./eB[0], 1./eB[1]))
#     json_object = json.dumps(res, indent=4)
#     with open(os.path.join(log_path, "test-result.json"), "w") as outfile:
#         outfile.write(json_object)

# if __name__ == "__main__":

#     N_EPOCHS = 45 # 60

#     model_path = "models/LorentzNet/"
#     log_path = "logs/LorentzNet/"
#     # args_init(args)

#     ### set random seed
#     torch.manual_seed(42)
#     np.random.seed(42)

#     ### initialize cuda
#     # dist.init_process_group(backend='nccl')
#     device = 'cpu' #torch.device("cpu")
#     dtype = torch.float32

#     ### load data
#     # dataloaders = retrieve_dataloaders( batch_size,
#     #                                     num_data=100000, # use all data
#     #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
#     #                                     num_workers=0,
#     #                                     use_one_hot=True)

#     ### create parallel model
#     model = LorentzNet(n_scalar = 1, n_hidden = 6, n_class = 2,\
#                        dropout = 0.2, n_layers = 1,\
#                        c_weight = 1e-3)

#     model = model.to(device)

#     ### print model and dataset information
#     # if (args.local_rank == 0):
#     pytorch_total_params = sum(p.numel() for p in model.parameters())
#     print("Model Size:", pytorch_total_params)
#     for (split, dataloader) in dataloaders.items():
#         print(f" {split} samples: {len(dataloader.dataset)}")

#     ### optimizer
#     optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

#     ### lr scheduler
#     base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
#     lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
#                                                 warmup_epoch=5,\
#                                                 after_scheduler=base_scheduler) ## warmup

#     ### loss function
#     loss_fn = nn.CrossEntropyLoss()

#     ### initialize logs
#     res = {'epochs': [], 'lr' : [],\
#            'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
#            'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

#     ### training and testing
#     print("Training...")
#     train(model, res, N_EPOCHS, model_path, log_path)
#     test(model, res, model_path, log_path)

In [89]:
import torch
import pennylane as qml
import torch.nn.functional as F
from torch import nn
from torch_geometric.utils import to_dense_adj

n_qubits = 8

dev = qml.device('default.qubit', wires=n_qubits)
# dev = qml.device("qiskit.aer", wires=n_qubits)


def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)

def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)

def RY_RX_layer(weights):
    """Applies a layer of parametrized RY and RX rotations."""
    for i, w in enumerate(weights):
        qml.RY(w, wires=i)
        qml.RX(w, wires=i)

def full_entangling_layer(n_qubits):
    """Applies CNOT gates between all pairs of qubits."""
    for i in range(n_qubits):
        for j in range(i+1, n_qubits):
            qml.CNOT(wires=[i, j])

def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(nqubits - 1):
        qml.CRZ(np.pi / 2, wires=[i, i + 1])
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.SWAP(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.SWAP(wires=[i, i + 1])


@qml.qnode(dev, interface="torch")
def quantum_net(q_input_features, q_weights_flat, q_depth, n_qubits):
    """
    The variational quantum circuit.
    """

    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    # RY_layer(q_input_features)
    qml.AngleEmbedding(features=q_input_features, wires=range(n_qubits), rotation='Z')

    # Sequence of trainable variational layers
    # for k in range(q_depth):
    #     entangling_layer(n_qubits)
    #     RY_RX_layer(q_weights[k])
    #     # RY_layer(q_weights[k])
    for k in range(q_depth):
        if k % 2 == 0:
            entangling_layer(n_qubits)
            RY_layer(q_weights[k])
        else:
            full_entangling_layer(n_qubits)
            RY_RX_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self, n_qubits, q_depth = 1, q_delta=0.001):
        """
        Definition of the *dressed* layout.
        """
        print('n_qubits: ', n_qubits)
        super().__init__()
        self.n_qubits = n_qubits
        self.q_depth = q_depth
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))

    def forward(self, input_features):
        """
        Optimized forward pass to reduce runtime.
        """

        # Quantum Embedding (U(X))
        q_in = torch.tanh(input_features) * np.pi / 2.0

        # Preallocate output tensor
        batch_size = q_in.shape[0]
        q_out = torch.zeros(batch_size, self.n_qubits, device=q_in.device)

        # Vectorized execution
        for i, elem in enumerate(q_in):
            q_out_elem = torch.hstack(quantum_net(elem, self.q_params, self.q_depth, self.n_qubits)).float()
            q_out[i] = q_out_elem

        return q_out

In [90]:
# @title
import torch
from torch import nn
import numpy as np
import pennylane as qml

"""
    Lie-Equivariant Quantum Block (LEQB).

        - Given the Lie generators found (i.e.: through LieGAN, oracle-preserving latent flow, or some other approach
          that we develop further), once the metric tensor J is found via the equation:

                          L.J + J.(L^T) = 0,

          we just have to specify the metric to make the model symmetry-preserving to the corresponding Lie group.
          In the cells below, we can see how the model preserves symmetries (starting with the default Lorentz group),
          and when we change J to some other metric (Euclidean, for example), Lorentz boosts **break** equivariance, while other
          transformations preserve it (rotations, for the example shown in the cells below)
"""
class LEQB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LEQB, self).__init__()
        self.c_weight = c_weight
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        self.dimension_reducer = nn.Linear(n_input * 2 + n_edge_attr, n_input) # New linear layer for dimension reduction
        self.dimension_reducer2 = nn.Linear(n_input * 2 + n_edge_attr - 1, n_input) # New linear layer for dimension reduction for phi_h
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2. n_input = n_hidden = 4
        print('Input size of phi_e: ', n_input)
        self.include_x = include_x

        """
            phi_e: input size: n_qubits -> output size: n_qubits
            n_hidden has to be equal to n_input,
            but this is just considering that this is a simple working example.
        """
        self.phi_e = DressedQuantumNet(n_input)
        # self.phi_e = nn.Sequential(
        #     nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),  # n_input * 2 + n_edge_attr
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU())

        n_hidden = n_input # n_input * 2 + n_edge_attr
        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        # self.phi_h = DressedQuantumNet(n_hidden)

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)
        
#         self.phi_x = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             layer)

#         self.phi_m = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             nn.Linear(n_hidden, 1),
#             nn.Sigmoid())
        
        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        # self.phi_e = nn.Sequential(
        #     nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4

    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        out = self.dimension_reducer(out) # extra
        # print("Before embedding to |psi> : ", out)
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.dimension_reducer(out) # extra
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        # agg = self.dimension_reducer2(agg) # extra for phi_h
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m):
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            x = self.x_model(x, edges, x_diff, m)
        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LieEQGNN(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LEQB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LieEQGNN, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LEQBs = nn.ModuleList([LEQB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LEQBs[i](h, x, edges, node_attr=scalars)

        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)

In [91]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

if __name__ == "__main__":

    N_EPOCHS =  25 # 60

    model_path = "models/LieEQGNN/"
    log_path = "logs/LieEQGNN/"
    # utils_lorentz.args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cpu
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cuda")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    model = LieEQGNN(n_scalar = 1, n_hidden = 8, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)

Input size of phi_e:  8
n_qubits:  8
Model Size: 651
 train samples: 800
 val samples: 100
 test samples: 100
Training...


50it [20:11, 24.22s/it]


>> train 	 Epoch 1/25 	 Batch 49/50 	 Loss 0.7081 	 Running Acc 0.506 	 Total Acc 0.506 	 Avg Batch Time 24.2203
Time: train: 1211.01 	 Train loss 0.7081 	 Train acc: 0.5062


7it [01:28, 12.61s/it]


>> val 	 Loss 0.7011 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 3.5311
New best validation model, saving...
Epoch 0/25 finished.
Train time: 1211.01 	 Val time 88.28
Train loss 0.7081 	 Train acc: 0.5062
Val loss: 0.7012 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 0.


50it [19:45, 23.71s/it]


>> train 	 Epoch 2/25 	 Batch 49/50 	 Loss 0.6931 	 Running Acc 0.509 	 Total Acc 0.509 	 Avg Batch Time 23.7085
Time: train: 1185.43 	 Train loss 0.6931 	 Train acc: 0.5088


7it [01:26, 12.32s/it]


>> val 	 Loss 0.6863 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 3.4487
Epoch 1/25 finished.
Train time: 1185.43 	 Val time 86.22
Train loss 0.6931 	 Train acc: 0.5088
Val loss: 0.6864 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 0.


50it [19:54, 23.88s/it]


>> train 	 Epoch 3/25 	 Batch 49/50 	 Loss 0.6691 	 Running Acc 0.641 	 Total Acc 0.641 	 Avg Batch Time 23.8839
Time: train: 1194.20 	 Train loss 0.6691 	 Train acc: 0.6412


7it [01:29, 12.72s/it]


>> val 	 Loss 0.6605 	 Running Acc 2.714 	 Total Acc 0.760 	 Avg Batch Time 3.5628
New best validation model, saving...
Epoch 2/25 finished.
Train time: 1194.20 	 Val time 89.07
Train loss 0.6691 	 Train acc: 0.6412
Val loss: 0.6606 	 Val acc: 0.7600
Best val acc: 0.7600 at epoch 2.


50it [19:47, 23.76s/it]


>> train 	 Epoch 4/25 	 Batch 49/50 	 Loss 0.6460 	 Running Acc 0.691 	 Total Acc 0.691 	 Avg Batch Time 23.7569
Time: train: 1187.84 	 Train loss 0.6460 	 Train acc: 0.6913


7it [01:27, 12.46s/it]


>> val 	 Loss 0.6449 	 Running Acc 2.429 	 Total Acc 0.680 	 Avg Batch Time 3.4891
Epoch 3/25 finished.
Train time: 1187.84 	 Val time 87.23
Train loss 0.6460 	 Train acc: 0.6913
Val loss: 0.6454 	 Val acc: 0.6800
Best val acc: 0.7600 at epoch 2.


50it [19:55, 23.91s/it]


>> train 	 Epoch 5/25 	 Batch 49/50 	 Loss 0.6251 	 Running Acc 0.669 	 Total Acc 0.669 	 Avg Batch Time 23.9081
Time: train: 1195.41 	 Train loss 0.6251 	 Train acc: 0.6687


7it [01:26, 12.39s/it]


>> val 	 Loss 0.6357 	 Running Acc 2.357 	 Total Acc 0.660 	 Avg Batch Time 3.4696
Epoch 4/25 finished.
Train time: 1195.41 	 Val time 86.74
Train loss 0.6251 	 Train acc: 0.6687
Val loss: 0.6363 	 Val acc: 0.6600
Best val acc: 0.7600 at epoch 2.


50it [19:42, 23.65s/it]


>> train 	 Epoch 6/25 	 Batch 49/50 	 Loss 0.6265 	 Running Acc 0.693 	 Total Acc 0.693 	 Avg Batch Time 23.6463
Time: train: 1182.32 	 Train loss 0.6265 	 Train acc: 0.6925


7it [01:26, 12.39s/it]


>> val 	 Loss 0.6287 	 Running Acc 2.429 	 Total Acc 0.680 	 Avg Batch Time 3.4689
Epoch 5/25 finished.
Train time: 1182.32 	 Val time 86.72
Train loss 0.6265 	 Train acc: 0.6925
Val loss: 0.6292 	 Val acc: 0.6800
Best val acc: 0.7600 at epoch 2.


50it [19:59, 24.00s/it]


>> train 	 Epoch 7/25 	 Batch 49/50 	 Loss 0.6108 	 Running Acc 0.731 	 Total Acc 0.731 	 Avg Batch Time 23.9950
Time: train: 1199.75 	 Train loss 0.6108 	 Train acc: 0.7312


7it [01:27, 12.45s/it]


>> val 	 Loss 0.6249 	 Running Acc 2.464 	 Total Acc 0.690 	 Avg Batch Time 3.4851
Epoch 6/25 finished.
Train time: 1199.75 	 Val time 87.13
Train loss 0.6108 	 Train acc: 0.7312
Val loss: 0.6257 	 Val acc: 0.6900
Best val acc: 0.7600 at epoch 2.


50it [19:59, 24.00s/it]


>> train 	 Epoch 8/25 	 Batch 49/50 	 Loss 0.6111 	 Running Acc 0.725 	 Total Acc 0.725 	 Avg Batch Time 23.9984
Time: train: 1199.92 	 Train loss 0.6111 	 Train acc: 0.7250


7it [01:28, 12.68s/it]


>> val 	 Loss 0.6236 	 Running Acc 2.464 	 Total Acc 0.690 	 Avg Batch Time 3.5518
Epoch 7/25 finished.
Train time: 1199.92 	 Val time 88.80
Train loss 0.6111 	 Train acc: 0.7250
Val loss: 0.6244 	 Val acc: 0.6900
Best val acc: 0.7600 at epoch 2.


50it [20:08, 24.16s/it]


>> train 	 Epoch 9/25 	 Batch 49/50 	 Loss 0.6185 	 Running Acc 0.710 	 Total Acc 0.710 	 Avg Batch Time 24.1649
Time: train: 1208.24 	 Train loss 0.6185 	 Train acc: 0.7100


7it [01:26, 12.34s/it]


>> val 	 Loss 0.6171 	 Running Acc 2.464 	 Total Acc 0.690 	 Avg Batch Time 3.4548
Epoch 8/25 finished.
Train time: 1208.24 	 Val time 86.37
Train loss 0.6185 	 Train acc: 0.7100
Val loss: 0.6180 	 Val acc: 0.6900
Best val acc: 0.7600 at epoch 2.


50it [19:47, 23.75s/it]


>> train 	 Epoch 10/25 	 Batch 49/50 	 Loss 0.6074 	 Running Acc 0.711 	 Total Acc 0.711 	 Avg Batch Time 23.7533
Time: train: 1187.67 	 Train loss 0.6074 	 Train acc: 0.7113


7it [01:25, 12.26s/it]


>> val 	 Loss 0.6108 	 Running Acc 2.536 	 Total Acc 0.710 	 Avg Batch Time 3.4324
Epoch 9/25 finished.
Train time: 1187.67 	 Val time 85.81
Train loss 0.6074 	 Train acc: 0.7113
Val loss: 0.6120 	 Val acc: 0.7100
Best val acc: 0.7600 at epoch 2.


50it [20:03, 24.06s/it]


>> train 	 Epoch 11/25 	 Batch 49/50 	 Loss 0.6015 	 Running Acc 0.715 	 Total Acc 0.715 	 Avg Batch Time 24.0627
Time: train: 1203.14 	 Train loss 0.6015 	 Train acc: 0.7150


7it [01:26, 12.37s/it]


>> val 	 Loss 0.6059 	 Running Acc 2.679 	 Total Acc 0.750 	 Avg Batch Time 3.4640
Epoch 10/25 finished.
Train time: 1203.14 	 Val time 86.60
Train loss 0.6015 	 Train acc: 0.7150
Val loss: 0.6071 	 Val acc: 0.7500
Best val acc: 0.7600 at epoch 2.


50it [19:36, 23.53s/it]


>> train 	 Epoch 12/25 	 Batch 49/50 	 Loss 0.5906 	 Running Acc 0.734 	 Total Acc 0.734 	 Avg Batch Time 23.5316
Time: train: 1176.58 	 Train loss 0.5906 	 Train acc: 0.7338


7it [01:26, 12.33s/it]


>> val 	 Loss 0.6004 	 Running Acc 2.607 	 Total Acc 0.730 	 Avg Batch Time 3.4530
Epoch 11/25 finished.
Train time: 1176.58 	 Val time 86.33
Train loss 0.5906 	 Train acc: 0.7338
Val loss: 0.6020 	 Val acc: 0.7300
Best val acc: 0.7600 at epoch 2.


50it [19:48, 23.78s/it]


>> train 	 Epoch 13/25 	 Batch 49/50 	 Loss 0.5861 	 Running Acc 0.739 	 Total Acc 0.739 	 Avg Batch Time 23.7764
Time: train: 1188.82 	 Train loss 0.5861 	 Train acc: 0.7388


7it [01:26, 12.41s/it]


>> val 	 Loss 0.5969 	 Running Acc 2.679 	 Total Acc 0.750 	 Avg Batch Time 3.4753
Epoch 12/25 finished.
Train time: 1188.82 	 Val time 86.88
Train loss 0.5861 	 Train acc: 0.7388
Val loss: 0.5987 	 Val acc: 0.7500
Best val acc: 0.7600 at epoch 2.


50it [19:26, 23.32s/it]


>> train 	 Epoch 14/25 	 Batch 49/50 	 Loss 0.5803 	 Running Acc 0.739 	 Total Acc 0.739 	 Avg Batch Time 23.3230
Time: train: 1166.15 	 Train loss 0.5803 	 Train acc: 0.7388


7it [01:26, 12.31s/it]


>> val 	 Loss 0.5947 	 Running Acc 2.607 	 Total Acc 0.730 	 Avg Batch Time 3.4457
Epoch 13/25 finished.
Train time: 1166.15 	 Val time 86.14
Train loss 0.5803 	 Train acc: 0.7388
Val loss: 0.5967 	 Val acc: 0.7300
Best val acc: 0.7600 at epoch 2.


50it [19:46, 23.73s/it]


>> train 	 Epoch 15/25 	 Batch 49/50 	 Loss 0.5801 	 Running Acc 0.733 	 Total Acc 0.733 	 Avg Batch Time 23.7294
Time: train: 1186.47 	 Train loss 0.5801 	 Train acc: 0.7325


7it [01:25, 12.20s/it]


>> val 	 Loss 0.5937 	 Running Acc 2.643 	 Total Acc 0.740 	 Avg Batch Time 3.4168
Epoch 14/25 finished.
Train time: 1186.47 	 Val time 85.42
Train loss 0.5801 	 Train acc: 0.7325
Val loss: 0.5958 	 Val acc: 0.7400
Best val acc: 0.7600 at epoch 2.


50it [19:45, 23.71s/it]


>> train 	 Epoch 16/25 	 Batch 49/50 	 Loss 0.5892 	 Running Acc 0.719 	 Total Acc 0.719 	 Avg Batch Time 23.7078
Time: train: 1185.39 	 Train loss 0.5892 	 Train acc: 0.7188


7it [01:27, 12.47s/it]


>> val 	 Loss 0.5934 	 Running Acc 2.643 	 Total Acc 0.740 	 Avg Batch Time 3.4925
Epoch 15/25 finished.
Train time: 1185.39 	 Val time 87.31
Train loss 0.5892 	 Train acc: 0.7188
Val loss: 0.5955 	 Val acc: 0.7400
Best val acc: 0.7600 at epoch 2.


50it [19:48, 23.77s/it]


>> train 	 Epoch 17/25 	 Batch 49/50 	 Loss 0.5879 	 Running Acc 0.713 	 Total Acc 0.713 	 Avg Batch Time 23.7681
Time: train: 1188.41 	 Train loss 0.5879 	 Train acc: 0.7125


7it [01:26, 12.41s/it]


>> val 	 Loss 0.5874 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.4740
New best validation model, saving...
Epoch 16/25 finished.
Train time: 1188.41 	 Val time 86.85
Train loss 0.5879 	 Train acc: 0.7125
Val loss: 0.5899 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 16.


50it [19:49, 23.78s/it]


>> train 	 Epoch 18/25 	 Batch 49/50 	 Loss 0.5768 	 Running Acc 0.734 	 Total Acc 0.734 	 Avg Batch Time 23.7846
Time: train: 1189.23 	 Train loss 0.5768 	 Train acc: 0.7338


7it [01:28, 12.64s/it]


>> val 	 Loss 0.5855 	 Running Acc 2.607 	 Total Acc 0.730 	 Avg Batch Time 3.5384
Epoch 17/25 finished.
Train time: 1189.23 	 Val time 88.46
Train loss 0.5768 	 Train acc: 0.7338
Val loss: 0.5881 	 Val acc: 0.7300
Best val acc: 0.7700 at epoch 16.


50it [19:50, 23.81s/it]


>> train 	 Epoch 19/25 	 Batch 49/50 	 Loss 0.5746 	 Running Acc 0.721 	 Total Acc 0.721 	 Avg Batch Time 23.8106
Time: train: 1190.53 	 Train loss 0.5746 	 Train acc: 0.7212


7it [01:27, 12.45s/it]


>> val 	 Loss 0.5786 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.4857
Epoch 18/25 finished.
Train time: 1190.53 	 Val time 87.14
Train loss 0.5746 	 Train acc: 0.7212
Val loss: 0.5810 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 16.


50it [20:00, 24.00s/it]


>> train 	 Epoch 20/25 	 Batch 49/50 	 Loss 0.5664 	 Running Acc 0.726 	 Total Acc 0.726 	 Avg Batch Time 24.0050
Time: train: 1200.25 	 Train loss 0.5664 	 Train acc: 0.7262


7it [01:31, 13.08s/it]


>> val 	 Loss 0.5740 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.6612
Epoch 19/25 finished.
Train time: 1200.25 	 Val time 91.53
Train loss 0.5664 	 Train acc: 0.7262
Val loss: 0.5767 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 16.


50it [19:57, 23.95s/it]


>> train 	 Epoch 21/25 	 Batch 49/50 	 Loss 0.5468 	 Running Acc 0.736 	 Total Acc 0.736 	 Avg Batch Time 23.9534
Time: train: 1197.67 	 Train loss 0.5468 	 Train acc: 0.7362


7it [01:27, 12.46s/it]


>> val 	 Loss 0.5708 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.4876
Epoch 20/25 finished.
Train time: 1197.67 	 Val time 87.19
Train loss 0.5468 	 Train acc: 0.7362
Val loss: 0.5735 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 16.


50it [19:29, 23.39s/it]


>> train 	 Epoch 22/25 	 Batch 49/50 	 Loss 0.5542 	 Running Acc 0.721 	 Total Acc 0.721 	 Avg Batch Time 23.3929
Time: train: 1169.64 	 Train loss 0.5542 	 Train acc: 0.7212


7it [01:25, 12.25s/it]


>> val 	 Loss 0.5674 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.4293
Epoch 21/25 finished.
Train time: 1169.64 	 Val time 85.73
Train loss 0.5542 	 Train acc: 0.7212
Val loss: 0.5703 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 16.


50it [19:50, 23.82s/it]


>> train 	 Epoch 23/25 	 Batch 49/50 	 Loss 0.5508 	 Running Acc 0.740 	 Total Acc 0.740 	 Avg Batch Time 23.8188
Time: train: 1190.94 	 Train loss 0.5508 	 Train acc: 0.7400


7it [01:27, 12.53s/it]


>> val 	 Loss 0.5635 	 Running Acc 2.714 	 Total Acc 0.760 	 Avg Batch Time 3.5081
Epoch 22/25 finished.
Train time: 1190.94 	 Val time 87.70
Train loss 0.5508 	 Train acc: 0.7400
Val loss: 0.5667 	 Val acc: 0.7600
Best val acc: 0.7700 at epoch 16.


50it [19:57, 23.96s/it]


>> train 	 Epoch 24/25 	 Batch 49/50 	 Loss 0.5656 	 Running Acc 0.709 	 Total Acc 0.709 	 Avg Batch Time 23.9576
Time: train: 1197.88 	 Train loss 0.5656 	 Train acc: 0.7087


7it [01:28, 12.66s/it]


>> val 	 Loss 0.5619 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.5452
Epoch 23/25 finished.
Train time: 1197.88 	 Val time 88.63
Train loss 0.5656 	 Train acc: 0.7087
Val loss: 0.5649 	 Val acc: 0.7700
Best val acc: 0.7700 at epoch 16.


50it [20:02, 24.05s/it]


>> train 	 Epoch 25/25 	 Batch 49/50 	 Loss 0.5524 	 Running Acc 0.734 	 Total Acc 0.734 	 Avg Batch Time 24.0507
Time: train: 1202.53 	 Train loss 0.5524 	 Train acc: 0.7338


7it [01:26, 12.35s/it]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.5607 	 Running Acc 2.893 	 Total Acc 0.810 	 Avg Batch Time 3.4587
New best validation model, saving...
Epoch 24/25 finished.
Train time: 1202.53 	 Val time 86.47
Train loss 0.5524 	 Train acc: 0.7338
Val loss: 0.5639 	 Val acc: 0.8100
Best val acc: 0.8100 at epoch 24.


7it [01:27, 12.45s/it]

>> test 	 Loss 0.5258 	 Running Acc 2.750 	 Total Acc 0.770 	 Avg Batch Time 3.4874
Final  tensor([[1.0000, 0.4188, 0.5812],
        [0.0000, 0.3460, 0.6540],
        [1.0000, 0.6909, 0.3091],
        [1.0000, 0.3148, 0.6852],
        [1.0000, 0.2739, 0.7261],
        [1.0000, 0.4070, 0.5930],
        [1.0000, 0.3294, 0.6706],
        [1.0000, 0.4165, 0.5835],
        [1.0000, 0.7745, 0.2255],
        [1.0000, 0.3993, 0.6007],
        [0.0000, 0.6105, 0.3895],
        [1.0000, 0.5358, 0.4642],
        [0.0000, 0.4644, 0.5356],
        [1.0000, 0.2847, 0.7153],
        [0.0000, 0.7641, 0.2359],
        [1.0000, 0.4351, 0.5649],
        [0.0000, 0.4926, 0.5074],
        [0.0000, 0.6473, 0.3527],
        [1.0000, 0.3029, 0.6971],
        [0.0000, 0.7143, 0.2857],
        [1.0000, 0.5034, 0.4966],
        [0.0000, 0.7692, 0.2308],
        [1.0000, 0.3618, 0.6382],
        [0.0000, 0.7710, 0.2290],
        [1.0000, 0.5715, 0.4285],
        [0.0000, 0.8830, 0.1170],
        [1.0000, 0.4907, 


