# Lorentz-Equivariant Quantum Graph Neural Network (Lorentz-EQGNN)

In [1]:
# For Colab
!pip install torch_geometric
# !pip install torch_sparse
# !pip install torch_scatter

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
!pip install pennylane qiskit pennylane-qiskit pylatexenc

Collecting pennylane
  Downloading PennyLane-0.38.0-py3-none-any.whl.metadata (9.3 kB)
Collecting qiskit
  Downloading qiskit-1.2.4-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting pennylane-qiskit
  Downloading PennyLane_qiskit-0.38.1-py3-none-any.whl.metadata (6.4 kB)
Collecting pylatexenc
  Downloading pylatexenc-2.10.tar.gz (162 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.15.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting autograd (from pennylane)
  Downloading autograd-1.7.0-py3-none-any.whl.metadata (7.5 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.38 (from pennylane)
  Downloading PennyLane_Lightning-0.3

In [3]:
import pennylane as qml
import qiskit
print(qml.__version__)
print(qiskit.__version__)
import pennylane_qiskit
print(pennylane_qiskit.__version__)
import pennylane as qml
from pennylane import numpy as np
# from pennylane_qiskit import AerDevice

0.38.0
1.2.4
0.38.1


## Dataset

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import h5py
from matplotlib.colors import LogNorm
import matplotlib.colors as mcolors
from sklearn.model_selection import train_test_split
import urllib

# Downloading Dataset
def download_file(url, filename):
    urllib.request.urlretrieve(url, filename)

photon_url = 'https://cernbox.cern.ch/remote.php/dav/public-files/AtBT8y4MiQYFcgc/SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5'
electron_url = 'https://cernbox.cern.ch/remote.php/dav/public-files/FbXw3V4XNyYB3oA/SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5'

download_file(photon_url, 'photon.hdf5')
download_file(electron_url, 'electron.hdf5')

def electron_photon(sample=100000):
    file_path_electron = "/kaggle/working/electron.hdf5"
    with h5py.File(file_path_electron, "r") as file:
        X_e = np.array(file["X"])
        y_e = np.array(file["y"])

    file_path_photon = "/kaggle/working/photon.hdf5"
    with h5py.File(file_path_photon, "r") as file:
        X_p = np.array(file["X"])
        y_p = np.array(file["y"])

    X = np.concatenate((X_e[:sample], X_p[:sample]), axis=0)
    y = np.concatenate((y_e[:sample], y_p[:sample]), axis=0)

    return X, y

In [4]:
X, y = electron_photon(sample=100000)
x_red = X
y_red = y
x_red.shape, y_red
jets_photon = x_red[y_red == 1][:10]
jets_photon.shape

(10, 32, 32, 2)

In [5]:
import torch
import numpy as np
import os
import h5py
from sklearn.model_selection import train_test_split
from scipy.sparse import coo_matrix

def save_electron_photon_tensors(num_data_per_class=500, max_nodes=139, save_dir="electron_photon/data"):
    """
    Generate and save tensor data files for Electron-Photon dataset in a graph-like format.

    Args:
        num_data_per_class (int): Number of samples to use per class
        max_nodes (int): Maximum number of nodes per graph
        save_dir (str): Directory to save the processed data
    """
    os.makedirs(save_dir, exist_ok=True)

    # Load electron and photon data
    with h5py.File("electron.hdf5", "r") as file:
        X_e = np.array(file["X"])[:num_data_per_class]
        y_e = np.ones(num_data_per_class)

    with h5py.File("photon.hdf5", "r") as file:
        X_p = np.array(file["X"])[:num_data_per_class]
        y_p = np.zeros(num_data_per_class)

    # Combine data
    X = np.concatenate((X_e, X_p), axis=0)
    labels = np.concatenate((y_e, y_p), axis=0)

    # Shuffle the data
    shuffle_idx = np.random.permutation(len(X))
    X = X[shuffle_idx]
    labels = labels[shuffle_idx]

    def image_to_nodes(image_data, max_nodes=139, threshold=1e-4):
        """Convert image data to graph nodes with position and feature information."""
        batch_size, height, width, channels = image_data.shape
        nodes = np.zeros((batch_size, max_nodes, 8))  # 8-dimensional node features
        p4s = np.zeros((batch_size, max_nodes, 4))    # Position and momentum information
        atom_masks = np.zeros((batch_size, max_nodes), dtype=bool)

        for b in range(batch_size):
            # Combine track and ECAL information
            combined = np.sum(image_data[b, :, :, :2], axis=-1)

            # Find significant points and their values
            significant_points = np.where(combined > threshold)
            values = combined[significant_points]

            # Sort points by their values and take exactly max_nodes
            sorted_indices = np.argsort(-values)
            n_points = min(len(sorted_indices), max_nodes)
            selected_indices = sorted_indices[:n_points]

            for idx_pos, idx in enumerate(selected_indices):
                h, w = significant_points[0][idx], significant_points[1][idx]

                # Enhanced node features (8-dimensional)
                nodes[b, idx_pos] = [
                    image_data[b, h, w, 0],  # track energy
                    image_data[b, h, w, 1],  # ECAL energy
                    h / height,              # normalized height position
                    w / width,               # normalized width position
                    np.sqrt(h**2 + w**2) / np.sqrt(height**2 + width**2),  # normalized radius
                    np.arctan2(h - height/2, w - width/2) / np.pi,         # angular position
                    values[idx] / np.max(values),                          # normalized energy
                    1.0 if image_data[b, h, w, 0] > image_data[b, h, w, 1] else 0.0  # track vs ECAL dominance
                ]

                # Create p4s (x, y, E, 0) - normalized coordinates
                x = (w - width/2) / (width/2)   # Center and normalize to [-1, 1]
                y = (h - height/2) / (height/2) # Center and normalize to [-1, 1]
                E = values[idx]                 # Energy value
                p4s[b, idx_pos] = [x, y, E, 0]
                atom_masks[b, idx_pos] = True

        return p4s, nodes, atom_masks

    # Convert image data to graph format
    p4s, nodes, atom_mask = image_to_nodes(X, max_nodes=max_nodes)

    # Convert to torch tensors
    p4s = torch.from_numpy(p4s).float()
    nodes = torch.from_numpy(nodes).float()
    labels = torch.from_numpy(labels).float()
    atom_mask = torch.from_numpy(atom_mask)

    # Create edge mask (fully connected graph between valid nodes)
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask = edge_mask * diag_mask

    # Calculate edges
    batch_size = len(X)
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx * max_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    edges = np.stack([rows, cols])

    # Save tensors
    torch.save(p4s, os.path.join(save_dir, "p4s.pt"))
    torch.save(nodes, os.path.join(save_dir, "nodes.pt"))
    torch.save(labels, os.path.join(save_dir, "labels.pt"))
    torch.save(atom_mask, os.path.join(save_dir, "atom_mask.pt"))
    np.save(os.path.join(save_dir, "edge_mask.npy"), edge_mask.numpy())
    np.save(os.path.join(save_dir, "edges.npy"), edges)

    print(f"Saved tensor files to {save_dir}")
    print(f"Shapes:")
    print(f"p4s: {p4s.shape}")
    print(f"nodes: {nodes.shape}")
    print(f"labels: {labels.shape}")
    print(f"atom_mask: {atom_mask.shape}")
    print(f"edge_mask: {edge_mask.shape}")
    print(f"edges: {edges.shape}")

    # Print label distribution
    unique_labels, counts = np.unique(labels, return_counts=True)
    print("\nLabel distribution:")
    for label, count in zip(unique_labels, counts):
        print(f"Class {label}: {count} samples")

if __name__ == '__main__':
    # Generate data for electron (1) vs photon (0)
    save_electron_photon_tensors(
        num_data_per_class=500,  # 500 samples per class = 1000 total
        max_nodes=200,           # Maximum number of nodes per graph
        save_dir="electron_photon/data"
    )

Saved tensor files to electron_photon/data
Shapes:
p4s: torch.Size([1000, 200, 4])
nodes: torch.Size([1000, 200, 8])
labels: torch.Size([1000])
atom_mask: torch.Size([1000, 200])
edge_mask: torch.Size([1000, 200, 200])
edges: (2, 420236)

Label distribution:
Class 0.0: 500 samples
Class 1.0: 500 samples


In [6]:
def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    edge_mask = data[-2]

    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edges]

In [8]:
from torch.utils.data import TensorDataset, DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import torch

def create_stratified_split(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_state=42):
    """
    Creates stratified train/val/test splits for binary classification.
    """
    # Extract labels and convert to integer labels for stratification
    labels = dataset.tensors[0].numpy()
    labels_int = labels.astype(int)  # Convert to int for stratification

    # Print initial distribution
    unique_classes, class_counts = np.unique(labels_int, return_counts=True)
    print("\nInitial class distribution:")
    for class_idx, count in zip(unique_classes, class_counts):
        print(f"Class {class_idx}: {count} samples ({count/len(labels)*100:.2f}%)")

    # First split: separate train from temporary val+test
    first_sss = StratifiedShuffleSplit(n_splits=1,
                                      train_size=train_ratio,
                                      random_state=random_state)

    train_idx, temp_idx = next(first_sss.split(np.zeros(len(labels)), labels_int))

    # Second split: separate val and test from temporary set
    temp_labels = labels_int[temp_idx]
    val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)

    second_sss = StratifiedShuffleSplit(n_splits=1,
                                       train_size=val_ratio_adjusted,
                                       random_state=random_state)

    val_idx_temp, test_idx_temp = next(second_sss.split(np.zeros(len(temp_labels)), temp_labels))

    # Convert temporary indices to original dataset indices
    val_idx = temp_idx[val_idx_temp]
    test_idx = temp_idx[test_idx_temp]

    # Create subset datasets
    train_dataset = Subset(dataset, train_idx)
    val_dataset = Subset(dataset, val_idx)
    test_dataset = Subset(dataset, test_idx)

    # Verify class distribution in each split
    def get_split_distribution(indices):
        split_labels = labels_int[indices]
        unique, counts = np.unique(split_labels, return_counts=True)
        return dict(zip(unique, counts / len(indices)))

    print("\nClass distribution in splits:")
    print("Train:", get_split_distribution(train_idx))
    print("Val:", get_split_distribution(val_idx))
    print("Test:", get_split_distribution(test_idx))

    return {
        "train": train_dataset,
        "val": val_dataset,
        "test": test_dataset
    }

# Let's reload the data and ensure we have binary labels
labels = torch.load('/kaggle/working/electron_photon/data/labels.pt')
p4s = torch.load('/kaggle/working/electron_photon/data/p4s.pt')
nodes = torch.load('/kaggle/working/electron_photon/data/nodes.pt')
atom_mask = torch.load('/kaggle/working/electron_photon/data/atom_mask.pt')
edge_mask = torch.from_numpy(np.load('/kaggle/working/electron_photon/data/edge_mask.npy'))

# Print label statistics before creating dataset
print("Label Statistics:")
print("Shape:", labels.shape)
print("Unique values:", torch.unique(labels))
print("Class distribution:", torch.bincount(labels.long()) / len(labels))

# Create the dataset
dataset_all = TensorDataset(labels, p4s, nodes, atom_mask, edge_mask)

# Create stratified splits
datasets = create_stratified_split(dataset_all)

# Create dataloaders
dataloaders = {
    split: DataLoader(
        dataset,
        batch_size=16,
        pin_memory=False,
        collate_fn=collate_fn,
        drop_last=True if (split == 'train') else False,
        num_workers=0,
        shuffle=(split == 'train')
    )
    for split, dataset in datasets.items()
}

Label Statistics:
Shape: torch.Size([1000])
Unique values: tensor([0., 1.])
Class distribution: tensor([0.5000, 0.5000])

Initial class distribution:
Class 0: 500 samples (50.00%)
Class 1: 500 samples (50.00%)

Class distribution in splits:
Train: {0: 0.5, 1: 0.5}
Val: {0: 0.5, 1: 0.5}
Test: {0: 0.5, 1: 0.5}


  labels = torch.load('/kaggle/working/electron_photon/data/labels.pt')
  p4s = torch.load('/kaggle/working/electron_photon/data/p4s.pt')
  nodes = torch.load('/kaggle/working/electron_photon/data/nodes.pt')
  atom_mask = torch.load('/kaggle/working/electron_photon/data/atom_mask.pt')


In [9]:
print(p4s.shape) # p4s
print(nodes.shape) # mass
print(atom_mask.shape) # torch.ones
print(edge_mask.shape) # adj_matrix

torch.Size([1000, 200, 4])
torch.Size([1000, 200, 8])
torch.Size([1000, 200])
torch.Size([1000, 200, 200])


In [10]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7a529ffa5b10>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x7a51f1b42e60>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7a51f1b42d40>}

In [11]:
import torch
import numpy as np
import h5py
import os
from scipy.sparse import coo_matrix

def save_physics_tensors(num_data=1000, save_dir="random/data"):
    """
    Generate and save tensor data files for electron-photon analysis.

    Args:
        num_data: Number of data points per class
        save_dir: Directory to save the tensor files
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Load electron-photon data
    def load_data(sample=1000):
        with h5py.File("/kaggle/working/electron.hdf5", "r") as file:
            X_e = np.array(file["X"])[:sample]
            y_e = np.ones(sample)

        with h5py.File("/kaggle/working/photon.hdf5", "r") as file:
            X_p = np.array(file["X"])[:sample]
            y_p = np.zeros(sample)

        X = np.concatenate((X_e, X_p), axis=0)
        y = np.concatenate((y_e, y_p), axis=0)
        return X, y

    # Load raw data
    X, labels = load_data(sample=num_data)
    batch_size = len(X)
    n_nodes = 500  # Fixed number of nodes as in quark-gluon data

    # Process image data into graph format
    def image_to_nodes(image_data, max_nodes=100, threshold=1e-4):
        batch_size = len(image_data)
        nodes = np.zeros((batch_size, max_nodes, 1))  # Match quark-gluon node dimension
        p4s = np.zeros((batch_size, max_nodes, 4))
        atom_masks = np.zeros((batch_size, max_nodes), dtype=bool)

        for b in range(batch_size):
            # Combine track and ECAL information
            combined = np.sum(image_data[b, :, :, :2], axis=-1)
            significant_points = np.where(combined > threshold)
            values = combined[significant_points]

            # Sort points by their values and take top max_nodes
            sorted_indices = np.argsort(-values)
            n_points = min(len(sorted_indices), max_nodes)
            selected_indices = sorted_indices[:n_points]

            for idx_pos, idx in enumerate(selected_indices):
                h, w = significant_points[0][idx], significant_points[1][idx]

                # Create node features similar to quark-gluon format
                track_energy = image_data[b, h, w, 0]
                ecal_energy = image_data[b, h, w, 1]
                total_energy = track_energy + ecal_energy

                # Node feature (single dimension as in quark-gluon)
                nodes[b, idx_pos, 0] = np.log(total_energy + 1)

                # Create pseudo-p4s (pt, eta, phi, m) from position and energy
                pt = total_energy
                eta = (h - 16) / 16  # Convert pixel position to pseudorapidity-like value
                phi = (w - 16) / 16 * np.pi  # Convert pixel position to phi-like value
                mass = 0  # Assume massless particles

                p4s[b, idx_pos] = [pt, eta, phi, mass]
                atom_masks[b, idx_pos] = True

        return p4s, nodes, atom_masks

    # Convert image data to graph format
    p4s, nodes, atom_mask = image_to_nodes(X)

    # Convert to torch tensors
    p4s = torch.from_numpy(p4s).float()
    nodes = torch.from_numpy(nodes).float()
    labels = torch.from_numpy(labels).float()
    atom_mask = torch.from_numpy(atom_mask)

    # Create edge mask
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask = edge_mask * diag_mask

    # Calculate edges
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx * n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    edges = np.stack([rows, cols])

    # Save tensors
    torch.save(p4s, os.path.join(save_dir, "p4s.pt"))
    torch.save(nodes, os.path.join(save_dir, "nodes.pt"))
    torch.save(labels, os.path.join(save_dir, "labels.pt"))
    torch.save(atom_mask, os.path.join(save_dir, "atom_mask.pt"))
    np.save(os.path.join(save_dir, "edge_mask.npy"), edge_mask.numpy())
    np.save(os.path.join(save_dir, "edges.npy"), edges)

    print(f"Saved tensor files to {save_dir}")
    print(f"Shapes:")
    print(f"p4s: {p4s.shape}")
    print(f"nodes: {nodes.shape}")
    print(f"labels: {labels.shape}")
    print(f"atom_mask: {atom_mask.shape}")
    print(f"edge_mask: {edge_mask.shape}")
    print(f"edges: {edges.shape}")

# Generate and save the tensor files
save_physics_tensors(num_data=500)  # This will give 1000 total samples (500 each class)

Saved tensor files to random/data
Shapes:
p4s: torch.Size([1000, 100, 4])
nodes: torch.Size([1000, 100, 1])
labels: torch.Size([1000])
atom_mask: torch.Size([1000, 100])
edge_mask: torch.Size([1000, 100, 100])
edges: (2, 420236)


In [12]:
from torch.utils.data import TensorDataset, random_split

def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    # edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    # diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    # edge_mask *= diag_mask

    edge_mask = data[-2]

    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edges]


p4s = torch.load('random/data/p4s.pt')
nodes = torch.load('random/data/nodes.pt')
labels = torch.load('random/data/labels.pt')
atom_mask = torch.load('random/data/atom_mask.pt')
edge_mask = torch.from_numpy(np.load('random/data/edge_mask.npy'))
edges = torch.from_numpy(np.load('random/data/edges.npy'))


# Create a TensorDataset
dataset_all = TensorDataset(labels, p4s, nodes, atom_mask, edge_mask)

# Define the split ratios
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

# Calculate the lengths for each split
total_size = len(dataset_all)
train_size = int(total_size * train_ratio)
val_size = int(total_size * val_ratio)
test_size = total_size - train_size - val_size  # Ensure all data is used

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset_all, [train_size, val_size, test_size])

# Create a dictionary to hold the datasets
datasets = {
    "train": train_dataset,
    "val": val_dataset,
    "test": test_dataset
}

dataloaders = {split: DataLoader(dataset,
                                 batch_size=16,
                                 # sampler=train_sampler if (split == 'train') else DistributedSampler(dataset, shuffle=False),
                                 pin_memory=False,
                                 # persistent_workers=True,
                                 collate_fn = collate_fn,
                                 drop_last=True if (split == 'train') else False,
                                 num_workers=0)
                    for split, dataset in datasets.items()}

  p4s = torch.load('random/data/p4s.pt')
  nodes = torch.load('random/data/nodes.pt')
  labels = torch.load('random/data/labels.pt')
  atom_mask = torch.load('random/data/atom_mask.pt')


In [13]:
# Set desired dimensions
batch_size = 1
n_nodes = 3
device = 'cpu'
dtype = torch.float32

# Print initial shapes
print("Initial shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Select subset of data
p4s = p4s[:batch_size, :n_nodes, :]
atom_mask = atom_mask[:batch_size, :n_nodes]
edge_mask = edge_mask[:batch_size, :n_nodes, :n_nodes]
nodes = nodes[:batch_size, :n_nodes, :]

print("\nAfter selection shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Reshape tensors
atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
# Don't reshape edge_mask yet
nodes = nodes.view(batch_size * n_nodes, -1).to(device, dtype)

print("\nAfter reshape shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)  # original shape
print("nodes:", nodes.shape)

# Recalculate edges for the subset
from scipy.sparse import coo_matrix
rows, cols = [], []
for batch_idx in range(batch_size):
    nn = batch_idx * n_nodes
    # Convert edge_mask to numpy and remove any extra dimensions
    edge_mask_np = edge_mask[batch_idx].cpu().numpy().squeeze()
    x = coo_matrix(edge_mask_np)
    rows.append(nn + x.row)
    cols.append(nn + x.col)

edges = [torch.LongTensor(np.concatenate(rows)).to(device),
         torch.LongTensor(np.concatenate(cols)).to(device)]

# Now reshape edge_mask after edges are calculated
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])

Initial shapes:
p4s: torch.Size([1000, 100, 4])
atom_mask: torch.Size([1000, 100])
edge_mask: torch.Size([1000, 100, 100])
nodes: torch.Size([1000, 100, 1])

After selection shapes:
p4s: torch.Size([1, 3, 4])
atom_mask: torch.Size([1, 3])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([1, 3, 1])

After reshape shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([3, 1])

Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


In [14]:
print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])


Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


# 3. LorentzNet

In [15]:
# @title
import torch
from torch import nn
import numpy as np



"""Some auxiliary functions"""

def unsorted_segment_sum(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_mean`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    count = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    count.index_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

def normsq4(p):
    r''' Minkowski square norm
         `\|p\|^2 = p[0]^2-p[1]^2-p[2]^2-p[3]^2`
    '''
    psq = torch.pow(p, 2)
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def dotsq4(p,q):
    r''' Minkowski inner product
         `<p,q> = p[0]q[0]-p[1]q[1]-p[2]q[2]-p[3]q[3]`
    '''
    psq = p*q
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def normA_fn(A):
    return lambda p: torch.einsum('...i, ij, ...j->...', p, A, p)

def dotA_fn(A):
    return lambda p, q: torch.einsum('...i, ij, ...j->...', p, A, q)

def psi(p):
    ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`
    '''
    return torch.sign(p) * torch.log(torch.abs(p) + 1)


"""Lorentz Group-Equivariant Block"""

class LGEB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LGEB, self).__init__()
        self.c_weight = c_weight
        self.dimension_reducer = nn.Linear(10, 4) # New linear layer for dimension reduction
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2.
        print('Input size of phi_e: ', n_input)

        self.include_x = include_x
        self.phi_e = nn.Sequential(
            nn.Linear(n_input, n_hidden, bias=False), # n_input * 2 + n_edge_attr
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU())

        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4


    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        # Reduce the dimension of 'out' to 4 using a linear layer
        out = self.dimension_reducer(out)
        out = self.phi_e(out)
        # print("m_model output: ", out.shape)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.phi_e(out)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m): # norms
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # print("m: ", m.shape)
        # print("trans: ", trans.shape)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        # print("trans: ", trans.shape)
        # print("x.size: ", x.size(0))
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight # * norms[i, j], smth like that, or norms
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            # print("X: ", x)
            x = self.x_model(x, edges, x_diff, m)
            # print("phi_x(X) = ", x, '\n---\n')

        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LorentzNet(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LGEB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LorentzNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LGEBs = nn.ModuleList([LGEB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LGEBs[i](h, x, edges, node_attr=scalars)
        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)

        # print("Final preds: \n", pred.cpu().detach().numpy())
        return pred.squeeze(1)

LGEB(self.n_hidden, self.n_hidden, self.n_hidden,\
                                    n_node_attr=n_scalar, dropout=dropout,\
                                    c_weight=c_weight, last_layer=\(i==n_layers-1), A=A, include_x=include_x)
                                    
We are using n_hidden = 4 and n_layers = 6

n_input=n_hidden, n_output=n_hidden, n_hidden=n_hidden, n_node_attr=n_scalar=8

In [16]:
# @title
import torch
import os, json, random, string
import torch.distributed as dist

def makedir(path):
    try:
        os.makedirs(path)
    except OSError:
        pass

def args_init(args):
    r''' Initialize seed and exp_name.
    '''
    if args.seed is None: # use random seed if not specified
        args.seed = np.random.randint(100)
    if args.exp_name == '': # use random strings if not specified
        args.exp_name = ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
    if (args.local_rank == 0): # master
        print(args)
        makedir(f"{args.logdir}/{args.exp_name}")
        with open(f"{args.logdir}/{args.exp_name}/args.json", 'w') as f:
            json.dump(args.__dict__, f, indent=4)

def sum_reduce(num, device):
    r''' Sum the tensor across the devices.
    '''
    if not torch.is_tensor(num):
        rt = torch.tensor(num).to(device)
    else:
        rt = num.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        warmup_epoch: target learning rate is reached at warmup_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    Reference:
        https://github.com/ildoonet/pytorch-gradual-warmup-lr
    """

    def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.warmup_epoch = warmup_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    @property
    def _warmup_lr(self):
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch + 1) / self.warmup_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * (self.last_epoch + 1) / self.warmup_epoch + 1.) for base_lr in self.base_lrs]

    def get_lr(self):
        if self.last_epoch >= self.warmup_epoch - 1:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        return self._warmup_lr

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        self.last_epoch = self.last_epoch + 1 if epoch==None else epoch
        if self.last_epoch >= self.warmup_epoch - 1:
            if not self.finished:
                warmup_lr = [base_lr * self.multiplier for base_lr in self.base_lrs]
                for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                    param_group['lr'] = lr
                self.finished = True
                return
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.warmup_epoch)
            return

        for param_group, lr in zip(self.optimizer.param_groups, self._warmup_lr):
            param_group['lr'] = lr

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.warmup_epoch)
                self.last_epoch = self.after_scheduler.last_epoch + self.warmup_epoch + 1
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        result = {key: value for key, value in self.__dict__.items() if key != 'optimizer' or key != "after_scheduler"}
        if self.after_scheduler:
            result.update({"after_scheduler": self.after_scheduler.state_dict()})
        return result

    def load_state_dict(self, state_dict):
        after_scheduler_state = state_dict.pop("after_scheduler", None)
        self.__dict__.update(state_dict)
        if after_scheduler_state:
            self.after_scheduler.load_state_dict(after_scheduler_state)


from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

def buildROC(labels, score, targetEff=[0.3,0.5]):
    r''' ROC curve is a plot of the true positive rate (Sensitivity) in the function of the false positive rate
    (100-Specificity) for different cut-off points of a parameter. Each point on the ROC curve represents a
    sensitivity/specificity pair corresponding to a particular decision threshold. The Area Under the ROC
    curve (AUC) is a measure of how well a parameter can distinguish between two diagnostic groups.
    '''
    if not isinstance(targetEff, list):
        targetEff = [targetEff]
    fpr, tpr, threshold = roc_curve(labels, score)
    idx = [np.argmin(np.abs(tpr - Eff)) for Eff in targetEff]
    eB, eS = fpr[idx], tpr[idx]
    return fpr, tpr, threshold, eB, eS

In [18]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from tqdm import tqdm

def run(model, epoch, loader, partition, N_EPOCHS=None):
    if partition == 'train':
        model.train()
    else:
        model.eval()

    res = {'time':0, 'correct':0, 'loss': 0, 'counter': 0, 'acc': 0,
           'loss_arr':[], 'correct_arr':[],'label':[],'score':[]}

    tik = time.time()
    loader_length = len(loader)

    for i, (label, p4s, nodes, atom_mask, edge_mask, edges) in tqdm(enumerate(loader)):
        if partition == 'train':
            optimizer.zero_grad()

        batch_size, n_nodes, _ = p4s.size()
        atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
        atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device)
        edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)
        nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)
        edges = [a.to(device) for a in edges]
        label = label.to(device, dtype).long()

        pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                         edge_mask=edge_mask, n_nodes=n_nodes)

        predict = pred.max(1).indices
        correct = torch.sum(predict == label).item()
        # print(pred.shape,label.shape)
        loss = loss_fn(pred, label)

        if partition == 'train':
            loss.backward()
            optimizer.step()
        elif partition == 'test':
            # save labels and probilities for ROC / AUC
            # print("Preds ", pred)
            score = torch.nn.functional.softmax(pred, dim = -1)
            # print("Score test ", score)
            # raise
            res['label'].append(label)
            res['score'].append(score)

        res['time'] = time.time() - tik
        res['correct'] += correct
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        res['loss_arr'].append(loss.item())
        res['correct_arr'].append(correct)

        # if i != 0 and i % args.log_interval == 0:

    running_loss = sum(res['loss_arr'])/len(res['loss_arr'])
    running_acc = sum(res['correct_arr'])/(len(res['correct_arr'])*batch_size)
    avg_time = res['time']/res['counter'] * batch_size
    tmp_counter = res['counter']
    tmp_loss = res['loss'] / tmp_counter
    tmp_acc = res['correct'] / tmp_counter

    if N_EPOCHS:
        print(">> %s \t Epoch %d/%d \t Batch %d/%d \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
             (partition, epoch + 1, N_EPOCHS, i, loader_length, running_loss, running_acc, tmp_acc, avg_time))
    else:
        print(">> %s \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
             (partition, running_loss, running_acc, tmp_acc, avg_time))

    torch.cuda.empty_cache()
    # ---------- reduce -----------
    if partition == 'test':
        res['label'] = torch.cat(res['label']).unsqueeze(-1)
        res['score'] = torch.cat(res['score'])
        res['score'] = torch.cat((res['label'],res['score']),dim=-1)
    res['counter'] = res['counter']
    res['loss'] = res['loss'] / res['counter']
    res['acc'] = res['correct'] / res['counter']
    return res

def train(model, res, N_EPOCHS, model_path, log_path):
    ### training and validation
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(log_path, exist_ok=True)

    for epoch in range(N_EPOCHS):
        train_res = run(model, epoch, dataloaders['train'], partition='train', N_EPOCHS = N_EPOCHS)
        print("Time: train: %.2f \t Train loss %.4f \t Train acc: %.4f" % (train_res['time'],train_res['loss'],train_res['acc']))
        # if epoch % args.val_interval == 0:

        # if (args.local_rank == 0):
        torch.save(model.state_dict(), os.path.join(model_path, "checkpoint-epoch-{}.pt".format(epoch)) )
        with torch.no_grad():
            val_res = run(model, epoch, dataloaders['val'], partition='val')

        # if (args.local_rank == 0): # only master process save
        res['lr'].append(optimizer.param_groups[0]['lr'])
        res['train_time'].append(train_res['time'])
        res['val_time'].append(val_res['time'])
        res['train_loss'].append(train_res['loss'])
        res['train_acc'].append(train_res['acc'])
        res['val_loss'].append(val_res['loss'])
        res['val_acc'].append(val_res['acc'])
        res['epochs'].append(epoch)

        ## save best model
        if val_res['acc'] > res['best_val']:
            print("New best validation model, saving...")
            torch.save(model.state_dict(), os.path.join(model_path,"best-val-model.pt"))
            res['best_val'] = val_res['acc']
            res['best_epoch'] = epoch

        print("Epoch %d/%d finished." % (epoch, N_EPOCHS))
        print("Train time: %.2f \t Val time %.2f" % (train_res['time'], val_res['time']))
        print("Train loss %.4f \t Train acc: %.4f" % (train_res['loss'], train_res['acc']))
        print("Val loss: %.4f \t Val acc: %.4f" % (val_res['loss'], val_res['acc']))
        print("Best val acc: %.4f at epoch %d." % (res['best_val'],  res['best_epoch']))

        json_object = json.dumps(res, indent=4)
        with open(os.path.join(log_path, "train-result-epoch{}.json".format(epoch)), "w") as outfile:
            outfile.write(json_object)

        ## adjust learning rate
        if (epoch < 31):
            lr_scheduler.step(metrics=val_res['acc'])
        else:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5


def test(model, res, model_path, log_path):
    ### test on best model
    best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)
    model.load_state_dict(best_model)
    with torch.no_grad():
        test_res = run(model, 0, dataloaders['test'], partition='test')

    print("Final ", test_res['score'])
    pred = test_res['score'].cpu()

    np.save(os.path.join(log_path, "score.npy"), pred)
    fpr, tpr, thres, eB, eS  = buildROC(pred[...,0], pred[...,2])
    auc = roc_auc_score(pred[...,0], pred[...,2])

    metric = {'test_loss': test_res['loss'], 'test_acc': test_res['acc'],
              'test_auc': auc, 'test_1/eB_0.3':1./eB[0],'test_1/eB_0.5':1./eB[1]}
    res.update(metric)
    print("Test: Loss %.4f \t Acc %.4f \t AUC: %.4f \t 1/eB 0.3: %.4f \t 1/eB 0.5: %.4f"\
           % (test_res['loss'], test_res['acc'], auc, 1./eB[0], 1./eB[1]))
    json_object = json.dumps(res, indent=4)
    with open(os.path.join(log_path, "test-result.json"), "w") as outfile:
        outfile.write(json_object)

if __name__ == "__main__":

    N_EPOCHS = 55 # 60

    model_path = "models/LorentzNet/"
    log_path = "logs/LorentzNet/"
    # args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cuda
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cpu")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    ### create parallel model
    model = LorentzNet(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)



Input size of phi_e:  4
Model Size: 199
 train samples: 800
 val samples: 100
 test samples: 100
Training...


50it [00:00, 123.87it/s]


>> train 	 Epoch 1/55 	 Batch 49/50 	 Loss 0.6964 	 Running Acc 0.497 	 Total Acc 0.497 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6964 	 Train acc: 0.4975


7it [00:00, 264.69it/s]


>> val 	 Loss 0.6962 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.0011
New best validation model, saving...
Epoch 0/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6964 	 Train acc: 0.4975
Val loss: 0.6952 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [00:00, 123.35it/s]


>> train 	 Epoch 2/55 	 Batch 49/50 	 Loss 0.6955 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6955 	 Train acc: 0.5025


7it [00:00, 245.75it/s]


>> val 	 Loss 0.6962 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.0012
Epoch 1/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6955 	 Train acc: 0.5025
Val loss: 0.6949 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [00:00, 107.58it/s]


>> train 	 Epoch 3/55 	 Batch 49/50 	 Loss 0.6968 	 Running Acc 0.472 	 Total Acc 0.472 	 Avg Batch Time 0.0094
Time: train: 0.47 	 Train loss 0.6968 	 Train acc: 0.4725


7it [00:00, 222.99it/s]


>> val 	 Loss 0.6947 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.0014
Epoch 2/55 finished.
Train time: 0.47 	 Val time 0.03
Train loss 0.6968 	 Train acc: 0.4725
Val loss: 0.6939 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [00:00, 95.21it/s] 


>> train 	 Epoch 4/55 	 Batch 49/50 	 Loss 0.6931 	 Running Acc 0.516 	 Total Acc 0.516 	 Avg Batch Time 0.0106
Time: train: 0.53 	 Train loss 0.6931 	 Train acc: 0.5162


7it [00:00, 177.16it/s]


>> val 	 Loss 0.6942 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.0017
Epoch 3/55 finished.
Train time: 0.53 	 Val time 0.04
Train loss 0.6931 	 Train acc: 0.5162
Val loss: 0.6935 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [00:00, 100.97it/s]


>> train 	 Epoch 5/55 	 Batch 49/50 	 Loss 0.6943 	 Running Acc 0.500 	 Total Acc 0.500 	 Avg Batch Time 0.0100
Time: train: 0.50 	 Train loss 0.6943 	 Train acc: 0.5000


7it [00:00, 250.35it/s]


>> val 	 Loss 0.6937 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0012
New best validation model, saving...
Epoch 4/55 finished.
Train time: 0.50 	 Val time 0.03
Train loss 0.6943 	 Train acc: 0.5000
Val loss: 0.6932 	 Val acc: 0.5400
Best val acc: 0.5400 at epoch 4.


50it [00:00, 121.38it/s]


>> train 	 Epoch 6/55 	 Batch 49/50 	 Loss 0.6950 	 Running Acc 0.482 	 Total Acc 0.482 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6950 	 Train acc: 0.4825


7it [00:00, 232.26it/s]


>> val 	 Loss 0.6933 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 0.0013
Epoch 5/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6950 	 Train acc: 0.4825
Val loss: 0.6929 	 Val acc: 0.5000
Best val acc: 0.5400 at epoch 4.


50it [00:00, 122.03it/s]


>> train 	 Epoch 7/55 	 Batch 49/50 	 Loss 0.6936 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6936 	 Train acc: 0.4988


7it [00:00, 254.09it/s]


>> val 	 Loss 0.6933 	 Running Acc 1.607 	 Total Acc 0.450 	 Avg Batch Time 0.0012
Epoch 6/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6936 	 Train acc: 0.4988
Val loss: 0.6929 	 Val acc: 0.4500
Best val acc: 0.5400 at epoch 4.


50it [00:00, 120.81it/s]


>> train 	 Epoch 8/55 	 Batch 49/50 	 Loss 0.6929 	 Running Acc 0.505 	 Total Acc 0.505 	 Avg Batch Time 0.0083
Time: train: 0.42 	 Train loss 0.6929 	 Train acc: 0.5050


7it [00:00, 239.48it/s]


>> val 	 Loss 0.6933 	 Running Acc 1.607 	 Total Acc 0.450 	 Avg Batch Time 0.0013
Epoch 7/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6929 	 Train acc: 0.5050
Val loss: 0.6929 	 Val acc: 0.4500
Best val acc: 0.5400 at epoch 4.


50it [00:00, 121.63it/s]


>> train 	 Epoch 9/55 	 Batch 49/50 	 Loss 0.6931 	 Running Acc 0.507 	 Total Acc 0.507 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6931 	 Train acc: 0.5075


7it [00:00, 246.27it/s]


>> val 	 Loss 0.6931 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.0012
Epoch 8/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6931 	 Train acc: 0.5075
Val loss: 0.6928 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 4.


50it [00:00, 122.50it/s]


>> train 	 Epoch 10/55 	 Batch 49/50 	 Loss 0.6938 	 Running Acc 0.496 	 Total Acc 0.496 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6938 	 Train acc: 0.4963


7it [00:00, 241.61it/s]


>> val 	 Loss 0.6932 	 Running Acc 1.679 	 Total Acc 0.470 	 Avg Batch Time 0.0012
Epoch 9/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6938 	 Train acc: 0.4963
Val loss: 0.6929 	 Val acc: 0.4700
Best val acc: 0.5400 at epoch 4.


50it [00:00, 117.84it/s]


>> train 	 Epoch 11/55 	 Batch 49/50 	 Loss 0.6939 	 Running Acc 0.492 	 Total Acc 0.492 	 Avg Batch Time 0.0085
Time: train: 0.43 	 Train loss 0.6939 	 Train acc: 0.4925


7it [00:00, 241.31it/s]


>> val 	 Loss 0.6929 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 0.0013
Epoch 10/55 finished.
Train time: 0.43 	 Val time 0.03
Train loss 0.6939 	 Train acc: 0.4925
Val loss: 0.6927 	 Val acc: 0.5000
Best val acc: 0.5400 at epoch 4.


50it [00:00, 120.21it/s]


>> train 	 Epoch 12/55 	 Batch 49/50 	 Loss 0.6924 	 Running Acc 0.531 	 Total Acc 0.531 	 Avg Batch Time 0.0084
Time: train: 0.42 	 Train loss 0.6924 	 Train acc: 0.5312


7it [00:00, 232.53it/s]


>> val 	 Loss 0.6928 	 Running Acc 1.679 	 Total Acc 0.470 	 Avg Batch Time 0.0013
Epoch 11/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6924 	 Train acc: 0.5312
Val loss: 0.6926 	 Val acc: 0.4700
Best val acc: 0.5400 at epoch 4.


50it [00:00, 121.78it/s]


>> train 	 Epoch 13/55 	 Batch 49/50 	 Loss 0.6928 	 Running Acc 0.527 	 Total Acc 0.527 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6928 	 Train acc: 0.5275


7it [00:00, 239.08it/s]


>> val 	 Loss 0.6928 	 Running Acc 1.679 	 Total Acc 0.470 	 Avg Batch Time 0.0013
Epoch 12/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6928 	 Train acc: 0.5275
Val loss: 0.6925 	 Val acc: 0.4700
Best val acc: 0.5400 at epoch 4.


50it [00:00, 118.35it/s]


>> train 	 Epoch 14/55 	 Batch 49/50 	 Loss 0.6933 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 0.0085
Time: train: 0.43 	 Train loss 0.6933 	 Train acc: 0.4988


7it [00:00, 227.20it/s]


>> val 	 Loss 0.6927 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 0.0013
Epoch 13/55 finished.
Train time: 0.43 	 Val time 0.03
Train loss 0.6933 	 Train acc: 0.4988
Val loss: 0.6924 	 Val acc: 0.5000
Best val acc: 0.5400 at epoch 4.


50it [00:00, 116.96it/s]


>> train 	 Epoch 15/55 	 Batch 49/50 	 Loss 0.6930 	 Running Acc 0.515 	 Total Acc 0.515 	 Avg Batch Time 0.0086
Time: train: 0.43 	 Train loss 0.6930 	 Train acc: 0.5150


7it [00:00, 248.49it/s]


>> val 	 Loss 0.6927 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.0012
Epoch 14/55 finished.
Train time: 0.43 	 Val time 0.03
Train loss 0.6930 	 Train acc: 0.5150
Val loss: 0.6924 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 4.


50it [00:00, 118.89it/s]


>> train 	 Epoch 16/55 	 Batch 49/50 	 Loss 0.6907 	 Running Acc 0.547 	 Total Acc 0.547 	 Avg Batch Time 0.0085
Time: train: 0.42 	 Train loss 0.6907 	 Train acc: 0.5475


7it [00:00, 241.12it/s]


>> val 	 Loss 0.6927 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.0013
Epoch 15/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6907 	 Train acc: 0.5475
Val loss: 0.6924 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 4.


50it [00:00, 119.22it/s]


>> train 	 Epoch 17/55 	 Batch 49/50 	 Loss 0.6910 	 Running Acc 0.544 	 Total Acc 0.544 	 Avg Batch Time 0.0084
Time: train: 0.42 	 Train loss 0.6910 	 Train acc: 0.5437


7it [00:00, 249.21it/s]


>> val 	 Loss 0.6928 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
New best validation model, saving...
Epoch 16/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6910 	 Train acc: 0.5437
Val loss: 0.6924 	 Val acc: 0.5600
Best val acc: 0.5600 at epoch 16.


50it [00:00, 122.25it/s]


>> train 	 Epoch 18/55 	 Batch 49/50 	 Loss 0.6923 	 Running Acc 0.519 	 Total Acc 0.519 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6923 	 Train acc: 0.5188


7it [00:00, 240.35it/s]


>> val 	 Loss 0.6928 	 Running Acc 2.036 	 Total Acc 0.570 	 Avg Batch Time 0.0013
New best validation model, saving...
Epoch 17/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6923 	 Train acc: 0.5188
Val loss: 0.6924 	 Val acc: 0.5700
Best val acc: 0.5700 at epoch 17.


50it [00:00, 124.06it/s]


>> train 	 Epoch 19/55 	 Batch 49/50 	 Loss 0.6918 	 Running Acc 0.531 	 Total Acc 0.531 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6918 	 Train acc: 0.5312


7it [00:00, 240.46it/s]


>> val 	 Loss 0.6926 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0013
Epoch 18/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6918 	 Train acc: 0.5312
Val loss: 0.6924 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 120.26it/s]


>> train 	 Epoch 20/55 	 Batch 49/50 	 Loss 0.6912 	 Running Acc 0.522 	 Total Acc 0.522 	 Avg Batch Time 0.0084
Time: train: 0.42 	 Train loss 0.6912 	 Train acc: 0.5225


7it [00:00, 246.46it/s]


>> val 	 Loss 0.6923 	 Running Acc 1.893 	 Total Acc 0.530 	 Avg Batch Time 0.0012
Epoch 19/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6912 	 Train acc: 0.5225
Val loss: 0.6922 	 Val acc: 0.5300
Best val acc: 0.5700 at epoch 17.


50it [00:00, 112.05it/s]


>> train 	 Epoch 21/55 	 Batch 49/50 	 Loss 0.6919 	 Running Acc 0.521 	 Total Acc 0.521 	 Avg Batch Time 0.0090
Time: train: 0.45 	 Train loss 0.6919 	 Train acc: 0.5212


7it [00:00, 248.07it/s]


>> val 	 Loss 0.6926 	 Running Acc 1.964 	 Total Acc 0.550 	 Avg Batch Time 0.0012
Epoch 20/55 finished.
Train time: 0.45 	 Val time 0.03
Train loss 0.6919 	 Train acc: 0.5212
Val loss: 0.6922 	 Val acc: 0.5500
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.31it/s]


>> train 	 Epoch 22/55 	 Batch 49/50 	 Loss 0.6903 	 Running Acc 0.549 	 Total Acc 0.549 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6903 	 Train acc: 0.5487


7it [00:00, 260.09it/s]


>> val 	 Loss 0.6923 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 21/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6903 	 Train acc: 0.5487
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.63it/s]


>> train 	 Epoch 23/55 	 Batch 49/50 	 Loss 0.6897 	 Running Acc 0.549 	 Total Acc 0.549 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6897 	 Train acc: 0.5487


7it [00:00, 262.68it/s]


>> val 	 Loss 0.6924 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0012
Epoch 22/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6897 	 Train acc: 0.5487
Val loss: 0.6922 	 Val acc: 0.5400
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.09it/s]


>> train 	 Epoch 24/55 	 Batch 49/50 	 Loss 0.6887 	 Running Acc 0.566 	 Total Acc 0.566 	 Avg Batch Time 0.0083
Time: train: 0.42 	 Train loss 0.6887 	 Train acc: 0.5663


7it [00:00, 231.48it/s]


>> val 	 Loss 0.6925 	 Running Acc 1.857 	 Total Acc 0.520 	 Avg Batch Time 0.0013
Epoch 23/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6887 	 Train acc: 0.5663
Val loss: 0.6922 	 Val acc: 0.5200
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.43it/s]


>> train 	 Epoch 25/55 	 Batch 49/50 	 Loss 0.6896 	 Running Acc 0.542 	 Total Acc 0.542 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6896 	 Train acc: 0.5425


7it [00:00, 256.87it/s]


>> val 	 Loss 0.6924 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0012
Epoch 24/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6896 	 Train acc: 0.5425
Val loss: 0.6921 	 Val acc: 0.5400
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.96it/s]


>> train 	 Epoch 26/55 	 Batch 49/50 	 Loss 0.6894 	 Running Acc 0.549 	 Total Acc 0.549 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6894 	 Train acc: 0.5487


7it [00:00, 264.42it/s]


>> val 	 Loss 0.6926 	 Running Acc 1.893 	 Total Acc 0.530 	 Avg Batch Time 0.0011
Epoch 25/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6894 	 Train acc: 0.5487
Val loss: 0.6921 	 Val acc: 0.5300
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.54it/s]


>> train 	 Epoch 27/55 	 Batch 49/50 	 Loss 0.6878 	 Running Acc 0.568 	 Total Acc 0.568 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6878 	 Train acc: 0.5675


7it [00:00, 252.68it/s]


>> val 	 Loss 0.6926 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0012
Epoch 26/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6878 	 Train acc: 0.5675
Val loss: 0.6921 	 Val acc: 0.5400
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.06it/s]


>> train 	 Epoch 28/55 	 Batch 49/50 	 Loss 0.6884 	 Running Acc 0.560 	 Total Acc 0.560 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6884 	 Train acc: 0.5600


7it [00:00, 236.47it/s]


>> val 	 Loss 0.6929 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.0013
Epoch 27/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6884 	 Train acc: 0.5600
Val loss: 0.6921 	 Val acc: 0.5400
Best val acc: 0.5700 at epoch 17.


50it [00:00, 119.85it/s]


>> train 	 Epoch 29/55 	 Batch 49/50 	 Loss 0.6893 	 Running Acc 0.542 	 Total Acc 0.542 	 Avg Batch Time 0.0084
Time: train: 0.42 	 Train loss 0.6893 	 Train acc: 0.5425


7it [00:00, 248.54it/s]


>> val 	 Loss 0.6928 	 Running Acc 1.964 	 Total Acc 0.550 	 Avg Batch Time 0.0012
Epoch 28/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6893 	 Train acc: 0.5425
Val loss: 0.6921 	 Val acc: 0.5500
Best val acc: 0.5700 at epoch 17.


50it [00:00, 115.47it/s]


>> train 	 Epoch 30/55 	 Batch 49/50 	 Loss 0.6888 	 Running Acc 0.556 	 Total Acc 0.556 	 Avg Batch Time 0.0087
Time: train: 0.44 	 Train loss 0.6888 	 Train acc: 0.5563


7it [00:00, 263.39it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 29/55 finished.
Train time: 0.44 	 Val time 0.03
Train loss 0.6888 	 Train acc: 0.5563
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.71it/s]


>> train 	 Epoch 31/55 	 Batch 49/50 	 Loss 0.6895 	 Running Acc 0.552 	 Total Acc 0.552 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6895 	 Train acc: 0.5525


7it [00:00, 261.26it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 30/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6895 	 Train acc: 0.5525
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.75it/s]


>> train 	 Epoch 32/55 	 Batch 49/50 	 Loss 0.6890 	 Running Acc 0.535 	 Total Acc 0.535 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6890 	 Train acc: 0.5350


7it [00:00, 271.35it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0011
Epoch 31/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6890 	 Train acc: 0.5350
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.88it/s]


>> train 	 Epoch 33/55 	 Batch 49/50 	 Loss 0.6892 	 Running Acc 0.550 	 Total Acc 0.550 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6892 	 Train acc: 0.5500


7it [00:00, 254.91it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 32/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6892 	 Train acc: 0.5500
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 124.01it/s]


>> train 	 Epoch 34/55 	 Batch 49/50 	 Loss 0.6895 	 Running Acc 0.541 	 Total Acc 0.541 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6895 	 Train acc: 0.5413


7it [00:00, 258.46it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 33/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6895 	 Train acc: 0.5413
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.78it/s]


>> train 	 Epoch 35/55 	 Batch 49/50 	 Loss 0.6886 	 Running Acc 0.551 	 Total Acc 0.551 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6886 	 Train acc: 0.5513


7it [00:00, 217.89it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0014
Epoch 34/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6886 	 Train acc: 0.5513
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.00it/s]


>> train 	 Epoch 36/55 	 Batch 49/50 	 Loss 0.6902 	 Running Acc 0.526 	 Total Acc 0.526 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6902 	 Train acc: 0.5262


7it [00:00, 252.76it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 35/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6902 	 Train acc: 0.5262
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.33it/s]


>> train 	 Epoch 37/55 	 Batch 49/50 	 Loss 0.6896 	 Running Acc 0.544 	 Total Acc 0.544 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6896 	 Train acc: 0.5437


7it [00:00, 241.22it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 36/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6896 	 Train acc: 0.5437
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 120.61it/s]


>> train 	 Epoch 38/55 	 Batch 49/50 	 Loss 0.6890 	 Running Acc 0.542 	 Total Acc 0.542 	 Avg Batch Time 0.0083
Time: train: 0.42 	 Train loss 0.6890 	 Train acc: 0.5425


7it [00:00, 236.91it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0013
Epoch 37/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6890 	 Train acc: 0.5425
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.90it/s]


>> train 	 Epoch 39/55 	 Batch 49/50 	 Loss 0.6887 	 Running Acc 0.560 	 Total Acc 0.560 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6887 	 Train acc: 0.5600


7it [00:00, 219.89it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0014
Epoch 38/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6887 	 Train acc: 0.5600
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.72it/s]


>> train 	 Epoch 40/55 	 Batch 49/50 	 Loss 0.6866 	 Running Acc 0.583 	 Total Acc 0.583 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6866 	 Train acc: 0.5825


7it [00:00, 245.71it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 39/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6866 	 Train acc: 0.5825
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 125.07it/s]


>> train 	 Epoch 41/55 	 Batch 49/50 	 Loss 0.6878 	 Running Acc 0.565 	 Total Acc 0.565 	 Avg Batch Time 0.0081
Time: train: 0.40 	 Train loss 0.6878 	 Train acc: 0.5650


7it [00:00, 260.99it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 40/55 finished.
Train time: 0.40 	 Val time 0.03
Train loss 0.6878 	 Train acc: 0.5650
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.09it/s]


>> train 	 Epoch 42/55 	 Batch 49/50 	 Loss 0.6874 	 Running Acc 0.579 	 Total Acc 0.579 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6874 	 Train acc: 0.5787


7it [00:00, 217.95it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0014
Epoch 41/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6874 	 Train acc: 0.5787
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 112.30it/s]


>> train 	 Epoch 43/55 	 Batch 49/50 	 Loss 0.6893 	 Running Acc 0.547 	 Total Acc 0.547 	 Avg Batch Time 0.0090
Time: train: 0.45 	 Train loss 0.6893 	 Train acc: 0.5475


7it [00:00, 249.92it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 42/55 finished.
Train time: 0.45 	 Val time 0.03
Train loss 0.6893 	 Train acc: 0.5475
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 117.81it/s]


>> train 	 Epoch 44/55 	 Batch 49/50 	 Loss 0.6895 	 Running Acc 0.541 	 Total Acc 0.541 	 Avg Batch Time 0.0085
Time: train: 0.43 	 Train loss 0.6895 	 Train acc: 0.5413


7it [00:00, 254.67it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 43/55 finished.
Train time: 0.43 	 Val time 0.03
Train loss 0.6895 	 Train acc: 0.5413
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 118.32it/s]


>> train 	 Epoch 45/55 	 Batch 49/50 	 Loss 0.6894 	 Running Acc 0.547 	 Total Acc 0.547 	 Avg Batch Time 0.0085
Time: train: 0.42 	 Train loss 0.6894 	 Train acc: 0.5475


7it [00:00, 247.82it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 44/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6894 	 Train acc: 0.5475
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 121.30it/s]


>> train 	 Epoch 46/55 	 Batch 49/50 	 Loss 0.6888 	 Running Acc 0.557 	 Total Acc 0.557 	 Avg Batch Time 0.0083
Time: train: 0.41 	 Train loss 0.6888 	 Train acc: 0.5575


7it [00:00, 248.56it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 45/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6888 	 Train acc: 0.5575
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 124.30it/s]


>> train 	 Epoch 47/55 	 Batch 49/50 	 Loss 0.6885 	 Running Acc 0.562 	 Total Acc 0.562 	 Avg Batch Time 0.0081
Time: train: 0.40 	 Train loss 0.6885 	 Train acc: 0.5625


7it [00:00, 265.73it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0011
Epoch 46/55 finished.
Train time: 0.40 	 Val time 0.03
Train loss 0.6885 	 Train acc: 0.5625
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.54it/s]


>> train 	 Epoch 48/55 	 Batch 49/50 	 Loss 0.6891 	 Running Acc 0.545 	 Total Acc 0.545 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6891 	 Train acc: 0.5450


7it [00:00, 237.65it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0013
Epoch 47/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6891 	 Train acc: 0.5450
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.47it/s]


>> train 	 Epoch 49/55 	 Batch 49/50 	 Loss 0.6907 	 Running Acc 0.534 	 Total Acc 0.534 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6907 	 Train acc: 0.5337


7it [00:00, 239.43it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0013
Epoch 48/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6907 	 Train acc: 0.5337
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 119.19it/s]


>> train 	 Epoch 50/55 	 Batch 49/50 	 Loss 0.6889 	 Running Acc 0.545 	 Total Acc 0.545 	 Avg Batch Time 0.0084
Time: train: 0.42 	 Train loss 0.6889 	 Train acc: 0.5450


7it [00:00, 245.76it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 49/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6889 	 Train acc: 0.5450
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.62it/s]


>> train 	 Epoch 51/55 	 Batch 49/50 	 Loss 0.6881 	 Running Acc 0.562 	 Total Acc 0.562 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6881 	 Train acc: 0.5625


7it [00:00, 259.93it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 50/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6881 	 Train acc: 0.5625
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.66it/s]


>> train 	 Epoch 52/55 	 Batch 49/50 	 Loss 0.6898 	 Running Acc 0.531 	 Total Acc 0.531 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6898 	 Train acc: 0.5312


7it [00:00, 222.09it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0013
Epoch 51/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6898 	 Train acc: 0.5312
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 122.10it/s]


>> train 	 Epoch 53/55 	 Batch 49/50 	 Loss 0.6894 	 Running Acc 0.541 	 Total Acc 0.541 	 Avg Batch Time 0.0082
Time: train: 0.41 	 Train loss 0.6894 	 Train acc: 0.5413


7it [00:00, 259.63it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 52/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6894 	 Train acc: 0.5413
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 120.38it/s]


>> train 	 Epoch 54/55 	 Batch 49/50 	 Loss 0.6891 	 Running Acc 0.554 	 Total Acc 0.554 	 Avg Batch Time 0.0084
Time: train: 0.42 	 Train loss 0.6891 	 Train acc: 0.5537


7it [00:00, 246.71it/s]


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0012
Epoch 53/55 finished.
Train time: 0.42 	 Val time 0.03
Train loss 0.6891 	 Train acc: 0.5537
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


50it [00:00, 123.61it/s]


>> train 	 Epoch 55/55 	 Batch 49/50 	 Loss 0.6891 	 Running Acc 0.556 	 Total Acc 0.556 	 Avg Batch Time 0.0081
Time: train: 0.41 	 Train loss 0.6891 	 Train acc: 0.5563


7it [00:00, 230.45it/s]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.6929 	 Running Acc 2.000 	 Total Acc 0.560 	 Avg Batch Time 0.0013
Epoch 54/55 finished.
Train time: 0.41 	 Val time 0.03
Train loss 0.6891 	 Train acc: 0.5563
Val loss: 0.6921 	 Val acc: 0.5600
Best val acc: 0.5700 at epoch 17.


7it [00:00, 217.48it/s]

>> test 	 Loss 0.6906 	 Running Acc 2.179 	 Total Acc 0.610 	 Avg Batch Time 0.0014
Final  tensor([[1.0000, 0.4909, 0.5091],
        [0.0000, 0.5016, 0.4984],
        [1.0000, 0.4915, 0.5085],
        [0.0000, 0.5023, 0.4977],
        [1.0000, 0.5035, 0.4965],
        [0.0000, 0.4963, 0.5037],
        [0.0000, 0.5004, 0.4996],
        [0.0000, 0.5018, 0.4982],
        [1.0000, 0.5002, 0.4998],
        [0.0000, 0.5012, 0.4988],
        [1.0000, 0.5012, 0.4988],
        [1.0000, 0.4853, 0.5147],
        [0.0000, 0.4988, 0.5012],
        [0.0000, 0.4949, 0.5051],
        [1.0000, 0.4980, 0.5020],
        [1.0000, 0.4976, 0.5024],
        [1.0000, 0.4938, 0.5062],
        [1.0000, 0.4930, 0.5070],
        [1.0000, 0.4989, 0.5011],
        [1.0000, 0.5043, 0.4957],
        [1.0000, 0.4870, 0.5130],
        [0.0000, 0.5018, 0.4982],
        [1.0000, 0.4949, 0.5051],
        [0.0000, 0.5002, 0.4998],
        [1.0000, 0.5017, 0.4983],
        [1.0000, 0.4933, 0.5067],
        [1.0000, 0.5019, 




# 4. Proposed


In [19]:
import torch
import pennylane as qml
import torch.nn.functional as F
from torch import nn
from torch_geometric.utils import to_dense_adj

n_qubits = 4

dev = qml.device('default.qubit', wires=n_qubits)
# dev = qml.device("qiskit.aer", wires=n_qubits)


def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)

def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)

def RY_RX_layer(weights):
    """Applies a layer of parametrized RY and RX rotations."""
    for i, w in enumerate(weights):
        qml.RY(w, wires=i)
        qml.RX(w, wires=i)

def full_entangling_layer(n_qubits):
    """Applies CNOT gates between all pairs of qubits."""
    for i in range(n_qubits):
        for j in range(i+1, n_qubits):
            qml.CNOT(wires=[i, j])

def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(nqubits - 1):
        qml.CRZ(np.pi / 2, wires=[i, i + 1])
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.SWAP(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.SWAP(wires=[i, i + 1])


@qml.qnode(dev, interface="torch")
def quantum_net(q_input_features, q_weights_flat, q_depth, n_qubits):
    """
    The variational quantum circuit.
    """

    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    # RY_layer(q_input_features)
    qml.AngleEmbedding(features=q_input_features, wires=range(n_qubits), rotation='Z')

    # Sequence of trainable variational layers
    # for k in range(q_depth):
    #     entangling_layer(n_qubits)
    #     RY_RX_layer(q_weights[k])
    #     # RY_layer(q_weights[k])
    for k in range(q_depth):
        if k % 2 == 0:
            entangling_layer(n_qubits)
            RY_layer(q_weights[k])
        else:
            full_entangling_layer(n_qubits)
            RY_RX_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self, n_qubits, q_depth = 1, q_delta=0.001):
        """
        Definition of the *dressed* layout.
        """
        print('n_qubits: ', n_qubits)
        super().__init__()
        self.n_qubits = n_qubits
        self.q_depth = q_depth
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))

    def forward(self, input_features):
        """
        Optimized forward pass to reduce runtime.
        """

        # Quantum Embedding (U(X))
        q_in = torch.tanh(input_features) * np.pi / 2.0

        # Preallocate output tensor
        batch_size = q_in.shape[0]
        q_out = torch.zeros(batch_size, self.n_qubits, device=q_in.device)

        # Vectorized execution
        for i, elem in enumerate(q_in):
            q_out_elem = torch.hstack(quantum_net(elem, self.q_params, self.q_depth, self.n_qubits)).float()
            q_out[i] = q_out_elem

        return q_out

In [20]:
# @title
import torch
from torch import nn
import numpy as np
import pennylane as qml

"""
    Lorentz-Equivariant Quantum Block (LEQB).

        - Given the Lie generators found (i.e.: through LieGAN, oracle-preserving latent flow, or some other approach
          that we develop further), once the metric tensor J is found via the equation:

                          L.J + J.(L^T) = 0,

          we just have to specify the metric to make the model symmetry-preserving to the corresponding Lie group.
          In the cells below, we can see how the model preserves symmetries (starting with the default Lorentz group),
          and when we change J to some other metric (Euclidean, for example), Lorentz boosts **break** equivariance, while other
          transformations preserve it (rotations, for the example shown in the cells below)
"""
class LEQB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LEQB, self).__init__()
        self.c_weight = c_weight
        self.dimension_reducer = nn.Linear(10, 4) # New linear layer for dimension reduction
        self.dimension_reducer2 = nn.Linear(9, 4) # New linear layer for dimension reduction for phi_h
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2. n_input = n_hidden = 4
        print('Input size of phi_e: ', n_input)
        self.include_x = include_x

        """
            phi_e: input size: n_qubits -> output size: n_qubits
            n_hidden has to be equal to n_input,
            but this is just considering that this is a simple working example.
        """
        self.phi_e = DressedQuantumNet(n_input)
#         self.phi_e = nn.Sequential(
#             nn.Linear(n_input, n_hidden, bias=False),  # n_input * 2 + n_edge_attr
#             nn.BatchNorm1d(n_hidden),
#             nn.ReLU(),
#             nn.Linear(n_hidden, n_hidden),
#             nn.ReLU())

        n_hidden = n_input # n_input * 2 + n_edge_attr
        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

#         self.phi_h = DressedQuantumNet(n_hidden)

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

#         self.phi_x = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             layer)

#         self.phi_m = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             nn.Linear(n_hidden, 1),
#             nn.Sigmoid())

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        # self.phi_e = nn.Sequential(
        #     nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4

    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        out = self.dimension_reducer(out) # extra
        # print("Before embedding to |psi> : ", out)
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.dimension_reducer(out) # extra
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        #agg = self.dimension_reducer2(agg) # extra for phi_h
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m):
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            x = self.x_model(x, edges, x_diff, m)
        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LieEQGNN(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LEQB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LieEQGNN, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LEQBs = nn.ModuleList([LEQB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LEQBs[i](h, x, edges, node_attr=scalars)

        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)

In [22]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

if __name__ == "__main__":

    N_EPOCHS = 55 # 60

    model_path = "models/LieEQGNN/"
    log_path = "logs/LieEQGNN/"
    # utils_lorentz.args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cpu
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cuda")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    model = LieEQGNN(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)



Input size of phi_e:  4
n_qubits:  4
Model Size: 199
 train samples: 800
 val samples: 100
 test samples: 100
Training...


50it [04:29,  5.40s/it]


>> train 	 Epoch 1/55 	 Batch 49/50 	 Loss 0.7467 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 5.3992
Time: train: 269.96 	 Train loss 0.7467 	 Train acc: 0.5025


7it [00:20,  2.97s/it]


>> val 	 Loss 0.7514 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8314
New best validation model, saving...
Epoch 0/55 finished.
Train time: 269.96 	 Val time 20.79
Train loss 0.7467 	 Train acc: 0.5025
Val loss: 0.7524 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:22,  5.24s/it]


>> train 	 Epoch 2/55 	 Batch 49/50 	 Loss 0.7334 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 5.2447
Time: train: 262.23 	 Train loss 0.7334 	 Train acc: 0.5025


7it [00:20,  2.94s/it]


>> val 	 Loss 0.7398 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8246
Epoch 1/55 finished.
Train time: 262.23 	 Val time 20.61
Train loss 0.7334 	 Train acc: 0.5025
Val loss: 0.7408 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:22,  5.25s/it]


>> train 	 Epoch 3/55 	 Batch 49/50 	 Loss 0.7219 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 5.2490
Time: train: 262.45 	 Train loss 0.7219 	 Train acc: 0.5025


7it [00:20,  2.95s/it]


>> val 	 Loss 0.7249 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8251
Epoch 2/55 finished.
Train time: 262.45 	 Val time 20.63
Train loss 0.7219 	 Train acc: 0.5025
Val loss: 0.7258 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:26,  5.33s/it]


>> train 	 Epoch 4/55 	 Batch 49/50 	 Loss 0.7088 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 5.3303
Time: train: 266.52 	 Train loss 0.7088 	 Train acc: 0.5025


7it [00:20,  2.98s/it]


>> val 	 Loss 0.7112 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8331
Epoch 3/55 finished.
Train time: 266.52 	 Val time 20.83
Train loss 0.7088 	 Train acc: 0.5025
Val loss: 0.7121 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:25,  5.31s/it]


>> train 	 Epoch 5/55 	 Batch 49/50 	 Loss 0.7025 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 5.3118
Time: train: 265.59 	 Train loss 0.7025 	 Train acc: 0.5025


7it [00:21,  3.02s/it]


>> val 	 Loss 0.7015 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8456
Epoch 4/55 finished.
Train time: 265.59 	 Val time 21.14
Train loss 0.7025 	 Train acc: 0.5025
Val loss: 0.7024 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:28,  5.37s/it]


>> train 	 Epoch 6/55 	 Batch 49/50 	 Loss 0.6934 	 Running Acc 0.509 	 Total Acc 0.509 	 Avg Batch Time 5.3707
Time: train: 268.53 	 Train loss 0.6934 	 Train acc: 0.5088


7it [00:22,  3.19s/it]


>> val 	 Loss 0.6967 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8933
Epoch 5/55 finished.
Train time: 268.53 	 Val time 22.33
Train loss 0.6934 	 Train acc: 0.5088
Val loss: 0.6975 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:25,  5.31s/it]


>> train 	 Epoch 7/55 	 Batch 49/50 	 Loss 0.6941 	 Running Acc 0.500 	 Total Acc 0.500 	 Avg Batch Time 5.3062
Time: train: 265.31 	 Train loss 0.6941 	 Train acc: 0.5000


7it [00:20,  3.00s/it]


>> val 	 Loss 0.6953 	 Running Acc 1.679 	 Total Acc 0.470 	 Avg Batch Time 0.8388
Epoch 6/55 finished.
Train time: 265.31 	 Val time 20.97
Train loss 0.6941 	 Train acc: 0.5000
Val loss: 0.6961 	 Val acc: 0.4700
Best val acc: 0.4800 at epoch 0.


50it [04:25,  5.30s/it]


>> train 	 Epoch 8/55 	 Batch 49/50 	 Loss 0.6900 	 Running Acc 0.545 	 Total Acc 0.545 	 Avg Batch Time 5.3010
Time: train: 265.05 	 Train loss 0.6900 	 Train acc: 0.5450


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6949 	 Running Acc 1.679 	 Total Acc 0.470 	 Avg Batch Time 0.8294
Epoch 7/55 finished.
Train time: 265.05 	 Val time 20.74
Train loss 0.6900 	 Train acc: 0.5450
Val loss: 0.6958 	 Val acc: 0.4700
Best val acc: 0.4800 at epoch 0.


50it [04:23,  5.28s/it]


>> train 	 Epoch 9/55 	 Batch 49/50 	 Loss 0.6922 	 Running Acc 0.550 	 Total Acc 0.550 	 Avg Batch Time 5.2797
Time: train: 263.99 	 Train loss 0.6922 	 Train acc: 0.5500


7it [00:20,  2.99s/it]


>> val 	 Loss 0.6934 	 Running Acc 1.714 	 Total Acc 0.480 	 Avg Batch Time 0.8385
Epoch 8/55 finished.
Train time: 263.99 	 Val time 20.96
Train loss 0.6922 	 Train acc: 0.5500
Val loss: 0.6943 	 Val acc: 0.4800
Best val acc: 0.4800 at epoch 0.


50it [04:23,  5.27s/it]


>> train 	 Epoch 10/55 	 Batch 49/50 	 Loss 0.6928 	 Running Acc 0.506 	 Total Acc 0.506 	 Avg Batch Time 5.2705
Time: train: 263.53 	 Train loss 0.6928 	 Train acc: 0.5062


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6929 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 0.8321
New best validation model, saving...
Epoch 9/55 finished.
Train time: 263.53 	 Val time 20.80
Train loss 0.6928 	 Train acc: 0.5062
Val loss: 0.6938 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 9.


50it [04:23,  5.28s/it]


>> train 	 Epoch 11/55 	 Batch 49/50 	 Loss 0.6946 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 5.2779
Time: train: 263.90 	 Train loss 0.6946 	 Train acc: 0.4988


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6928 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 0.8317
Epoch 10/55 finished.
Train time: 263.90 	 Val time 20.79
Train loss 0.6946 	 Train acc: 0.4988
Val loss: 0.6938 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 9.


50it [04:25,  5.31s/it]


>> train 	 Epoch 12/55 	 Batch 49/50 	 Loss 0.6915 	 Running Acc 0.512 	 Total Acc 0.512 	 Avg Batch Time 5.3092
Time: train: 265.46 	 Train loss 0.6915 	 Train acc: 0.5125


7it [00:21,  3.01s/it]


>> val 	 Loss 0.6924 	 Running Acc 1.786 	 Total Acc 0.500 	 Avg Batch Time 0.8431
Epoch 11/55 finished.
Train time: 265.46 	 Val time 21.08
Train loss 0.6915 	 Train acc: 0.5125
Val loss: 0.6933 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 9.


50it [04:24,  5.29s/it]


>> train 	 Epoch 13/55 	 Batch 49/50 	 Loss 0.6919 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 5.2872
Time: train: 264.36 	 Train loss 0.6919 	 Train acc: 0.4988


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6923 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8291
New best validation model, saving...
Epoch 12/55 finished.
Train time: 264.36 	 Val time 20.73
Train loss 0.6919 	 Train acc: 0.4988
Val loss: 0.6932 	 Val acc: 0.5100
Best val acc: 0.5100 at epoch 12.


50it [04:31,  5.43s/it]


>> train 	 Epoch 14/55 	 Batch 49/50 	 Loss 0.6950 	 Running Acc 0.484 	 Total Acc 0.484 	 Avg Batch Time 5.4253
Time: train: 271.27 	 Train loss 0.6950 	 Train acc: 0.4838


7it [00:21,  3.00s/it]


>> val 	 Loss 0.6922 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8409
Epoch 13/55 finished.
Train time: 271.27 	 Val time 21.02
Train loss 0.6950 	 Train acc: 0.4838
Val loss: 0.6932 	 Val acc: 0.5100
Best val acc: 0.5100 at epoch 12.


50it [04:24,  5.29s/it]


>> train 	 Epoch 15/55 	 Batch 49/50 	 Loss 0.6927 	 Running Acc 0.519 	 Total Acc 0.519 	 Avg Batch Time 5.2944
Time: train: 264.72 	 Train loss 0.6927 	 Train acc: 0.5188


7it [00:21,  3.05s/it]


>> val 	 Loss 0.6922 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8553
Epoch 14/55 finished.
Train time: 264.72 	 Val time 21.38
Train loss 0.6927 	 Train acc: 0.5188
Val loss: 0.6932 	 Val acc: 0.5100
Best val acc: 0.5100 at epoch 12.


50it [04:30,  5.41s/it]


>> train 	 Epoch 16/55 	 Batch 49/50 	 Loss 0.6964 	 Running Acc 0.486 	 Total Acc 0.486 	 Avg Batch Time 5.4051
Time: train: 270.26 	 Train loss 0.6964 	 Train acc: 0.4863


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6922 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8301
Epoch 15/55 finished.
Train time: 270.26 	 Val time 20.75
Train loss 0.6964 	 Train acc: 0.4863
Val loss: 0.6932 	 Val acc: 0.5100
Best val acc: 0.5100 at epoch 12.


50it [04:24,  5.29s/it]


>> train 	 Epoch 17/55 	 Batch 49/50 	 Loss 0.6951 	 Running Acc 0.487 	 Total Acc 0.487 	 Avg Batch Time 5.2942
Time: train: 264.71 	 Train loss 0.6951 	 Train acc: 0.4875


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6923 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8301
Epoch 16/55 finished.
Train time: 264.71 	 Val time 20.75
Train loss 0.6951 	 Train acc: 0.4875
Val loss: 0.6933 	 Val acc: 0.5100
Best val acc: 0.5100 at epoch 12.


50it [04:24,  5.29s/it]


>> train 	 Epoch 18/55 	 Batch 49/50 	 Loss 0.6962 	 Running Acc 0.471 	 Total Acc 0.471 	 Avg Batch Time 5.2896
Time: train: 264.48 	 Train loss 0.6962 	 Train acc: 0.4713


7it [00:21,  3.01s/it]


>> val 	 Loss 0.6924 	 Running Acc 1.893 	 Total Acc 0.530 	 Avg Batch Time 0.8437
New best validation model, saving...
Epoch 17/55 finished.
Train time: 264.48 	 Val time 21.09
Train loss 0.6962 	 Train acc: 0.4713
Val loss: 0.6934 	 Val acc: 0.5300
Best val acc: 0.5300 at epoch 17.


50it [04:25,  5.31s/it]


>> train 	 Epoch 19/55 	 Batch 49/50 	 Loss 0.6913 	 Running Acc 0.502 	 Total Acc 0.502 	 Avg Batch Time 5.3071
Time: train: 265.36 	 Train loss 0.6913 	 Train acc: 0.5025


7it [00:20,  2.94s/it]


>> val 	 Loss 0.6921 	 Running Acc 1.857 	 Total Acc 0.520 	 Avg Batch Time 0.8246
Epoch 18/55 finished.
Train time: 265.36 	 Val time 20.61
Train loss 0.6913 	 Train acc: 0.5025
Val loss: 0.6930 	 Val acc: 0.5200
Best val acc: 0.5300 at epoch 17.


50it [04:26,  5.33s/it]


>> train 	 Epoch 20/55 	 Batch 49/50 	 Loss 0.6900 	 Running Acc 0.517 	 Total Acc 0.517 	 Avg Batch Time 5.3279
Time: train: 266.40 	 Train loss 0.6900 	 Train acc: 0.5175


7it [00:20,  3.00s/it]


>> val 	 Loss 0.6917 	 Running Acc 1.893 	 Total Acc 0.530 	 Avg Batch Time 0.8399
Epoch 19/55 finished.
Train time: 266.40 	 Val time 21.00
Train loss 0.6900 	 Train acc: 0.5175
Val loss: 0.6926 	 Val acc: 0.5300
Best val acc: 0.5300 at epoch 17.


50it [04:22,  5.25s/it]


>> train 	 Epoch 21/55 	 Batch 49/50 	 Loss 0.6900 	 Running Acc 0.526 	 Total Acc 0.526 	 Avg Batch Time 5.2468
Time: train: 262.34 	 Train loss 0.6900 	 Train acc: 0.5262


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6916 	 Running Acc 1.929 	 Total Acc 0.540 	 Avg Batch Time 0.8310
New best validation model, saving...
Epoch 20/55 finished.
Train time: 262.34 	 Val time 20.77
Train loss 0.6900 	 Train acc: 0.5262
Val loss: 0.6926 	 Val acc: 0.5400
Best val acc: 0.5400 at epoch 20.


50it [04:26,  5.32s/it]


>> train 	 Epoch 22/55 	 Batch 49/50 	 Loss 0.6916 	 Running Acc 0.505 	 Total Acc 0.505 	 Avg Batch Time 5.3206
Time: train: 266.03 	 Train loss 0.6916 	 Train acc: 0.5050


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6916 	 Running Acc 1.893 	 Total Acc 0.530 	 Avg Batch Time 0.8330
Epoch 21/55 finished.
Train time: 266.03 	 Val time 20.83
Train loss 0.6916 	 Train acc: 0.5050
Val loss: 0.6927 	 Val acc: 0.5300
Best val acc: 0.5400 at epoch 20.


50it [04:24,  5.28s/it]


>> train 	 Epoch 23/55 	 Batch 49/50 	 Loss 0.6937 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 5.2827
Time: train: 264.13 	 Train loss 0.6937 	 Train acc: 0.4988


7it [00:21,  3.04s/it]


>> val 	 Loss 0.6918 	 Running Acc 1.857 	 Total Acc 0.520 	 Avg Batch Time 0.8499
Epoch 22/55 finished.
Train time: 264.13 	 Val time 21.25
Train loss 0.6937 	 Train acc: 0.4988
Val loss: 0.6928 	 Val acc: 0.5200
Best val acc: 0.5400 at epoch 20.


50it [04:30,  5.42s/it]


>> train 	 Epoch 24/55 	 Batch 49/50 	 Loss 0.6891 	 Running Acc 0.512 	 Total Acc 0.512 	 Avg Batch Time 5.4194
Time: train: 270.97 	 Train loss 0.6891 	 Train acc: 0.5125


7it [00:21,  3.01s/it]


>> val 	 Loss 0.6917 	 Running Acc 1.857 	 Total Acc 0.520 	 Avg Batch Time 0.8422
Epoch 23/55 finished.
Train time: 270.97 	 Val time 21.06
Train loss 0.6891 	 Train acc: 0.5125
Val loss: 0.6927 	 Val acc: 0.5200
Best val acc: 0.5400 at epoch 20.


50it [04:27,  5.36s/it]


>> train 	 Epoch 25/55 	 Batch 49/50 	 Loss 0.6940 	 Running Acc 0.494 	 Total Acc 0.494 	 Avg Batch Time 5.3588
Time: train: 267.94 	 Train loss 0.6940 	 Train acc: 0.4938


7it [00:21,  3.09s/it]


>> val 	 Loss 0.6917 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8658
Epoch 24/55 finished.
Train time: 267.94 	 Val time 21.65
Train loss 0.6940 	 Train acc: 0.4938
Val loss: 0.6928 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:31,  5.44s/it]


>> train 	 Epoch 26/55 	 Batch 49/50 	 Loss 0.6899 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 5.4358
Time: train: 271.79 	 Train loss 0.6899 	 Train acc: 0.4988


7it [00:20,  2.99s/it]


>> val 	 Loss 0.6917 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8360
Epoch 25/55 finished.
Train time: 271.79 	 Val time 20.90
Train loss 0.6899 	 Train acc: 0.4988
Val loss: 0.6928 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:25,  5.31s/it]


>> train 	 Epoch 27/55 	 Batch 49/50 	 Loss 0.6892 	 Running Acc 0.522 	 Total Acc 0.522 	 Avg Batch Time 5.3132
Time: train: 265.66 	 Train loss 0.6892 	 Train acc: 0.5225


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6916 	 Running Acc 1.857 	 Total Acc 0.520 	 Avg Batch Time 0.8325
Epoch 26/55 finished.
Train time: 265.66 	 Val time 20.81
Train loss 0.6892 	 Train acc: 0.5225
Val loss: 0.6927 	 Val acc: 0.5200
Best val acc: 0.5400 at epoch 20.


50it [04:26,  5.34s/it]


>> train 	 Epoch 28/55 	 Batch 49/50 	 Loss 0.6896 	 Running Acc 0.517 	 Total Acc 0.517 	 Avg Batch Time 5.3383
Time: train: 266.91 	 Train loss 0.6896 	 Train acc: 0.5175


7it [00:20,  2.95s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8273
Epoch 27/55 finished.
Train time: 266.91 	 Val time 20.68
Train loss 0.6896 	 Train acc: 0.5175
Val loss: 0.6926 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:23,  5.27s/it]


>> train 	 Epoch 29/55 	 Batch 49/50 	 Loss 0.6918 	 Running Acc 0.500 	 Total Acc 0.500 	 Avg Batch Time 5.2707
Time: train: 263.53 	 Train loss 0.6918 	 Train acc: 0.5000


7it [00:20,  2.95s/it]


>> val 	 Loss 0.6916 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8254
Epoch 28/55 finished.
Train time: 263.53 	 Val time 20.64
Train loss 0.6918 	 Train acc: 0.5000
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:26,  5.33s/it]


>> train 	 Epoch 30/55 	 Batch 49/50 	 Loss 0.6919 	 Running Acc 0.500 	 Total Acc 0.500 	 Avg Batch Time 5.3295
Time: train: 266.47 	 Train loss 0.6919 	 Train acc: 0.5000


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8296
Epoch 29/55 finished.
Train time: 266.47 	 Val time 20.74
Train loss 0.6919 	 Train acc: 0.5000
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:22,  5.26s/it]


>> train 	 Epoch 31/55 	 Batch 49/50 	 Loss 0.6876 	 Running Acc 0.525 	 Total Acc 0.525 	 Avg Batch Time 5.2579
Time: train: 262.89 	 Train loss 0.6876 	 Train acc: 0.5250


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8313
Epoch 30/55 finished.
Train time: 262.89 	 Val time 20.78
Train loss 0.6876 	 Train acc: 0.5250
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:26,  5.33s/it]


>> train 	 Epoch 32/55 	 Batch 49/50 	 Loss 0.6950 	 Running Acc 0.480 	 Total Acc 0.480 	 Avg Batch Time 5.3266
Time: train: 266.33 	 Train loss 0.6950 	 Train acc: 0.4800


7it [00:21,  3.09s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8659
Epoch 31/55 finished.
Train time: 266.33 	 Val time 21.65
Train loss 0.6950 	 Train acc: 0.4800
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:23,  5.28s/it]


>> train 	 Epoch 33/55 	 Batch 49/50 	 Loss 0.6934 	 Running Acc 0.484 	 Total Acc 0.484 	 Avg Batch Time 5.2757
Time: train: 263.79 	 Train loss 0.6934 	 Train acc: 0.4838


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8304
Epoch 32/55 finished.
Train time: 263.79 	 Val time 20.76
Train loss 0.6934 	 Train acc: 0.4838
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:26,  5.32s/it]


>> train 	 Epoch 34/55 	 Batch 49/50 	 Loss 0.6865 	 Running Acc 0.554 	 Total Acc 0.554 	 Avg Batch Time 5.3224
Time: train: 266.12 	 Train loss 0.6865 	 Train acc: 0.5537


7it [00:20,  2.94s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8225
Epoch 33/55 finished.
Train time: 266.12 	 Val time 20.56
Train loss 0.6865 	 Train acc: 0.5537
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:22,  5.25s/it]


>> train 	 Epoch 35/55 	 Batch 49/50 	 Loss 0.6957 	 Running Acc 0.468 	 Total Acc 0.468 	 Avg Batch Time 5.2464
Time: train: 262.32 	 Train loss 0.6957 	 Train acc: 0.4675


7it [00:20,  2.94s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8238
Epoch 34/55 finished.
Train time: 262.32 	 Val time 20.60
Train loss 0.6957 	 Train acc: 0.4675
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:25,  5.32s/it]


>> train 	 Epoch 36/55 	 Batch 49/50 	 Loss 0.6918 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 5.3186
Time: train: 265.93 	 Train loss 0.6918 	 Train acc: 0.4988


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8318
Epoch 35/55 finished.
Train time: 265.93 	 Val time 20.79
Train loss 0.6918 	 Train acc: 0.4988
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:24,  5.29s/it]


>> train 	 Epoch 37/55 	 Batch 49/50 	 Loss 0.6950 	 Running Acc 0.489 	 Total Acc 0.489 	 Avg Batch Time 5.2938
Time: train: 264.69 	 Train loss 0.6950 	 Train acc: 0.4888


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8306
Epoch 36/55 finished.
Train time: 264.69 	 Val time 20.77
Train loss 0.6950 	 Train acc: 0.4888
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:27,  5.35s/it]


>> train 	 Epoch 38/55 	 Batch 49/50 	 Loss 0.6939 	 Running Acc 0.489 	 Total Acc 0.489 	 Avg Batch Time 5.3502
Time: train: 267.51 	 Train loss 0.6939 	 Train acc: 0.4888


7it [00:20,  2.98s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8335
Epoch 37/55 finished.
Train time: 267.51 	 Val time 20.84
Train loss 0.6939 	 Train acc: 0.4888
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:25,  5.30s/it]


>> train 	 Epoch 39/55 	 Batch 49/50 	 Loss 0.6899 	 Running Acc 0.510 	 Total Acc 0.510 	 Avg Batch Time 5.3016
Time: train: 265.08 	 Train loss 0.6899 	 Train acc: 0.5100


7it [00:22,  3.22s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.9006
Epoch 38/55 finished.
Train time: 265.08 	 Val time 22.52
Train loss 0.6899 	 Train acc: 0.5100
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:32,  5.44s/it]


>> train 	 Epoch 40/55 	 Batch 49/50 	 Loss 0.6916 	 Running Acc 0.495 	 Total Acc 0.495 	 Avg Batch Time 5.4441
Time: train: 272.20 	 Train loss 0.6916 	 Train acc: 0.4950


7it [00:21,  3.05s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8549
Epoch 39/55 finished.
Train time: 272.20 	 Val time 21.37
Train loss 0.6916 	 Train acc: 0.4950
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:25,  5.31s/it]


>> train 	 Epoch 41/55 	 Batch 49/50 	 Loss 0.6851 	 Running Acc 0.531 	 Total Acc 0.531 	 Avg Batch Time 5.3136
Time: train: 265.68 	 Train loss 0.6851 	 Train acc: 0.5312


7it [00:20,  2.95s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8273
Epoch 40/55 finished.
Train time: 265.68 	 Val time 20.68
Train loss 0.6851 	 Train acc: 0.5312
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:28,  5.38s/it]


>> train 	 Epoch 42/55 	 Batch 49/50 	 Loss 0.6913 	 Running Acc 0.500 	 Total Acc 0.500 	 Avg Batch Time 5.3777
Time: train: 268.88 	 Train loss 0.6913 	 Train acc: 0.5000


7it [00:20,  2.95s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8263
Epoch 41/55 finished.
Train time: 268.88 	 Val time 20.66
Train loss 0.6913 	 Train acc: 0.5000
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:23,  5.26s/it]


>> train 	 Epoch 43/55 	 Batch 49/50 	 Loss 0.6950 	 Running Acc 0.477 	 Total Acc 0.477 	 Avg Batch Time 5.2646
Time: train: 263.23 	 Train loss 0.6950 	 Train acc: 0.4775


7it [00:20,  2.97s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8320
Epoch 42/55 finished.
Train time: 263.23 	 Val time 20.80
Train loss 0.6950 	 Train acc: 0.4775
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:22,  5.25s/it]


>> train 	 Epoch 44/55 	 Batch 49/50 	 Loss 0.6926 	 Running Acc 0.486 	 Total Acc 0.486 	 Avg Batch Time 5.2550
Time: train: 262.75 	 Train loss 0.6926 	 Train acc: 0.4863


7it [00:20,  2.93s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8197
Epoch 43/55 finished.
Train time: 262.75 	 Val time 20.49
Train loss 0.6926 	 Train acc: 0.4863
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:19,  5.20s/it]


>> train 	 Epoch 45/55 	 Batch 49/50 	 Loss 0.6915 	 Running Acc 0.499 	 Total Acc 0.499 	 Avg Batch Time 5.1992
Time: train: 259.96 	 Train loss 0.6915 	 Train acc: 0.4988


7it [00:20,  2.93s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8207
Epoch 44/55 finished.
Train time: 259.96 	 Val time 20.52
Train loss 0.6915 	 Train acc: 0.4988
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:25,  5.31s/it]


>> train 	 Epoch 46/55 	 Batch 49/50 	 Loss 0.6958 	 Running Acc 0.482 	 Total Acc 0.482 	 Avg Batch Time 5.3064
Time: train: 265.32 	 Train loss 0.6958 	 Train acc: 0.4825


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8287
Epoch 45/55 finished.
Train time: 265.32 	 Val time 20.72
Train loss 0.6958 	 Train acc: 0.4825
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:21,  5.23s/it]


>> train 	 Epoch 47/55 	 Batch 49/50 	 Loss 0.6924 	 Running Acc 0.497 	 Total Acc 0.497 	 Avg Batch Time 5.2261
Time: train: 261.30 	 Train loss 0.6924 	 Train acc: 0.4975


7it [00:20,  2.94s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8240
Epoch 46/55 finished.
Train time: 261.30 	 Val time 20.60
Train loss 0.6924 	 Train acc: 0.4975
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:22,  5.26s/it]


>> train 	 Epoch 48/55 	 Batch 49/50 	 Loss 0.6955 	 Running Acc 0.474 	 Total Acc 0.474 	 Avg Batch Time 5.2597
Time: train: 262.98 	 Train loss 0.6955 	 Train acc: 0.4738


7it [00:20,  2.94s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8235
Epoch 47/55 finished.
Train time: 262.98 	 Val time 20.59
Train loss 0.6955 	 Train acc: 0.4738
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:21,  5.22s/it]


>> train 	 Epoch 49/55 	 Batch 49/50 	 Loss 0.6903 	 Running Acc 0.515 	 Total Acc 0.515 	 Avg Batch Time 5.2244
Time: train: 261.22 	 Train loss 0.6903 	 Train acc: 0.5150


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8285
Epoch 48/55 finished.
Train time: 261.22 	 Val time 20.71
Train loss 0.6903 	 Train acc: 0.5150
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:21,  5.23s/it]


>> train 	 Epoch 50/55 	 Batch 49/50 	 Loss 0.6938 	 Running Acc 0.491 	 Total Acc 0.491 	 Avg Batch Time 5.2290
Time: train: 261.45 	 Train loss 0.6938 	 Train acc: 0.4913


7it [00:20,  2.98s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8357
Epoch 49/55 finished.
Train time: 261.45 	 Val time 20.89
Train loss 0.6938 	 Train acc: 0.4913
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:21,  5.24s/it]


>> train 	 Epoch 51/55 	 Batch 49/50 	 Loss 0.6904 	 Running Acc 0.504 	 Total Acc 0.504 	 Avg Batch Time 5.2369
Time: train: 261.84 	 Train loss 0.6904 	 Train acc: 0.5038


7it [00:20,  2.93s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8210
Epoch 50/55 finished.
Train time: 261.84 	 Val time 20.53
Train loss 0.6904 	 Train acc: 0.5038
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:23,  5.28s/it]


>> train 	 Epoch 52/55 	 Batch 49/50 	 Loss 0.6912 	 Running Acc 0.505 	 Total Acc 0.505 	 Avg Batch Time 5.2757
Time: train: 263.78 	 Train loss 0.6912 	 Train acc: 0.5050


7it [00:20,  2.96s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8295
Epoch 51/55 finished.
Train time: 263.78 	 Val time 20.74
Train loss 0.6912 	 Train acc: 0.5050
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:32,  5.46s/it]


>> train 	 Epoch 53/55 	 Batch 49/50 	 Loss 0.6873 	 Running Acc 0.535 	 Total Acc 0.535 	 Avg Batch Time 5.4580
Time: train: 272.90 	 Train loss 0.6873 	 Train acc: 0.5350


7it [00:20,  2.98s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8341
Epoch 52/55 finished.
Train time: 272.90 	 Val time 20.85
Train loss 0.6873 	 Train acc: 0.5350
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:22,  5.26s/it]


>> train 	 Epoch 54/55 	 Batch 49/50 	 Loss 0.6937 	 Running Acc 0.482 	 Total Acc 0.482 	 Avg Batch Time 5.2576
Time: train: 262.88 	 Train loss 0.6937 	 Train acc: 0.4825


7it [00:20,  2.94s/it]


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8222
Epoch 53/55 finished.
Train time: 262.88 	 Val time 20.55
Train loss 0.6937 	 Train acc: 0.4825
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


50it [04:23,  5.27s/it]


>> train 	 Epoch 55/55 	 Batch 49/50 	 Loss 0.6928 	 Running Acc 0.497 	 Total Acc 0.497 	 Avg Batch Time 5.2675
Time: train: 263.38 	 Train loss 0.6928 	 Train acc: 0.4975


7it [00:20,  2.97s/it]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.6915 	 Running Acc 1.821 	 Total Acc 0.510 	 Avg Batch Time 0.8318
Epoch 54/55 finished.
Train time: 263.38 	 Val time 20.80
Train loss 0.6928 	 Train acc: 0.4975
Val loss: 0.6927 	 Val acc: 0.5100
Best val acc: 0.5400 at epoch 20.


7it [00:21,  3.09s/it]

>> test 	 Loss 0.6858 	 Running Acc 2.393 	 Total Acc 0.670 	 Avg Batch Time 0.8643
Final  tensor([[1.0000, 0.4797, 0.5203],
        [0.0000, 0.5055, 0.4945],
        [1.0000, 0.4939, 0.5061],
        [0.0000, 0.5174, 0.4826],
        [1.0000, 0.5021, 0.4979],
        [0.0000, 0.5050, 0.4950],
        [0.0000, 0.5051, 0.4949],
        [0.0000, 0.5174, 0.4826],
        [1.0000, 0.4786, 0.5214],
        [0.0000, 0.5150, 0.4850],
        [1.0000, 0.5077, 0.4923],
        [1.0000, 0.4847, 0.5153],
        [0.0000, 0.5096, 0.4904],
        [0.0000, 0.4797, 0.5203],
        [1.0000, 0.5027, 0.4973],
        [1.0000, 0.5009, 0.4991],
        [1.0000, 0.4937, 0.5063],
        [1.0000, 0.4845, 0.5155],
        [1.0000, 0.4988, 0.5012],
        [1.0000, 0.4956, 0.5044],
        [1.0000, 0.4872, 0.5128],
        [0.0000, 0.5051, 0.4949],
        [1.0000, 0.4867, 0.5133],
        [0.0000, 0.5151, 0.4849],
        [1.0000, 0.5123, 0.4877],
        [1.0000, 0.4831, 0.5169],
        [1.0000, 0.5170, 


