# Lie-Equivariant Quantum Graph Neural Network (Lie-EQGNN)

In [1]:
# For Colab
!pip install torch_geometric
# !pip install torch_sparse
# !pip install torch_scatter

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
!pip install pennylane qiskit pennylane-qiskit pylatexenc

Collecting pennylane
  Downloading PennyLane-0.38.0-py3-none-any.whl.metadata (9.3 kB)
Collecting qiskit
  Downloading qiskit-1.2.4-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting pennylane-qiskit
  Downloading PennyLane_qiskit-0.38.1-py3-none-any.whl.metadata (6.4 kB)
Collecting pylatexenc
  Downloading pylatexenc-2.10.tar.gz (162 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.15.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting autograd (from pennylane)
  Downloading autograd-1.7.0-py3-none-any.whl.metadata (7.5 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.38 (from pennylane)
  Downloading PennyLane_Lightning-0.3

In [3]:
import pennylane as qml
import qiskit
print(qml.__version__)
print(qiskit.__version__)
import pennylane_qiskit
print(pennylane_qiskit.__version__)
import pennylane as qml
from pennylane import numpy as np
# from pennylane_qiskit import AerDevice

0.38.0
1.2.4
0.38.1


## Dataset

In [16]:
import torch
import numpy as np
import os
from torchvision import datasets, transforms
from scipy.sparse import coo_matrix

def save_fashion_mnist_tensors(class_indices=[0, 1], num_data_per_class=500, save_dir="fashion_mnist/data"):
    """
    Generate and save tensor data files for Fashion MNIST dataset in a graph-like format.
    """
    assert len(class_indices) == 2, "Please specify exactly 2 class indices"
    os.makedirs(save_dir, exist_ok=True)

    # Load Fashion MNIST data
    fashion_mnist = datasets.FashionMNIST(root='./data', train=True, download=True,
                                        transform=transforms.ToTensor())

    # Select subset of data for specified classes
    indices = []
    labels = []
    for i, class_idx in enumerate(class_indices):
        class_indices_temp = (fashion_mnist.targets == class_idx).nonzero().squeeze()
        selected_indices = class_indices_temp[:num_data_per_class].tolist()
        indices.extend(selected_indices)
        # Create binary labels (0 for first class, 1 for second class)
        labels.extend([i] * num_data_per_class)

    # Convert indices and labels to tensors
    indices = torch.tensor(indices)
    labels = torch.tensor(labels, dtype=torch.float)

    # Shuffle the data
    shuffle_idx = torch.randperm(len(indices))
    indices = indices[shuffle_idx]
    labels = labels[shuffle_idx]

    X = fashion_mnist.data[indices].float() / 255.0  # Normalize to [0,1]

    batch_size = len(X)
    n_nodes = 10  # Maximum number of significant points to extract per image

    def image_to_nodes(image_data, max_nodes=10, threshold=0.1):
        """
        Convert Fashion MNIST images to graph nodes.
        Extracts significant points based on intensity values.
        """
        batch_size = len(image_data)
        nodes = np.zeros((batch_size, max_nodes, 1))  # Single channel feature
        p4s = np.zeros((batch_size, max_nodes, 4))    # Position and feature information
        atom_masks = np.zeros((batch_size, max_nodes), dtype=bool)

        for b in range(batch_size):
            # Find significant points in the image (edges and important features)
            significant_points = np.where(image_data[b] > threshold)
            values = image_data[b][significant_points]

            # Sort points by their intensity and take top max_nodes
            sorted_indices = np.argsort(-values)
            n_points = min(len(sorted_indices), max_nodes)
            selected_indices = sorted_indices[:n_points]

            for idx_pos, idx in enumerate(selected_indices):
                h, w = significant_points[0][idx], significant_points[1][idx]
                intensity = values[idx]

                # Node feature (intensity value)
                nodes[b, idx_pos, 0] = intensity

                # Create p4s (x, y, intensity, 0)
                # Normalize coordinates to [-1, 1] range
                x = (w - 13.5) / 13.5  # Center and normalize
                y = (h - 13.5) / 13.5  # Center and normalize
                p4s[b, idx_pos] = [x, y, intensity, 0]
                atom_masks[b, idx_pos] = True

        return p4s, nodes, atom_masks

    # Convert image data to graph format
    p4s, nodes, atom_mask = image_to_nodes(X.numpy())

    # Convert to torch tensors
    p4s = torch.from_numpy(p4s).float()
    nodes = torch.from_numpy(nodes).float()
    atom_mask = torch.from_numpy(atom_mask)

    # Create edge mask (fully connected graph between valid nodes)
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    edge_mask = edge_mask * diag_mask

    # Calculate edges
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx * n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)
    edges = np.stack([rows, cols])

    # Save tensors
    torch.save(p4s, os.path.join(save_dir, "p4s.pt"))
    torch.save(nodes, os.path.join(save_dir, "nodes.pt"))
    torch.save(labels, os.path.join(save_dir, "labels.pt"))
    torch.save(atom_mask, os.path.join(save_dir, "atom_mask.pt"))
    np.save(os.path.join(save_dir, "edge_mask.npy"), edge_mask.numpy())
    np.save(os.path.join(save_dir, "edges.npy"), edges)

    print(f"Saved tensor files to {save_dir}")
    print(f"Classes used: {class_indices}")
    print(f"Shapes:")
    print(f"p4s: {p4s.shape}")
    print(f"nodes: {nodes.shape}")
    print(f"labels: {labels.shape}")
    print(f"atom_mask: {atom_mask.shape}")
    print(f"edge_mask: {edge_mask.shape}")
    print(f"edges: {edges.shape}")

    # Print label distribution
    unique_labels, counts = torch.unique(labels, return_counts=True)
    print("\nLabel distribution:")
    for label, count in zip(unique_labels, counts):
        print(f"Class {label.item()}: {count.item()} samples")

# Generate data for T-shirts (0) vs Trousers (1)
save_fashion_mnist_tensors(
    class_indices=[0, 1],  # T-shirt/top vs Trouser
    num_data_per_class=5000,  # 500 samples per class = 1000 total
    save_dir="fashion_mnist/data"
)

Saved tensor files to fashion_mnist/data
Classes used: [0, 1]
Shapes:
p4s: torch.Size([10000, 10, 4])
nodes: torch.Size([10000, 10, 1])
labels: torch.Size([10000])
atom_mask: torch.Size([10000, 10])
edge_mask: torch.Size([10000, 10, 10])
edges: (2, 900000)

Label distribution:
Class 0.0: 5000 samples
Class 1.0: 5000 samples


In [17]:
# from torch.utils.data import TensorDataset, random_split
# import torch
# import numpy as np
# from torch.utils.data import TensorDataset, DataLoader
# from scipy.sparse import coo_matrix
# import h5py

def get_adj_matrix(n_nodes, batch_size, edge_mask):
    rows, cols = [], []
    for batch_idx in range(batch_size):
        nn = batch_idx*n_nodes
        x = coo_matrix(edge_mask[batch_idx])
        rows.append(nn + x.row)
        cols.append(nn + x.col)
    rows = np.concatenate(rows)
    cols = np.concatenate(cols)

    edges = [torch.LongTensor(rows), torch.LongTensor(cols)]
    return edges

def collate_fn(data):
    data = list(zip(*data)) # label p4s nodes atom_mask
    data = [torch.stack(item) for item in data]
    batch_size, n_nodes, _ = data[1].size()
    atom_mask = data[-1]
    # edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    # diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
    # edge_mask *= diag_mask

    edge_mask = data[-2]

    edges = get_adj_matrix(n_nodes, batch_size, edge_mask)
    return data + [edges]

In [18]:
from torch.utils.data import TensorDataset, DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import torch

def create_stratified_split(dataset, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_state=42):
    """
    Creates stratified train/val/test splits for binary classification.
    """
    # Extract labels and convert to integer labels for stratification
    labels = dataset.tensors[0].numpy()
    labels_int = labels.astype(int)  # Convert to int for stratification

    # Print initial distribution
    unique_classes, class_counts = np.unique(labels_int, return_counts=True)
    print("\nInitial class distribution:")
    for class_idx, count in zip(unique_classes, class_counts):
        print(f"Class {class_idx}: {count} samples ({count/len(labels)*100:.2f}%)")

    # First split: separate train from temporary val+test
    first_sss = StratifiedShuffleSplit(n_splits=1,
                                      train_size=train_ratio,
                                      random_state=random_state)

    train_idx, temp_idx = next(first_sss.split(np.zeros(len(labels)), labels_int))

    # Second split: separate val and test from temporary set
    temp_labels = labels_int[temp_idx]
    val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)

    second_sss = StratifiedShuffleSplit(n_splits=1,
                                       train_size=val_ratio_adjusted,
                                       random_state=random_state)

    val_idx_temp, test_idx_temp = next(second_sss.split(np.zeros(len(temp_labels)), temp_labels))

    # Convert temporary indices to original dataset indices
    val_idx = temp_idx[val_idx_temp]
    test_idx = temp_idx[test_idx_temp]

    # Create subset datasets
    train_dataset = Subset(dataset, train_idx)
    val_dataset = Subset(dataset, val_idx)
    test_dataset = Subset(dataset, test_idx)

    # Verify class distribution in each split
    def get_split_distribution(indices):
        split_labels = labels_int[indices]
        unique, counts = np.unique(split_labels, return_counts=True)
        return dict(zip(unique, counts / len(indices)))

    print("\nClass distribution in splits:")
    print("Train:", get_split_distribution(train_idx))
    print("Val:", get_split_distribution(val_idx))
    print("Test:", get_split_distribution(test_idx))

    return {
        "train": train_dataset,
        "val": val_dataset,
        "test": test_dataset
    }

# Let's reload the data and ensure we have binary labels
labels = torch.load('fashion_mnist/data/labels.pt')
p4s = torch.load('fashion_mnist/data/p4s.pt')
nodes = torch.load('fashion_mnist/data/nodes.pt')
atom_mask = torch.load('fashion_mnist/data/atom_mask.pt')
edge_mask = torch.from_numpy(np.load('fashion_mnist/data/edge_mask.npy'))

# Print label statistics before creating dataset
print("Label Statistics:")
print("Shape:", labels.shape)
print("Unique values:", torch.unique(labels))
print("Class distribution:", torch.bincount(labels.long()) / len(labels))

# Create the dataset
dataset_all = TensorDataset(labels, p4s, nodes, atom_mask, edge_mask)

# Create stratified splits
datasets = create_stratified_split(dataset_all)

# Create dataloaders
dataloaders = {
    split: DataLoader(
        dataset,
        batch_size=16,
        pin_memory=False,
        collate_fn=collate_fn,
        drop_last=True if (split == 'train') else False,
        num_workers=0,
        shuffle=(split == 'train')
    )
    for split, dataset in datasets.items()
}

Label Statistics:
Shape: torch.Size([10000])
Unique values: tensor([0., 1.])
Class distribution: tensor([0.5000, 0.5000])

Initial class distribution:
Class 0: 5000 samples (50.00%)
Class 1: 5000 samples (50.00%)

Class distribution in splits:
Train: {0: 0.5, 1: 0.5}
Val: {0: 0.5, 1: 0.5}
Test: {0: 0.5, 1: 0.5}


  labels = torch.load('fashion_mnist/data/labels.pt')
  p4s = torch.load('fashion_mnist/data/p4s.pt')
  nodes = torch.load('fashion_mnist/data/nodes.pt')
  atom_mask = torch.load('fashion_mnist/data/atom_mask.pt')


In [19]:
# Set desired dimensions
batch_size = 1
n_nodes = 3
device = 'cpu'
dtype = torch.float32

# Print initial shapes
print("Initial shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Select subset of data
p4s = p4s[:batch_size, :n_nodes, :]
atom_mask = atom_mask[:batch_size, :n_nodes]
edge_mask = edge_mask[:batch_size, :n_nodes, :n_nodes]
nodes = nodes[:batch_size, :n_nodes, :]

print("\nAfter selection shapes:")
print("p4s:", p4s.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)

# Reshape tensors
atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device, dtype)
# Don't reshape edge_mask yet
nodes = nodes.view(batch_size * n_nodes, -1).to(device, dtype)

print("\nAfter reshape shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)  # original shape
print("nodes:", nodes.shape)

# Recalculate edges for the subset
from scipy.sparse import coo_matrix
rows, cols = [], []
for batch_idx in range(batch_size):
    nn = batch_idx * n_nodes
    # Convert edge_mask to numpy and remove any extra dimensions
    edge_mask_np = edge_mask[batch_idx].cpu().numpy().squeeze()
    x = coo_matrix(edge_mask_np)
    rows.append(nn + x.row)
    cols.append(nn + x.col)

edges = [torch.LongTensor(np.concatenate(rows)).to(device),
         torch.LongTensor(np.concatenate(cols)).to(device)]

# Now reshape edge_mask after edges are calculated
edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)

print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])

Initial shapes:
p4s: torch.Size([10000, 10, 4])
atom_mask: torch.Size([10000, 10])
edge_mask: torch.Size([10000, 10, 10])
nodes: torch.Size([10000, 10, 1])

After selection shapes:
p4s: torch.Size([1, 3, 4])
atom_mask: torch.Size([1, 3])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([1, 3, 1])

After reshape shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([1, 3, 3])
nodes: torch.Size([3, 1])

Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


In [20]:
print("\nFinal shapes:")
print("atom_positions:", atom_positions.shape)
print("atom_mask:", atom_mask.shape)
print("edge_mask:", edge_mask.shape)
print("nodes:", nodes.shape)
print("edges:", [e.shape for e in edges])


Final shapes:
atom_positions: torch.Size([3, 4])
atom_mask: torch.Size([3, 1])
edge_mask: torch.Size([9, 1])
nodes: torch.Size([3, 1])
edges: [torch.Size([6]), torch.Size([6])]


# 3. LorentzNet

In [21]:
# @title
import torch
from torch import nn
import numpy as np



"""Some auxiliary functions"""

def unsorted_segment_sum(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    return result

def unsorted_segment_mean(data, segment_ids, num_segments):
    r'''Custom PyTorch op to replicate TensorFlow's `unsorted_segment_mean`.
    Adapted from https://github.com/vgsatorras/egnn.
    '''
    result = data.new_zeros((num_segments, data.size(1)))
    count = data.new_zeros((num_segments, data.size(1)))
    result.index_add_(0, segment_ids, data)
    count.index_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)

def normsq4(p):
    r''' Minkowski square norm
         `\|p\|^2 = p[0]^2-p[1]^2-p[2]^2-p[3]^2`
    '''
    psq = torch.pow(p, 2)
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def dotsq4(p,q):
    r''' Minkowski inner product
         `<p,q> = p[0]q[0]-p[1]q[1]-p[2]q[2]-p[3]q[3]`
    '''
    psq = p*q
    return 2 * psq[..., 0] - psq.sum(dim=-1)

def normA_fn(A):
    return lambda p: torch.einsum('...i, ij, ...j->...', p, A, p)

def dotA_fn(A):
    return lambda p, q: torch.einsum('...i, ij, ...j->...', p, A, q)

def psi(p):
    ''' `\psi(p) = Sgn(p) \cdot \log(|p| + 1)`
    '''
    return torch.sign(p) * torch.log(torch.abs(p) + 1)


"""Lorentz Group-Equivariant Block"""

class LGEB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LGEB, self).__init__()
        self.c_weight = c_weight
        self.dimension_reducer = nn.Linear(10, 4) # New linear layer for dimension reduction
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2.
        print('Input size of phi_e: ', n_input)

        self.include_x = include_x
        self.phi_e = nn.Sequential(
            nn.Linear(n_input, n_hidden, bias=False), # n_input * 2 + n_edge_attr
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU())

        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4


    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        # Reduce the dimension of 'out' to 4 using a linear layer
        out = self.dimension_reducer(out)
        out = self.phi_e(out)
        # print("m_model output: ", out.shape)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.phi_e(out)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m): # norms
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # print("m: ", m.shape)
        # print("trans: ", trans.shape)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        # print("trans: ", trans.shape)
        # print("x.size: ", x.size(0))
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight # * norms[i, j], smth like that, or norms
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            # print("X: ", x)
            x = self.x_model(x, edges, x_diff, m)
            # print("phi_x(X) = ", x, '\n---\n')

        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LorentzNet(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LGEB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LorentzNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LGEBs = nn.ModuleList([LGEB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LGEBs[i](h, x, edges, node_attr=scalars)
        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)

        # print("Final preds: \n", pred.cpu().detach().numpy())
        return pred.squeeze(1)

LGEB(self.n_hidden, self.n_hidden, self.n_hidden,\
                                    n_node_attr=n_scalar, dropout=dropout,\
                                    c_weight=c_weight, last_layer=\(i==n_layers-1), A=A, include_x=include_x)
                                    
We are using n_hidden = 4 and n_layers = 6

n_input=n_hidden, n_output=n_hidden, n_hidden=n_hidden, n_node_attr=n_scalar=8

In [22]:
# @title
import torch
import os, json, random, string
import torch.distributed as dist

def makedir(path):
    try:
        os.makedirs(path)
    except OSError:
        pass

def args_init(args):
    r''' Initialize seed and exp_name.
    '''
    if args.seed is None: # use random seed if not specified
        args.seed = np.random.randint(100)
    if args.exp_name == '': # use random strings if not specified
        args.exp_name = ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
    if (args.local_rank == 0): # master
        print(args)
        makedir(f"{args.logdir}/{args.exp_name}")
        with open(f"{args.logdir}/{args.exp_name}/args.json", 'w') as f:
            json.dump(args.__dict__, f, indent=4)

def sum_reduce(num, device):
    r''' Sum the tensor across the devices.
    '''
    if not torch.is_tensor(num):
        rt = torch.tensor(num).to(device)
    else:
        rt = num.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    return rt

from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        warmup_epoch: target learning rate is reached at warmup_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    Reference:
        https://github.com/ildoonet/pytorch-gradual-warmup-lr
    """

    def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.warmup_epoch = warmup_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    @property
    def _warmup_lr(self):
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch + 1) / self.warmup_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * (self.last_epoch + 1) / self.warmup_epoch + 1.) for base_lr in self.base_lrs]

    def get_lr(self):
        if self.last_epoch >= self.warmup_epoch - 1:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        return self._warmup_lr

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        self.last_epoch = self.last_epoch + 1 if epoch==None else epoch
        if self.last_epoch >= self.warmup_epoch - 1:
            if not self.finished:
                warmup_lr = [base_lr * self.multiplier for base_lr in self.base_lrs]
                for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                    param_group['lr'] = lr
                self.finished = True
                return
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.warmup_epoch)
            return

        for param_group, lr in zip(self.optimizer.param_groups, self._warmup_lr):
            param_group['lr'] = lr

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.warmup_epoch)
                self.last_epoch = self.after_scheduler.last_epoch + self.warmup_epoch + 1
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        result = {key: value for key, value in self.__dict__.items() if key != 'optimizer' or key != "after_scheduler"}
        if self.after_scheduler:
            result.update({"after_scheduler": self.after_scheduler.state_dict()})
        return result

    def load_state_dict(self, state_dict):
        after_scheduler_state = state_dict.pop("after_scheduler", None)
        self.__dict__.update(state_dict)
        if after_scheduler_state:
            self.after_scheduler.load_state_dict(after_scheduler_state)


from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np

def buildROC(labels, score, targetEff=[0.3,0.5]):
    r''' ROC curve is a plot of the true positive rate (Sensitivity) in the function of the false positive rate
    (100-Specificity) for different cut-off points of a parameter. Each point on the ROC curve represents a
    sensitivity/specificity pair corresponding to a particular decision threshold. The Area Under the ROC
    curve (AUC) is a measure of how well a parameter can distinguish between two diagnostic groups.
    '''
    if not isinstance(targetEff, list):
        targetEff = [targetEff]
    fpr, tpr, threshold = roc_curve(labels, score)
    idx = [np.argmin(np.abs(tpr - Eff)) for Eff in targetEff]
    eB, eS = fpr[idx], tpr[idx]
    return fpr, tpr, threshold, eB, eS

In [23]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from tqdm import tqdm

def run(model, epoch, loader, partition, N_EPOCHS=None):
    if partition == 'train':
        model.train()
    else:
        model.eval()

    res = {'time':0, 'correct':0, 'loss': 0, 'counter': 0, 'acc': 0,
           'loss_arr':[], 'correct_arr':[],'label':[],'score':[]}

    tik = time.time()
    loader_length = len(loader)

    for i, (label, p4s, nodes, atom_mask, edge_mask, edges) in tqdm(enumerate(loader)):
        if partition == 'train':
            optimizer.zero_grad()

        batch_size, n_nodes, _ = p4s.size()
        atom_positions = p4s.view(batch_size * n_nodes, -1).to(device, dtype)
        atom_mask = atom_mask.view(batch_size * n_nodes, -1).to(device)
        edge_mask = edge_mask.reshape(batch_size * n_nodes * n_nodes, -1).to(device)
        nodes = nodes.view(batch_size * n_nodes, -1).to(device,dtype)
        edges = [a.to(device) for a in edges]
        label = label.to(device, dtype).long()

        pred = model(scalars=nodes, x=atom_positions, edges=edges, node_mask=atom_mask,
                         edge_mask=edge_mask, n_nodes=n_nodes)

        predict = pred.max(1).indices
        correct = torch.sum(predict == label).item()
        # print(pred.shape,label.shape)
        loss = loss_fn(pred, label)

        if partition == 'train':
            loss.backward()
            optimizer.step()
        elif partition == 'test':
            # save labels and probilities for ROC / AUC
            # print("Preds ", pred)
            score = torch.nn.functional.softmax(pred, dim = -1)
            # print("Score test ", score)
            # raise
            res['label'].append(label)
            res['score'].append(score)

        res['time'] = time.time() - tik
        res['correct'] += correct
        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        res['loss_arr'].append(loss.item())
        res['correct_arr'].append(correct)

        # if i != 0 and i % args.log_interval == 0:

    running_loss = sum(res['loss_arr'])/len(res['loss_arr'])
    running_acc = sum(res['correct_arr'])/(len(res['correct_arr'])*batch_size)
    avg_time = res['time']/res['counter'] * batch_size
    tmp_counter = res['counter']
    tmp_loss = res['loss'] / tmp_counter
    tmp_acc = res['correct'] / tmp_counter

    if N_EPOCHS:
        print(">> %s \t Epoch %d/%d \t Batch %d/%d \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
             (partition, epoch + 1, N_EPOCHS, i, loader_length, running_loss, running_acc, tmp_acc, avg_time))
    else:
        print(">> %s \t Loss %.4f \t Running Acc %.3f \t Total Acc %.3f \t Avg Batch Time %.4f" %
             (partition, running_loss, running_acc, tmp_acc, avg_time))

    torch.cuda.empty_cache()
    # ---------- reduce -----------
    if partition == 'test':
        res['label'] = torch.cat(res['label']).unsqueeze(-1)
        res['score'] = torch.cat(res['score'])
        res['score'] = torch.cat((res['label'],res['score']),dim=-1)
    res['counter'] = res['counter']
    res['loss'] = res['loss'] / res['counter']
    res['acc'] = res['correct'] / res['counter']
    return res

def train(model, res, N_EPOCHS, model_path, log_path):
    ### training and validation
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(log_path, exist_ok=True)

    for epoch in range(N_EPOCHS):
        train_res = run(model, epoch, dataloaders['train'], partition='train', N_EPOCHS = N_EPOCHS)
        print("Time: train: %.2f \t Train loss %.4f \t Train acc: %.4f" % (train_res['time'],train_res['loss'],train_res['acc']))
        # if epoch % args.val_interval == 0:

        # if (args.local_rank == 0):
        torch.save(model.state_dict(), os.path.join(model_path, "checkpoint-epoch-{}.pt".format(epoch)) )
        with torch.no_grad():
            val_res = run(model, epoch, dataloaders['val'], partition='val')

        # if (args.local_rank == 0): # only master process save
        res['lr'].append(optimizer.param_groups[0]['lr'])
        res['train_time'].append(train_res['time'])
        res['val_time'].append(val_res['time'])
        res['train_loss'].append(train_res['loss'])
        res['train_acc'].append(train_res['acc'])
        res['val_loss'].append(val_res['loss'])
        res['val_acc'].append(val_res['acc'])
        res['epochs'].append(epoch)

        ## save best model
        if val_res['acc'] > res['best_val']:
            print("New best validation model, saving...")
            torch.save(model.state_dict(), os.path.join(model_path,"best-val-model.pt"))
            res['best_val'] = val_res['acc']
            res['best_epoch'] = epoch

        print("Epoch %d/%d finished." % (epoch, N_EPOCHS))
        print("Train time: %.2f \t Val time %.2f" % (train_res['time'], val_res['time']))
        print("Train loss %.4f \t Train acc: %.4f" % (train_res['loss'], train_res['acc']))
        print("Val loss: %.4f \t Val acc: %.4f" % (val_res['loss'], val_res['acc']))
        print("Best val acc: %.4f at epoch %d." % (res['best_val'],  res['best_epoch']))

        json_object = json.dumps(res, indent=4)
        with open(os.path.join(log_path, "train-result-epoch{}.json".format(epoch)), "w") as outfile:
            outfile.write(json_object)

        ## adjust learning rate
        if (epoch < 31):
            lr_scheduler.step(metrics=val_res['acc'])
        else:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5


def test(model, res, model_path, log_path):
    ### test on best model
    best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)
    model.load_state_dict(best_model)
    with torch.no_grad():
        test_res = run(model, 0, dataloaders['test'], partition='test')

    print("Final ", test_res['score'])
    pred = test_res['score'].cpu()

    np.save(os.path.join(log_path, "score.npy"), pred)
    fpr, tpr, thres, eB, eS  = buildROC(pred[...,0], pred[...,2])
    auc = roc_auc_score(pred[...,0], pred[...,2])

    metric = {'test_loss': test_res['loss'], 'test_acc': test_res['acc'],
              'test_auc': auc, 'test_1/eB_0.3':1./eB[0],'test_1/eB_0.5':1./eB[1]}
    res.update(metric)
    print("Test: Loss %.4f \t Acc %.4f \t AUC: %.4f \t 1/eB 0.3: %.4f \t 1/eB 0.5: %.4f"\
           % (test_res['loss'], test_res['acc'], auc, 1./eB[0], 1./eB[1]))
    json_object = json.dumps(res, indent=4)
    with open(os.path.join(log_path, "test-result.json"), "w") as outfile:
        outfile.write(json_object)

if __name__ == "__main__":

    N_EPOCHS = 55 # 60

    model_path = "models/LorentzNet/"
    log_path = "logs/LorentzNet/"
    # args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cuda
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cpu")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    ### create parallel model
    model = LorentzNet(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)



Input size of phi_e:  4
Model Size: 199
 train samples: 8000
 val samples: 1000
 test samples: 1000
Training...


500it [00:04, 124.76it/s]


>> train 	 Epoch 1/55 	 Batch 499/500 	 Loss 0.7082 	 Running Acc 0.503 	 Total Acc 0.503 	 Avg Batch Time 0.0080
Time: train: 4.01 	 Train loss 0.7082 	 Train acc: 0.5031


63it [00:00, 193.55it/s]


>> val 	 Loss 0.6903 	 Running Acc 0.992 	 Total Acc 0.500 	 Avg Batch Time 0.0026
New best validation model, saving...
Epoch 0/55 finished.
Train time: 4.01 	 Val time 0.33
Train loss 0.7082 	 Train acc: 0.5031
Val loss: 0.6905 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 0.


500it [00:03, 129.14it/s]


>> train 	 Epoch 2/55 	 Batch 499/500 	 Loss 0.6741 	 Running Acc 0.581 	 Total Acc 0.581 	 Avg Batch Time 0.0077
Time: train: 3.87 	 Train loss 0.6741 	 Train acc: 0.5806


63it [00:00, 244.89it/s]


>> val 	 Loss 0.6668 	 Running Acc 1.224 	 Total Acc 0.617 	 Avg Batch Time 0.0021
New best validation model, saving...
Epoch 1/55 finished.
Train time: 3.87 	 Val time 0.26
Train loss 0.6741 	 Train acc: 0.5806
Val loss: 0.6670 	 Val acc: 0.6170
Best val acc: 0.6170 at epoch 1.


500it [00:03, 131.67it/s]


>> train 	 Epoch 3/55 	 Batch 499/500 	 Loss 0.6033 	 Running Acc 0.727 	 Total Acc 0.727 	 Avg Batch Time 0.0076
Time: train: 3.80 	 Train loss 0.6033 	 Train acc: 0.7266


63it [00:00, 252.23it/s]


>> val 	 Loss 0.6500 	 Running Acc 1.210 	 Total Acc 0.610 	 Avg Batch Time 0.0020
Epoch 2/55 finished.
Train time: 3.80 	 Val time 0.25
Train loss 0.6033 	 Train acc: 0.7266
Val loss: 0.6506 	 Val acc: 0.6100
Best val acc: 0.6170 at epoch 1.


500it [00:03, 129.63it/s]


>> train 	 Epoch 4/55 	 Batch 499/500 	 Loss 0.5589 	 Running Acc 0.744 	 Total Acc 0.744 	 Avg Batch Time 0.0077
Time: train: 3.86 	 Train loss 0.5589 	 Train acc: 0.7444


63it [00:00, 238.59it/s]


>> val 	 Loss 0.5300 	 Running Acc 1.496 	 Total Acc 0.754 	 Avg Batch Time 0.0021
New best validation model, saving...
Epoch 3/55 finished.
Train time: 3.86 	 Val time 0.27
Train loss 0.5589 	 Train acc: 0.7444
Val loss: 0.5310 	 Val acc: 0.7540
Best val acc: 0.7540 at epoch 3.


500it [00:03, 133.99it/s]


>> train 	 Epoch 5/55 	 Batch 499/500 	 Loss 0.5477 	 Running Acc 0.745 	 Total Acc 0.745 	 Avg Batch Time 0.0075
Time: train: 3.73 	 Train loss 0.5477 	 Train acc: 0.7452


63it [00:00, 262.26it/s]


>> val 	 Loss 0.5755 	 Running Acc 1.417 	 Total Acc 0.714 	 Avg Batch Time 0.0019
Epoch 4/55 finished.
Train time: 3.73 	 Val time 0.24
Train loss 0.5477 	 Train acc: 0.7452
Val loss: 0.5776 	 Val acc: 0.7140
Best val acc: 0.7540 at epoch 3.


500it [00:03, 134.54it/s]


>> train 	 Epoch 6/55 	 Batch 499/500 	 Loss 0.5387 	 Running Acc 0.749 	 Total Acc 0.749 	 Avg Batch Time 0.0074
Time: train: 3.72 	 Train loss 0.5387 	 Train acc: 0.7492


63it [00:00, 241.81it/s]


>> val 	 Loss 0.5221 	 Running Acc 1.502 	 Total Acc 0.757 	 Avg Batch Time 0.0021
New best validation model, saving...
Epoch 5/55 finished.
Train time: 3.72 	 Val time 0.26
Train loss 0.5387 	 Train acc: 0.7492
Val loss: 0.5229 	 Val acc: 0.7570
Best val acc: 0.7570 at epoch 5.


500it [00:03, 137.55it/s]


>> train 	 Epoch 7/55 	 Batch 499/500 	 Loss 0.5273 	 Running Acc 0.759 	 Total Acc 0.759 	 Avg Batch Time 0.0073
Time: train: 3.64 	 Train loss 0.5273 	 Train acc: 0.7589


63it [00:00, 262.62it/s]


>> val 	 Loss 0.5103 	 Running Acc 1.538 	 Total Acc 0.775 	 Avg Batch Time 0.0019
New best validation model, saving...
Epoch 6/55 finished.
Train time: 3.64 	 Val time 0.24
Train loss 0.5273 	 Train acc: 0.7589
Val loss: 0.5113 	 Val acc: 0.7750
Best val acc: 0.7750 at epoch 6.


500it [00:03, 133.80it/s]


>> train 	 Epoch 8/55 	 Batch 499/500 	 Loss 0.5210 	 Running Acc 0.760 	 Total Acc 0.760 	 Avg Batch Time 0.0075
Time: train: 3.74 	 Train loss 0.5210 	 Train acc: 0.7600


63it [00:00, 252.76it/s]


>> val 	 Loss 0.4882 	 Running Acc 1.560 	 Total Acc 0.786 	 Avg Batch Time 0.0020
New best validation model, saving...
Epoch 7/55 finished.
Train time: 3.74 	 Val time 0.25
Train loss 0.5210 	 Train acc: 0.7600
Val loss: 0.4890 	 Val acc: 0.7860
Best val acc: 0.7860 at epoch 7.


500it [00:03, 129.26it/s]


>> train 	 Epoch 9/55 	 Batch 499/500 	 Loss 0.5271 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0077
Time: train: 3.87 	 Train loss 0.5271 	 Train acc: 0.7569


63it [00:00, 205.15it/s]


>> val 	 Loss 0.6399 	 Running Acc 1.238 	 Total Acc 0.624 	 Avg Batch Time 0.0025
Epoch 8/55 finished.
Train time: 3.87 	 Val time 0.31
Train loss 0.5271 	 Train acc: 0.7569
Val loss: 0.6417 	 Val acc: 0.6240
Best val acc: 0.7860 at epoch 7.


500it [00:03, 132.81it/s]


>> train 	 Epoch 10/55 	 Batch 499/500 	 Loss 0.5274 	 Running Acc 0.751 	 Total Acc 0.751 	 Avg Batch Time 0.0075
Time: train: 3.77 	 Train loss 0.5274 	 Train acc: 0.7510


63it [00:00, 250.40it/s]


>> val 	 Loss 0.5092 	 Running Acc 1.542 	 Total Acc 0.777 	 Avg Batch Time 0.0020
Epoch 9/55 finished.
Train time: 3.77 	 Val time 0.25
Train loss 0.5274 	 Train acc: 0.7510
Val loss: 0.5100 	 Val acc: 0.7770
Best val acc: 0.7860 at epoch 7.


500it [00:03, 128.09it/s]


>> train 	 Epoch 11/55 	 Batch 499/500 	 Loss 0.5206 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0078
Time: train: 3.91 	 Train loss 0.5206 	 Train acc: 0.7572


63it [00:00, 261.72it/s]


>> val 	 Loss 0.4823 	 Running Acc 1.556 	 Total Acc 0.784 	 Avg Batch Time 0.0020
Epoch 10/55 finished.
Train time: 3.91 	 Val time 0.24
Train loss 0.5206 	 Train acc: 0.7572
Val loss: 0.4826 	 Val acc: 0.7840
Best val acc: 0.7860 at epoch 7.


500it [00:03, 131.51it/s]


>> train 	 Epoch 12/55 	 Batch 499/500 	 Loss 0.5168 	 Running Acc 0.761 	 Total Acc 0.761 	 Avg Batch Time 0.0076
Time: train: 3.80 	 Train loss 0.5168 	 Train acc: 0.7608


63it [00:00, 259.20it/s]


>> val 	 Loss 0.5306 	 Running Acc 1.452 	 Total Acc 0.732 	 Avg Batch Time 0.0020
Epoch 11/55 finished.
Train time: 3.80 	 Val time 0.25
Train loss 0.5168 	 Train acc: 0.7608
Val loss: 0.5321 	 Val acc: 0.7320
Best val acc: 0.7860 at epoch 7.


500it [00:03, 129.28it/s]


>> train 	 Epoch 13/55 	 Batch 499/500 	 Loss 0.5206 	 Running Acc 0.757 	 Total Acc 0.757 	 Avg Batch Time 0.0077
Time: train: 3.87 	 Train loss 0.5206 	 Train acc: 0.7569


63it [00:00, 254.13it/s]


>> val 	 Loss 0.4804 	 Running Acc 1.548 	 Total Acc 0.780 	 Avg Batch Time 0.0020
Epoch 12/55 finished.
Train time: 3.87 	 Val time 0.25
Train loss 0.5206 	 Train acc: 0.7569
Val loss: 0.4808 	 Val acc: 0.7800
Best val acc: 0.7860 at epoch 7.


500it [00:03, 127.62it/s]


>> train 	 Epoch 14/55 	 Batch 499/500 	 Loss 0.5165 	 Running Acc 0.758 	 Total Acc 0.758 	 Avg Batch Time 0.0078
Time: train: 3.92 	 Train loss 0.5165 	 Train acc: 0.7580


63it [00:00, 247.71it/s]


>> val 	 Loss 0.4737 	 Running Acc 1.556 	 Total Acc 0.784 	 Avg Batch Time 0.0021
Epoch 13/55 finished.
Train time: 3.92 	 Val time 0.26
Train loss 0.5165 	 Train acc: 0.7580
Val loss: 0.4747 	 Val acc: 0.7840
Best val acc: 0.7860 at epoch 7.


500it [00:03, 126.59it/s]


>> train 	 Epoch 15/55 	 Batch 499/500 	 Loss 0.5171 	 Running Acc 0.760 	 Total Acc 0.760 	 Avg Batch Time 0.0079
Time: train: 3.95 	 Train loss 0.5171 	 Train acc: 0.7602


63it [00:00, 241.00it/s]


>> val 	 Loss 0.4781 	 Running Acc 1.550 	 Total Acc 0.781 	 Avg Batch Time 0.0021
Epoch 14/55 finished.
Train time: 3.95 	 Val time 0.26
Train loss 0.5171 	 Train acc: 0.7602
Val loss: 0.4789 	 Val acc: 0.7810
Best val acc: 0.7860 at epoch 7.


500it [00:03, 125.22it/s]


>> train 	 Epoch 16/55 	 Batch 499/500 	 Loss 0.5114 	 Running Acc 0.760 	 Total Acc 0.760 	 Avg Batch Time 0.0080
Time: train: 3.99 	 Train loss 0.5114 	 Train acc: 0.7599


63it [00:00, 234.04it/s]


>> val 	 Loss 0.4712 	 Running Acc 1.569 	 Total Acc 0.791 	 Avg Batch Time 0.0022
New best validation model, saving...
Epoch 15/55 finished.
Train time: 3.99 	 Val time 0.27
Train loss 0.5114 	 Train acc: 0.7599
Val loss: 0.4721 	 Val acc: 0.7910
Best val acc: 0.7910 at epoch 15.


500it [00:04, 123.47it/s]


>> train 	 Epoch 17/55 	 Batch 499/500 	 Loss 0.5138 	 Running Acc 0.765 	 Total Acc 0.765 	 Avg Batch Time 0.0081
Time: train: 4.05 	 Train loss 0.5138 	 Train acc: 0.7648


63it [00:00, 232.82it/s]


>> val 	 Loss 0.4972 	 Running Acc 1.530 	 Total Acc 0.771 	 Avg Batch Time 0.0022
Epoch 16/55 finished.
Train time: 4.05 	 Val time 0.27
Train loss 0.5138 	 Train acc: 0.7648
Val loss: 0.4979 	 Val acc: 0.7710
Best val acc: 0.7910 at epoch 15.


500it [00:03, 128.99it/s]


>> train 	 Epoch 18/55 	 Batch 499/500 	 Loss 0.5140 	 Running Acc 0.762 	 Total Acc 0.762 	 Avg Batch Time 0.0078
Time: train: 3.88 	 Train loss 0.5140 	 Train acc: 0.7621


63it [00:00, 231.80it/s]


>> val 	 Loss 0.4853 	 Running Acc 1.530 	 Total Acc 0.771 	 Avg Batch Time 0.0022
Epoch 17/55 finished.
Train time: 3.88 	 Val time 0.28
Train loss 0.5140 	 Train acc: 0.7621
Val loss: 0.4855 	 Val acc: 0.7710
Best val acc: 0.7910 at epoch 15.


500it [00:03, 132.26it/s]


>> train 	 Epoch 19/55 	 Batch 499/500 	 Loss 0.5183 	 Running Acc 0.759 	 Total Acc 0.759 	 Avg Batch Time 0.0076
Time: train: 3.78 	 Train loss 0.5183 	 Train acc: 0.7588


63it [00:00, 250.65it/s]


>> val 	 Loss 0.6117 	 Running Acc 1.278 	 Total Acc 0.644 	 Avg Batch Time 0.0020
Epoch 18/55 finished.
Train time: 3.78 	 Val time 0.25
Train loss 0.5183 	 Train acc: 0.7588
Val loss: 0.6134 	 Val acc: 0.6440
Best val acc: 0.7910 at epoch 15.


500it [00:03, 133.16it/s]


>> train 	 Epoch 20/55 	 Batch 499/500 	 Loss 0.5212 	 Running Acc 0.754 	 Total Acc 0.754 	 Avg Batch Time 0.0075
Time: train: 3.76 	 Train loss 0.5212 	 Train acc: 0.7545


63it [00:00, 251.99it/s]


>> val 	 Loss 0.4952 	 Running Acc 1.516 	 Total Acc 0.764 	 Avg Batch Time 0.0020
Epoch 19/55 finished.
Train time: 3.76 	 Val time 0.25
Train loss 0.5212 	 Train acc: 0.7545
Val loss: 0.4957 	 Val acc: 0.7640
Best val acc: 0.7910 at epoch 15.


500it [00:03, 129.60it/s]


>> train 	 Epoch 21/55 	 Batch 499/500 	 Loss 0.5045 	 Running Acc 0.767 	 Total Acc 0.767 	 Avg Batch Time 0.0077
Time: train: 3.86 	 Train loss 0.5045 	 Train acc: 0.7669


63it [00:00, 233.57it/s]


>> val 	 Loss 0.5790 	 Running Acc 1.357 	 Total Acc 0.684 	 Avg Batch Time 0.0022
Epoch 20/55 finished.
Train time: 3.86 	 Val time 0.27
Train loss 0.5045 	 Train acc: 0.7669
Val loss: 0.5808 	 Val acc: 0.6840
Best val acc: 0.7910 at epoch 15.


500it [00:03, 135.07it/s]


>> train 	 Epoch 22/55 	 Batch 499/500 	 Loss 0.5108 	 Running Acc 0.767 	 Total Acc 0.767 	 Avg Batch Time 0.0074
Time: train: 3.70 	 Train loss 0.5108 	 Train acc: 0.7671


63it [00:00, 245.87it/s]


>> val 	 Loss 0.4891 	 Running Acc 1.534 	 Total Acc 0.773 	 Avg Batch Time 0.0021
Epoch 21/55 finished.
Train time: 3.70 	 Val time 0.26
Train loss 0.5108 	 Train acc: 0.7671
Val loss: 0.4894 	 Val acc: 0.7730
Best val acc: 0.7910 at epoch 15.


500it [00:04, 123.14it/s]


>> train 	 Epoch 23/55 	 Batch 499/500 	 Loss 0.5102 	 Running Acc 0.768 	 Total Acc 0.768 	 Avg Batch Time 0.0081
Time: train: 4.06 	 Train loss 0.5102 	 Train acc: 0.7684


63it [00:00, 247.49it/s]


>> val 	 Loss 0.4710 	 Running Acc 1.558 	 Total Acc 0.785 	 Avg Batch Time 0.0021
Epoch 22/55 finished.
Train time: 4.06 	 Val time 0.26
Train loss 0.5102 	 Train acc: 0.7684
Val loss: 0.4722 	 Val acc: 0.7850
Best val acc: 0.7910 at epoch 15.


500it [00:03, 128.49it/s]


>> train 	 Epoch 24/55 	 Batch 499/500 	 Loss 0.5096 	 Running Acc 0.768 	 Total Acc 0.768 	 Avg Batch Time 0.0078
Time: train: 3.89 	 Train loss 0.5096 	 Train acc: 0.7676


63it [00:00, 224.30it/s]


>> val 	 Loss 0.4899 	 Running Acc 1.536 	 Total Acc 0.774 	 Avg Batch Time 0.0023
Epoch 23/55 finished.
Train time: 3.89 	 Val time 0.28
Train loss 0.5096 	 Train acc: 0.7676
Val loss: 0.4904 	 Val acc: 0.7740
Best val acc: 0.7910 at epoch 15.


500it [00:04, 117.59it/s]


>> train 	 Epoch 25/55 	 Batch 499/500 	 Loss 0.5103 	 Running Acc 0.768 	 Total Acc 0.768 	 Avg Batch Time 0.0085
Time: train: 4.25 	 Train loss 0.5103 	 Train acc: 0.7679


63it [00:00, 233.41it/s]


>> val 	 Loss 0.4885 	 Running Acc 1.554 	 Total Acc 0.783 	 Avg Batch Time 0.0022
Epoch 24/55 finished.
Train time: 4.25 	 Val time 0.27
Train loss 0.5103 	 Train acc: 0.7679
Val loss: 0.4891 	 Val acc: 0.7830
Best val acc: 0.7910 at epoch 15.


500it [00:03, 125.33it/s]


>> train 	 Epoch 26/55 	 Batch 499/500 	 Loss 0.5095 	 Running Acc 0.764 	 Total Acc 0.764 	 Avg Batch Time 0.0080
Time: train: 3.99 	 Train loss 0.5095 	 Train acc: 0.7642


63it [00:00, 219.56it/s]


>> val 	 Loss 0.4706 	 Running Acc 1.554 	 Total Acc 0.783 	 Avg Batch Time 0.0023
Epoch 25/55 finished.
Train time: 3.99 	 Val time 0.29
Train loss 0.5095 	 Train acc: 0.7642
Val loss: 0.4712 	 Val acc: 0.7830
Best val acc: 0.7910 at epoch 15.


500it [00:03, 126.44it/s]


>> train 	 Epoch 27/55 	 Batch 499/500 	 Loss 0.5000 	 Running Acc 0.771 	 Total Acc 0.771 	 Avg Batch Time 0.0079
Time: train: 3.96 	 Train loss 0.5000 	 Train acc: 0.7710


63it [00:00, 256.54it/s]


>> val 	 Loss 0.4826 	 Running Acc 1.514 	 Total Acc 0.763 	 Avg Batch Time 0.0020
Epoch 26/55 finished.
Train time: 3.96 	 Val time 0.25
Train loss 0.5000 	 Train acc: 0.7710
Val loss: 0.4837 	 Val acc: 0.7630
Best val acc: 0.7910 at epoch 15.


500it [00:03, 129.76it/s]


>> train 	 Epoch 28/55 	 Batch 499/500 	 Loss 0.5145 	 Running Acc 0.761 	 Total Acc 0.761 	 Avg Batch Time 0.0077
Time: train: 3.85 	 Train loss 0.5145 	 Train acc: 0.7606


63it [00:00, 250.67it/s]


>> val 	 Loss 0.4706 	 Running Acc 1.562 	 Total Acc 0.787 	 Avg Batch Time 0.0020
Epoch 27/55 finished.
Train time: 3.85 	 Val time 0.25
Train loss 0.5145 	 Train acc: 0.7606
Val loss: 0.4712 	 Val acc: 0.7870
Best val acc: 0.7910 at epoch 15.


500it [00:03, 129.81it/s]


>> train 	 Epoch 29/55 	 Batch 499/500 	 Loss 0.5103 	 Running Acc 0.765 	 Total Acc 0.765 	 Avg Batch Time 0.0077
Time: train: 3.85 	 Train loss 0.5103 	 Train acc: 0.7655


63it [00:00, 230.80it/s]


>> val 	 Loss 0.4641 	 Running Acc 1.571 	 Total Acc 0.792 	 Avg Batch Time 0.0022
New best validation model, saving...
Epoch 28/55 finished.
Train time: 3.85 	 Val time 0.28
Train loss 0.5103 	 Train acc: 0.7655
Val loss: 0.4647 	 Val acc: 0.7920
Best val acc: 0.7920 at epoch 28.


500it [00:04, 121.10it/s]


>> train 	 Epoch 30/55 	 Batch 499/500 	 Loss 0.5091 	 Running Acc 0.761 	 Total Acc 0.761 	 Avg Batch Time 0.0083
Time: train: 4.13 	 Train loss 0.5091 	 Train acc: 0.7610


63it [00:00, 246.69it/s]


>> val 	 Loss 0.4642 	 Running Acc 1.569 	 Total Acc 0.791 	 Avg Batch Time 0.0021
Epoch 29/55 finished.
Train time: 4.13 	 Val time 0.26
Train loss 0.5091 	 Train acc: 0.7610
Val loss: 0.4653 	 Val acc: 0.7910
Best val acc: 0.7920 at epoch 28.


500it [00:03, 132.21it/s]


>> train 	 Epoch 31/55 	 Batch 499/500 	 Loss 0.5033 	 Running Acc 0.766 	 Total Acc 0.766 	 Avg Batch Time 0.0076
Time: train: 3.78 	 Train loss 0.5033 	 Train acc: 0.7660


63it [00:00, 254.41it/s]


>> val 	 Loss 0.4688 	 Running Acc 1.569 	 Total Acc 0.791 	 Avg Batch Time 0.0020
Epoch 30/55 finished.
Train time: 3.78 	 Val time 0.25
Train loss 0.5033 	 Train acc: 0.7660
Val loss: 0.4697 	 Val acc: 0.7910
Best val acc: 0.7920 at epoch 28.


500it [00:03, 133.02it/s]


>> train 	 Epoch 32/55 	 Batch 499/500 	 Loss 0.5048 	 Running Acc 0.770 	 Total Acc 0.770 	 Avg Batch Time 0.0075
Time: train: 3.76 	 Train loss 0.5048 	 Train acc: 0.7696


63it [00:00, 180.38it/s]


>> val 	 Loss 0.4621 	 Running Acc 1.583 	 Total Acc 0.798 	 Avg Batch Time 0.0028
New best validation model, saving...
Epoch 31/55 finished.
Train time: 3.76 	 Val time 0.35
Train loss 0.5048 	 Train acc: 0.7696
Val loss: 0.4631 	 Val acc: 0.7980
Best val acc: 0.7980 at epoch 31.


500it [00:03, 130.24it/s]


>> train 	 Epoch 33/55 	 Batch 499/500 	 Loss 0.5070 	 Running Acc 0.765 	 Total Acc 0.765 	 Avg Batch Time 0.0077
Time: train: 3.84 	 Train loss 0.5070 	 Train acc: 0.7654


63it [00:00, 250.83it/s]


>> val 	 Loss 0.4631 	 Running Acc 1.579 	 Total Acc 0.796 	 Avg Batch Time 0.0020
Epoch 32/55 finished.
Train time: 3.84 	 Val time 0.25
Train loss 0.5070 	 Train acc: 0.7654
Val loss: 0.4640 	 Val acc: 0.7960
Best val acc: 0.7980 at epoch 31.


500it [00:03, 133.13it/s]


>> train 	 Epoch 34/55 	 Batch 499/500 	 Loss 0.5071 	 Running Acc 0.768 	 Total Acc 0.768 	 Avg Batch Time 0.0075
Time: train: 3.76 	 Train loss 0.5071 	 Train acc: 0.7681


63it [00:00, 237.24it/s]


>> val 	 Loss 0.4779 	 Running Acc 1.550 	 Total Acc 0.781 	 Avg Batch Time 0.0022
Epoch 33/55 finished.
Train time: 3.76 	 Val time 0.27
Train loss 0.5071 	 Train acc: 0.7681
Val loss: 0.4783 	 Val acc: 0.7810
Best val acc: 0.7980 at epoch 31.


500it [00:03, 129.65it/s]


>> train 	 Epoch 35/55 	 Batch 499/500 	 Loss 0.5063 	 Running Acc 0.767 	 Total Acc 0.767 	 Avg Batch Time 0.0077
Time: train: 3.86 	 Train loss 0.5063 	 Train acc: 0.7666


63it [00:00, 249.19it/s]


>> val 	 Loss 0.4676 	 Running Acc 1.585 	 Total Acc 0.799 	 Avg Batch Time 0.0020
New best validation model, saving...
Epoch 34/55 finished.
Train time: 3.86 	 Val time 0.26
Train loss 0.5063 	 Train acc: 0.7666
Val loss: 0.4685 	 Val acc: 0.7990
Best val acc: 0.7990 at epoch 34.


500it [00:03, 134.20it/s]


>> train 	 Epoch 36/55 	 Batch 499/500 	 Loss 0.5041 	 Running Acc 0.774 	 Total Acc 0.774 	 Avg Batch Time 0.0075
Time: train: 3.73 	 Train loss 0.5041 	 Train acc: 0.7740


63it [00:00, 265.50it/s]


>> val 	 Loss 0.4691 	 Running Acc 1.575 	 Total Acc 0.794 	 Avg Batch Time 0.0019
Epoch 35/55 finished.
Train time: 3.73 	 Val time 0.24
Train loss 0.5041 	 Train acc: 0.7740
Val loss: 0.4699 	 Val acc: 0.7940
Best val acc: 0.7990 at epoch 34.


500it [00:03, 132.32it/s]


>> train 	 Epoch 37/55 	 Batch 499/500 	 Loss 0.5084 	 Running Acc 0.771 	 Total Acc 0.771 	 Avg Batch Time 0.0076
Time: train: 3.78 	 Train loss 0.5084 	 Train acc: 0.7711


63it [00:00, 253.67it/s]


>> val 	 Loss 0.4658 	 Running Acc 1.583 	 Total Acc 0.798 	 Avg Batch Time 0.0020
Epoch 36/55 finished.
Train time: 3.78 	 Val time 0.25
Train loss 0.5084 	 Train acc: 0.7711
Val loss: 0.4667 	 Val acc: 0.7980
Best val acc: 0.7990 at epoch 34.


500it [00:03, 134.19it/s]


>> train 	 Epoch 38/55 	 Batch 499/500 	 Loss 0.5087 	 Running Acc 0.767 	 Total Acc 0.767 	 Avg Batch Time 0.0075
Time: train: 3.73 	 Train loss 0.5087 	 Train acc: 0.7666


63it [00:00, 257.24it/s]


>> val 	 Loss 0.4710 	 Running Acc 1.563 	 Total Acc 0.788 	 Avg Batch Time 0.0020
Epoch 37/55 finished.
Train time: 3.73 	 Val time 0.25
Train loss 0.5087 	 Train acc: 0.7666
Val loss: 0.4717 	 Val acc: 0.7880
Best val acc: 0.7990 at epoch 34.


500it [00:03, 135.46it/s]


>> train 	 Epoch 39/55 	 Batch 499/500 	 Loss 0.5128 	 Running Acc 0.764 	 Total Acc 0.764 	 Avg Batch Time 0.0074
Time: train: 3.69 	 Train loss 0.5128 	 Train acc: 0.7641


63it [00:00, 256.11it/s]


>> val 	 Loss 0.4674 	 Running Acc 1.565 	 Total Acc 0.789 	 Avg Batch Time 0.0020
Epoch 38/55 finished.
Train time: 3.69 	 Val time 0.25
Train loss 0.5128 	 Train acc: 0.7641
Val loss: 0.4681 	 Val acc: 0.7890
Best val acc: 0.7990 at epoch 34.


500it [00:03, 130.82it/s]


>> train 	 Epoch 40/55 	 Batch 499/500 	 Loss 0.5089 	 Running Acc 0.766 	 Total Acc 0.766 	 Avg Batch Time 0.0076
Time: train: 3.82 	 Train loss 0.5089 	 Train acc: 0.7665


63it [00:00, 180.38it/s]


>> val 	 Loss 0.4690 	 Running Acc 1.573 	 Total Acc 0.793 	 Avg Batch Time 0.0028
Epoch 39/55 finished.
Train time: 3.82 	 Val time 0.35
Train loss 0.5089 	 Train acc: 0.7665
Val loss: 0.4698 	 Val acc: 0.7930
Best val acc: 0.7990 at epoch 34.


500it [00:03, 128.94it/s]


>> train 	 Epoch 41/55 	 Batch 499/500 	 Loss 0.5060 	 Running Acc 0.770 	 Total Acc 0.770 	 Avg Batch Time 0.0078
Time: train: 3.88 	 Train loss 0.5060 	 Train acc: 0.7704


63it [00:00, 238.34it/s]


>> val 	 Loss 0.4619 	 Running Acc 1.581 	 Total Acc 0.797 	 Avg Batch Time 0.0021
Epoch 40/55 finished.
Train time: 3.88 	 Val time 0.27
Train loss 0.5060 	 Train acc: 0.7704
Val loss: 0.4629 	 Val acc: 0.7970
Best val acc: 0.7990 at epoch 34.


500it [00:03, 129.50it/s]


>> train 	 Epoch 42/55 	 Batch 499/500 	 Loss 0.5089 	 Running Acc 0.763 	 Total Acc 0.763 	 Avg Batch Time 0.0077
Time: train: 3.86 	 Train loss 0.5089 	 Train acc: 0.7635


63it [00:00, 237.85it/s]


>> val 	 Loss 0.4652 	 Running Acc 1.569 	 Total Acc 0.791 	 Avg Batch Time 0.0021
Epoch 41/55 finished.
Train time: 3.86 	 Val time 0.27
Train loss 0.5089 	 Train acc: 0.7635
Val loss: 0.4662 	 Val acc: 0.7910
Best val acc: 0.7990 at epoch 34.


500it [00:03, 133.57it/s]


>> train 	 Epoch 43/55 	 Batch 499/500 	 Loss 0.5097 	 Running Acc 0.763 	 Total Acc 0.763 	 Avg Batch Time 0.0075
Time: train: 3.74 	 Train loss 0.5097 	 Train acc: 0.7629


63it [00:00, 258.17it/s]


>> val 	 Loss 0.4637 	 Running Acc 1.573 	 Total Acc 0.793 	 Avg Batch Time 0.0020
Epoch 42/55 finished.
Train time: 3.74 	 Val time 0.25
Train loss 0.5097 	 Train acc: 0.7629
Val loss: 0.4646 	 Val acc: 0.7930
Best val acc: 0.7990 at epoch 34.


500it [00:03, 129.61it/s]


>> train 	 Epoch 44/55 	 Batch 499/500 	 Loss 0.5076 	 Running Acc 0.767 	 Total Acc 0.767 	 Avg Batch Time 0.0077
Time: train: 3.86 	 Train loss 0.5076 	 Train acc: 0.7666


63it [00:00, 254.59it/s]


>> val 	 Loss 0.4680 	 Running Acc 1.569 	 Total Acc 0.791 	 Avg Batch Time 0.0020
Epoch 43/55 finished.
Train time: 3.86 	 Val time 0.25
Train loss 0.5076 	 Train acc: 0.7666
Val loss: 0.4686 	 Val acc: 0.7910
Best val acc: 0.7990 at epoch 34.


500it [00:03, 131.46it/s]


>> train 	 Epoch 45/55 	 Batch 499/500 	 Loss 0.5074 	 Running Acc 0.766 	 Total Acc 0.766 	 Avg Batch Time 0.0076
Time: train: 3.81 	 Train loss 0.5074 	 Train acc: 0.7660


63it [00:00, 259.26it/s]


>> val 	 Loss 0.4630 	 Running Acc 1.579 	 Total Acc 0.796 	 Avg Batch Time 0.0020
Epoch 44/55 finished.
Train time: 3.81 	 Val time 0.25
Train loss 0.5074 	 Train acc: 0.7660
Val loss: 0.4638 	 Val acc: 0.7960
Best val acc: 0.7990 at epoch 34.


500it [00:03, 134.29it/s]


>> train 	 Epoch 46/55 	 Batch 499/500 	 Loss 0.5037 	 Running Acc 0.770 	 Total Acc 0.770 	 Avg Batch Time 0.0075
Time: train: 3.73 	 Train loss 0.5037 	 Train acc: 0.7698


63it [00:00, 253.78it/s]


>> val 	 Loss 0.4637 	 Running Acc 1.583 	 Total Acc 0.798 	 Avg Batch Time 0.0020
Epoch 45/55 finished.
Train time: 3.73 	 Val time 0.25
Train loss 0.5037 	 Train acc: 0.7698
Val loss: 0.4645 	 Val acc: 0.7980
Best val acc: 0.7990 at epoch 34.


500it [00:03, 133.80it/s]


>> train 	 Epoch 47/55 	 Batch 499/500 	 Loss 0.5035 	 Running Acc 0.774 	 Total Acc 0.774 	 Avg Batch Time 0.0075
Time: train: 3.74 	 Train loss 0.5035 	 Train acc: 0.7739


63it [00:00, 235.62it/s]


>> val 	 Loss 0.4710 	 Running Acc 1.562 	 Total Acc 0.787 	 Avg Batch Time 0.0022
Epoch 46/55 finished.
Train time: 3.74 	 Val time 0.27
Train loss 0.5035 	 Train acc: 0.7739
Val loss: 0.4717 	 Val acc: 0.7870
Best val acc: 0.7990 at epoch 34.


500it [00:03, 129.88it/s]


>> train 	 Epoch 48/55 	 Batch 499/500 	 Loss 0.5048 	 Running Acc 0.767 	 Total Acc 0.767 	 Avg Batch Time 0.0077
Time: train: 3.85 	 Train loss 0.5048 	 Train acc: 0.7674


63it [00:00, 237.30it/s]


>> val 	 Loss 0.4634 	 Running Acc 1.581 	 Total Acc 0.797 	 Avg Batch Time 0.0021
Epoch 47/55 finished.
Train time: 3.85 	 Val time 0.27
Train loss 0.5048 	 Train acc: 0.7674
Val loss: 0.4643 	 Val acc: 0.7970
Best val acc: 0.7990 at epoch 34.


500it [00:03, 131.77it/s]


>> train 	 Epoch 49/55 	 Batch 499/500 	 Loss 0.5048 	 Running Acc 0.770 	 Total Acc 0.770 	 Avg Batch Time 0.0076
Time: train: 3.80 	 Train loss 0.5048 	 Train acc: 0.7698


63it [00:00, 267.64it/s]


>> val 	 Loss 0.4654 	 Running Acc 1.577 	 Total Acc 0.795 	 Avg Batch Time 0.0019
Epoch 48/55 finished.
Train time: 3.80 	 Val time 0.24
Train loss 0.5048 	 Train acc: 0.7698
Val loss: 0.4663 	 Val acc: 0.7950
Best val acc: 0.7990 at epoch 34.


500it [00:03, 126.21it/s]


>> train 	 Epoch 50/55 	 Batch 499/500 	 Loss 0.5039 	 Running Acc 0.768 	 Total Acc 0.768 	 Avg Batch Time 0.0079
Time: train: 3.96 	 Train loss 0.5039 	 Train acc: 0.7676


63it [00:00, 233.69it/s]


>> val 	 Loss 0.4650 	 Running Acc 1.599 	 Total Acc 0.806 	 Avg Batch Time 0.0022
New best validation model, saving...
Epoch 49/55 finished.
Train time: 3.96 	 Val time 0.27
Train loss 0.5039 	 Train acc: 0.7676
Val loss: 0.4660 	 Val acc: 0.8060
Best val acc: 0.8060 at epoch 49.


500it [00:03, 132.07it/s]


>> train 	 Epoch 51/55 	 Batch 499/500 	 Loss 0.5048 	 Running Acc 0.770 	 Total Acc 0.770 	 Avg Batch Time 0.0076
Time: train: 3.79 	 Train loss 0.5048 	 Train acc: 0.7704


63it [00:00, 256.98it/s]


>> val 	 Loss 0.4628 	 Running Acc 1.577 	 Total Acc 0.795 	 Avg Batch Time 0.0020
Epoch 50/55 finished.
Train time: 3.79 	 Val time 0.25
Train loss 0.5048 	 Train acc: 0.7704
Val loss: 0.4636 	 Val acc: 0.7950
Best val acc: 0.8060 at epoch 49.


500it [00:03, 128.53it/s]


>> train 	 Epoch 52/55 	 Batch 499/500 	 Loss 0.5008 	 Running Acc 0.773 	 Total Acc 0.773 	 Avg Batch Time 0.0078
Time: train: 3.89 	 Train loss 0.5008 	 Train acc: 0.7729


63it [00:00, 248.85it/s]


>> val 	 Loss 0.4665 	 Running Acc 1.575 	 Total Acc 0.794 	 Avg Batch Time 0.0020
Epoch 51/55 finished.
Train time: 3.89 	 Val time 0.26
Train loss 0.5008 	 Train acc: 0.7729
Val loss: 0.4674 	 Val acc: 0.7940
Best val acc: 0.8060 at epoch 49.


500it [00:03, 132.30it/s]


>> train 	 Epoch 53/55 	 Batch 499/500 	 Loss 0.5038 	 Running Acc 0.769 	 Total Acc 0.769 	 Avg Batch Time 0.0076
Time: train: 3.78 	 Train loss 0.5038 	 Train acc: 0.7686


63it [00:00, 266.16it/s]


>> val 	 Loss 0.4655 	 Running Acc 1.575 	 Total Acc 0.794 	 Avg Batch Time 0.0019
Epoch 52/55 finished.
Train time: 3.78 	 Val time 0.24
Train loss 0.5038 	 Train acc: 0.7686
Val loss: 0.4661 	 Val acc: 0.7940
Best val acc: 0.8060 at epoch 49.


500it [00:03, 129.14it/s]


>> train 	 Epoch 54/55 	 Batch 499/500 	 Loss 0.5077 	 Running Acc 0.768 	 Total Acc 0.768 	 Avg Batch Time 0.0077
Time: train: 3.87 	 Train loss 0.5077 	 Train acc: 0.7684


63it [00:00, 235.28it/s]


>> val 	 Loss 0.4717 	 Running Acc 1.558 	 Total Acc 0.785 	 Avg Batch Time 0.0022
Epoch 53/55 finished.
Train time: 3.87 	 Val time 0.27
Train loss 0.5077 	 Train acc: 0.7684
Val loss: 0.4723 	 Val acc: 0.7850
Best val acc: 0.8060 at epoch 49.


500it [00:04, 120.84it/s]


>> train 	 Epoch 55/55 	 Batch 499/500 	 Loss 0.5030 	 Running Acc 0.769 	 Total Acc 0.769 	 Avg Batch Time 0.0083
Time: train: 4.14 	 Train loss 0.5030 	 Train acc: 0.7688


63it [00:00, 233.89it/s]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.4626 	 Running Acc 1.581 	 Total Acc 0.797 	 Avg Batch Time 0.0022
Epoch 54/55 finished.
Train time: 4.14 	 Val time 0.27
Train loss 0.5030 	 Train acc: 0.7688
Val loss: 0.4633 	 Val acc: 0.7970
Best val acc: 0.8060 at epoch 49.


63it [00:00, 244.95it/s]

>> test 	 Loss 0.4781 	 Running Acc 1.581 	 Total Acc 0.797 	 Avg Batch Time 0.0021
Final  tensor([[1.0000, 0.8694, 0.1306],
        [0.0000, 0.6214, 0.3786],
        [1.0000, 0.1845, 0.8155],
        ...,
        [1.0000, 0.2325, 0.7675],
        [0.0000, 0.5240, 0.4760],
        [1.0000, 0.1203, 0.8797]])
Test: Loss 0.4771 	 Acc 0.7970 	 AUC: 0.8621 	 1/eB 0.3: 55.5556 	 1/eB 0.5: 15.1515





# 4. Proposed


In [24]:
import torch
import pennylane as qml
import torch.nn.functional as F
from torch import nn
from torch_geometric.utils import to_dense_adj

n_qubits = 4

dev = qml.device('default.qubit', wires=n_qubits)
# dev = qml.device("qiskit.aer", wires=n_qubits)


def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)

def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)

def RY_RX_layer(weights):
    """Applies a layer of parametrized RY and RX rotations."""
    for i, w in enumerate(weights):
        qml.RY(w, wires=i)
        qml.RX(w, wires=i)

def full_entangling_layer(n_qubits):
    """Applies CNOT gates between all pairs of qubits."""
    for i in range(n_qubits):
        for j in range(i+1, n_qubits):
            qml.CNOT(wires=[i, j])

def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(nqubits - 1):
        qml.CRZ(np.pi / 2, wires=[i, i + 1])
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.SWAP(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.SWAP(wires=[i, i + 1])


@qml.qnode(dev, interface="torch")
def quantum_net(q_input_features, q_weights_flat, q_depth, n_qubits):
    """
    The variational quantum circuit.
    """

    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    # RY_layer(q_input_features)
    qml.AngleEmbedding(features=q_input_features, wires=range(n_qubits), rotation='Z')

    # Sequence of trainable variational layers
    # for k in range(q_depth):
    #     entangling_layer(n_qubits)
    #     RY_RX_layer(q_weights[k])
    #     # RY_layer(q_weights[k])
    for k in range(q_depth):
        if k % 2 == 0:
            entangling_layer(n_qubits)
            RY_layer(q_weights[k])
        else:
            full_entangling_layer(n_qubits)
            RY_RX_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self, n_qubits, q_depth = 1, q_delta=0.001):
        """
        Definition of the *dressed* layout.
        """
        print('n_qubits: ', n_qubits)
        super().__init__()
        self.n_qubits = n_qubits
        self.q_depth = q_depth
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))

    def forward(self, input_features):
        """
        Optimized forward pass to reduce runtime.
        """

        # Quantum Embedding (U(X))
        q_in = torch.tanh(input_features) * np.pi / 2.0

        # Preallocate output tensor
        batch_size = q_in.shape[0]
        q_out = torch.zeros(batch_size, self.n_qubits, device=q_in.device)

        # Vectorized execution
        for i, elem in enumerate(q_in):
            q_out_elem = torch.hstack(quantum_net(elem, self.q_params, self.q_depth, self.n_qubits)).float()
            q_out[i] = q_out_elem

        return q_out

In [25]:
# @title
import torch
from torch import nn
import numpy as np
import pennylane as qml

"""
    Lie-Equivariant Quantum Block (LEQB).

        - Given the Lie generators found (i.e.: through LieGAN, oracle-preserving latent flow, or some other approach
          that we develop further), once the metric tensor J is found via the equation:

                          L.J + J.(L^T) = 0,

          we just have to specify the metric to make the model symmetry-preserving to the corresponding Lie group.
          In the cells below, we can see how the model preserves symmetries (starting with the default Lorentz group),
          and when we change J to some other metric (Euclidean, for example), Lorentz boosts **break** equivariance, while other
          transformations preserve it (rotations, for the example shown in the cells below)
"""
class LEQB(nn.Module):
    def __init__(self, n_input, n_output, n_hidden, n_node_attr=0,
                 dropout = 0., c_weight=1.0, last_layer=False, A=None, include_x=False):
        super(LEQB, self).__init__()
        self.c_weight = c_weight
        self.dimension_reducer = nn.Linear(10, 4) # New linear layer for dimension reduction
        self.dimension_reducer2 = nn.Linear(9, 4) # New linear layer for dimension reduction for phi_h
        n_edge_attr = 2 if not include_x else 10 # dims for Minkowski norm & inner product
        # With include_X = False, not include_x becomes True, so the value of n_edge_attr is 2. n_input = n_hidden = 4
        print('Input size of phi_e: ', n_input)
        self.include_x = include_x

        """
            phi_e: input size: n_qubits -> output size: n_qubits
            n_hidden has to be equal to n_input,
            but this is just considering that this is a simple working example.
        """
        self.phi_e = DressedQuantumNet(n_input)
#         self.phi_e = nn.Sequential(
#             nn.Linear(n_input, n_hidden, bias=False),  # n_input * 2 + n_edge_attr
#             nn.BatchNorm1d(n_hidden),
#             nn.ReLU(),
#             nn.Linear(n_hidden, n_hidden),
#             nn.ReLU())

        n_hidden = n_input # n_input * 2 + n_edge_attr
        self.phi_h = nn.Sequential(
            nn.Linear(n_hidden + n_input + n_node_attr, n_hidden),
            nn.BatchNorm1d(n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output))

#         self.phi_h = DressedQuantumNet(n_hidden)

        layer = nn.Linear(n_hidden, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        self.phi_x = nn.Sequential(
            nn.Linear(n_hidden, n_hidden),
            nn.ReLU(),
            layer)

#         self.phi_x = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             layer)

#         self.phi_m = nn.Sequential(
#             DressedQuantumNet(n_hidden),
#             nn.Linear(n_hidden, 1),
#             nn.Sigmoid())

        self.phi_m = nn.Sequential(
            nn.Linear(n_hidden, 1),
            nn.Sigmoid())

        # self.phi_e = nn.Sequential(
        #     nn.Linear(n_input * 2 + n_edge_attr, n_hidden, bias=False),
        #     nn.BatchNorm1d(n_hidden),
        #     nn.ReLU(),
        #     nn.Linear(n_hidden, n_hidden),
        #     nn.ReLU())

        self.last_layer = last_layer
        if last_layer:
            del self.phi_x

        self.A = A
        self.norm_fn = normA_fn(A) if A is not None else normsq4
        self.dot_fn = dotA_fn(A) if A is not None else dotsq4

    def m_model(self, hi, hj, norms, dots):
        out = torch.cat([hi, hj, norms, dots], dim=1)
        out = self.dimension_reducer(out) # extra
        # print("Before embedding to |psi> : ", out)
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def m_model_extended(self, hi, hj, norms, dots, xi, xj):
        out = torch.cat([hi, hj, norms, dots, xi, xj], dim=1)
        out = self.dimension_reducer(out) # extra
        out = self.phi_e(out).squeeze(0)
        w = self.phi_m(out)
        out = out * w
        return out

    def h_model(self, h, edges, m, node_attr):
        i, j = edges
        agg = unsorted_segment_sum(m, i, num_segments=h.size(0))
        agg = torch.cat([h, agg, node_attr], dim=1)
        #agg = self.dimension_reducer2(agg) # extra for phi_h
        out = h + self.phi_h(agg)
        return out

    def x_model(self, x, edges, x_diff, m):
        i, j = edges
        trans = x_diff * self.phi_x(m)
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        # From https://github.com/vgsatorras/egnn
        # This is never activated but just in case it explosed it may save the train
        trans = torch.clamp(trans, min=-100, max=100)
        agg = unsorted_segment_mean(trans, i, num_segments=x.size(0))
        x = x + agg * self.c_weight
        return x

    def minkowski_feats(self, edges, x):
        i, j = edges
        x_diff = x[i] - x[j]
        norms = self.norm_fn(x_diff).unsqueeze(1)
        dots = self.dot_fn(x[i], x[j]).unsqueeze(1)
        norms, dots = psi(norms), psi(dots)
        return norms, dots, x_diff

    def forward(self, h, x, edges, node_attr=None):
        i, j = edges
        norms, dots, x_diff = self.minkowski_feats(edges, x)

        if self.include_x:
            m = self.m_model_extended(h[i], h[j], norms, dots, x[i], x[j])
        else:
            m = self.m_model(h[i], h[j], norms, dots) # [B*N, hidden]
        if not self.last_layer:
            x = self.x_model(x, edges, x_diff, m)
        h = self.h_model(h, edges, m, node_attr)
        return h, x, m

class LieEQGNN(nn.Module):
    r''' Implementation of LorentzNet.

    Args:
        - `n_scalar` (int): number of input scalars.
        - `n_hidden` (int): dimension of latent space.
        - `n_class`  (int): number of output classes.
        - `n_layers` (int): number of LEQB layers.
        - `c_weight` (float): weight c in the x_model.
        - `dropout`  (float): dropout rate.
    '''
    def __init__(self, n_scalar, n_hidden, n_class = 2, n_layers = 6, c_weight = 1e-3, dropout = 0., A=None, include_x=False):
        super(LieEQGNN, self).__init__()
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.embedding = nn.Linear(n_scalar, n_hidden)
        self.LEQBs = nn.ModuleList([LEQB(self.n_hidden, self.n_hidden, self.n_hidden,
                                    n_node_attr=n_scalar, dropout=dropout,
                                    c_weight=c_weight, last_layer=(i==n_layers-1), A=A, include_x=include_x)
                                    for i in range(n_layers)])
        self.graph_dec = nn.Sequential(nn.Linear(self.n_hidden, self.n_hidden),
                                       nn.ReLU(),
                                       nn.Dropout(dropout),
                                       nn.Linear(self.n_hidden, n_class)) # classification

    def forward(self, scalars, x, edges, node_mask, edge_mask, n_nodes):
        h = self.embedding(scalars)

        # print("h before (just the first particle): \n", h[0].cpu().detach().numpy())
        for i in range(self.n_layers):
            h, x, _ = self.LEQBs[i](h, x, edges, node_attr=scalars)

        # print("h after (just the first particle): \n", h[0].cpu().detach().numpy())

        h = h * node_mask
        h = h.view(-1, n_nodes, self.n_hidden)
        h = torch.mean(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)

In [26]:
import os
import torch
from torch import nn, optim
import json, time
# import utils_lorentz
import numpy as np
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

if __name__ == "__main__":

    N_EPOCHS = 25 # 60

    model_path = "models/LieEQGNN/"
    log_path = "logs/LieEQGNN/"
    # utils_lorentz.args_init(args)

    ### set random seed
    torch.manual_seed(42)
    np.random.seed(42)

    ### initialize cpu
    # dist.init_process_group(backend='nccl')
    device = 'cpu' #torch.device("cuda")
    dtype = torch.float32

    ### load data
    # dataloaders = retrieve_dataloaders( batch_size,
    #                                     num_data=100000, # use all data
    #                                     cache_dir="datasets/QMLHEP/quark_gluons/",
    #                                     num_workers=0,
    #                                     use_one_hot=True)

    model = LieEQGNN(n_scalar = 1, n_hidden = 4, n_class = 2,\
                       dropout = 0.2, n_layers = 1,\
                       c_weight = 1e-3)

    model = model.to(device)

    ### print model and dataset information
    # if (args.local_rank == 0):
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Model Size:", pytorch_total_params)
    for (split, dataloader) in dataloaders.items():
        print(f" {split} samples: {len(dataloader.dataset)}")

    ### optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

    ### lr scheduler
    base_scheduler = CosineAnnealingWarmRestarts(optimizer, 4, 2, verbose = False)
    lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1,\
                                                warmup_epoch=5,\
                                                after_scheduler=base_scheduler) ## warmup

    ### loss function
    loss_fn = nn.CrossEntropyLoss()

    ### initialize logs
    res = {'epochs': [], 'lr' : [],\
           'train_time': [], 'val_time': [],  'train_loss': [], 'val_loss': [],\
           'train_acc': [], 'val_acc': [], 'best_val': 0, 'best_epoch': 0}

    ### training and testing
    print("Training...")
    train(model, res, N_EPOCHS, model_path, log_path)
    test(model, res, model_path, log_path)



Input size of phi_e:  4
n_qubits:  4
Model Size: 199
 train samples: 8000
 val samples: 1000
 test samples: 1000
Training...


500it [21:43,  2.61s/it]


>> train 	 Epoch 1/25 	 Batch 499/500 	 Loss 0.7276 	 Running Acc 0.500 	 Total Acc 0.500 	 Avg Batch Time 2.6068
Time: train: 1303.39 	 Train loss 0.7276 	 Train acc: 0.5000


63it [01:41,  1.62s/it]


>> val 	 Loss 0.6951 	 Running Acc 0.992 	 Total Acc 0.500 	 Avg Batch Time 0.8158
New best validation model, saving...
Epoch 0/25 finished.
Train time: 1303.39 	 Val time 101.97
Train loss 0.7276 	 Train acc: 0.5000
Val loss: 0.6955 	 Val acc: 0.5000
Best val acc: 0.5000 at epoch 0.


500it [21:19,  2.56s/it]


>> train 	 Epoch 2/25 	 Batch 499/500 	 Loss 0.6927 	 Running Acc 0.533 	 Total Acc 0.533 	 Avg Batch Time 2.5589
Time: train: 1279.45 	 Train loss 0.6927 	 Train acc: 0.5331


63it [01:48,  1.72s/it]


>> val 	 Loss 0.6792 	 Running Acc 1.272 	 Total Acc 0.641 	 Avg Batch Time 0.8668
New best validation model, saving...
Epoch 1/25 finished.
Train time: 1279.45 	 Val time 108.35
Train loss 0.6927 	 Train acc: 0.5331
Val loss: 0.6794 	 Val acc: 0.6410
Best val acc: 0.6410 at epoch 1.


500it [21:39,  2.60s/it]


>> train 	 Epoch 3/25 	 Batch 499/500 	 Loss 0.6785 	 Running Acc 0.567 	 Total Acc 0.567 	 Avg Batch Time 2.5988
Time: train: 1299.41 	 Train loss 0.6785 	 Train acc: 0.5670


63it [01:42,  1.63s/it]


>> val 	 Loss 0.7483 	 Running Acc 0.992 	 Total Acc 0.500 	 Avg Batch Time 0.8235
Epoch 2/25 finished.
Train time: 1299.41 	 Val time 102.94
Train loss 0.6785 	 Train acc: 0.5670
Val loss: 0.7500 	 Val acc: 0.5000
Best val acc: 0.6410 at epoch 1.


500it [21:28,  2.58s/it]


>> train 	 Epoch 4/25 	 Batch 499/500 	 Loss 0.6439 	 Running Acc 0.618 	 Total Acc 0.618 	 Avg Batch Time 2.5765
Time: train: 1288.26 	 Train loss 0.6439 	 Train acc: 0.6179


63it [01:41,  1.61s/it]


>> val 	 Loss 0.5918 	 Running Acc 1.395 	 Total Acc 0.703 	 Avg Batch Time 0.8107
New best validation model, saving...
Epoch 3/25 finished.
Train time: 1288.26 	 Val time 101.33
Train loss 0.6439 	 Train acc: 0.6179
Val loss: 0.5927 	 Val acc: 0.7030
Best val acc: 0.7030 at epoch 3.


500it [21:49,  2.62s/it]


>> train 	 Epoch 5/25 	 Batch 499/500 	 Loss 0.5849 	 Running Acc 0.681 	 Total Acc 0.681 	 Avg Batch Time 2.6185
Time: train: 1309.23 	 Train loss 0.5849 	 Train acc: 0.6809


63it [01:42,  1.63s/it]


>> val 	 Loss 0.5663 	 Running Acc 1.435 	 Total Acc 0.723 	 Avg Batch Time 0.8208
New best validation model, saving...
Epoch 4/25 finished.
Train time: 1309.23 	 Val time 102.60
Train loss 0.5849 	 Train acc: 0.6809
Val loss: 0.5668 	 Val acc: 0.7230
Best val acc: 0.7230 at epoch 4.


500it [21:23,  2.57s/it]


>> train 	 Epoch 6/25 	 Batch 499/500 	 Loss 0.5692 	 Running Acc 0.711 	 Total Acc 0.711 	 Avg Batch Time 2.5666
Time: train: 1283.31 	 Train loss 0.5692 	 Train acc: 0.7114


63it [01:41,  1.61s/it]


>> val 	 Loss 0.5519 	 Running Acc 1.421 	 Total Acc 0.716 	 Avg Batch Time 0.8119
Epoch 5/25 finished.
Train time: 1283.31 	 Val time 101.49
Train loss 0.5692 	 Train acc: 0.7114
Val loss: 0.5525 	 Val acc: 0.7160
Best val acc: 0.7230 at epoch 4.


500it [21:50,  2.62s/it]


>> train 	 Epoch 7/25 	 Batch 499/500 	 Loss 0.5601 	 Running Acc 0.727 	 Total Acc 0.727 	 Avg Batch Time 2.6205
Time: train: 1310.26 	 Train loss 0.5601 	 Train acc: 0.7266


63it [01:46,  1.70s/it]


>> val 	 Loss 0.5522 	 Running Acc 1.444 	 Total Acc 0.728 	 Avg Batch Time 0.8543
New best validation model, saving...
Epoch 6/25 finished.
Train time: 1310.26 	 Val time 106.79
Train loss 0.5601 	 Train acc: 0.7266
Val loss: 0.5525 	 Val acc: 0.7280
Best val acc: 0.7280 at epoch 6.


500it [22:01,  2.64s/it]


>> train 	 Epoch 8/25 	 Batch 499/500 	 Loss 0.5554 	 Running Acc 0.732 	 Total Acc 0.732 	 Avg Batch Time 2.6434
Time: train: 1321.70 	 Train loss 0.5554 	 Train acc: 0.7319


63it [01:45,  1.67s/it]


>> val 	 Loss 0.5399 	 Running Acc 1.452 	 Total Acc 0.732 	 Avg Batch Time 0.8435
New best validation model, saving...
Epoch 7/25 finished.
Train time: 1321.70 	 Val time 105.44
Train loss 0.5554 	 Train acc: 0.7319
Val loss: 0.5404 	 Val acc: 0.7320
Best val acc: 0.7320 at epoch 7.


500it [21:40,  2.60s/it]


>> train 	 Epoch 9/25 	 Batch 499/500 	 Loss 0.5607 	 Running Acc 0.732 	 Total Acc 0.732 	 Avg Batch Time 2.6006
Time: train: 1300.28 	 Train loss 0.5607 	 Train acc: 0.7315


63it [01:41,  1.61s/it]


>> val 	 Loss 0.9873 	 Running Acc 1.111 	 Total Acc 0.560 	 Avg Batch Time 0.8116
Epoch 8/25 finished.
Train time: 1300.28 	 Val time 101.45
Train loss 0.5607 	 Train acc: 0.7315
Val loss: 0.9849 	 Val acc: 0.5600
Best val acc: 0.7320 at epoch 7.


500it [21:27,  2.57s/it]


>> train 	 Epoch 10/25 	 Batch 499/500 	 Loss 0.5551 	 Running Acc 0.734 	 Total Acc 0.734 	 Avg Batch Time 2.5745
Time: train: 1287.25 	 Train loss 0.5551 	 Train acc: 0.7342


63it [01:42,  1.62s/it]


>> val 	 Loss 0.5870 	 Running Acc 1.421 	 Total Acc 0.716 	 Avg Batch Time 0.8175
Epoch 9/25 finished.
Train time: 1287.25 	 Val time 102.19
Train loss 0.5551 	 Train acc: 0.7342
Val loss: 0.5880 	 Val acc: 0.7160
Best val acc: 0.7320 at epoch 7.


500it [21:08,  2.54s/it]


>> train 	 Epoch 11/25 	 Batch 499/500 	 Loss 0.5573 	 Running Acc 0.738 	 Total Acc 0.738 	 Avg Batch Time 2.5362
Time: train: 1268.08 	 Train loss 0.5573 	 Train acc: 0.7376


63it [01:43,  1.64s/it]


>> val 	 Loss 0.5384 	 Running Acc 1.472 	 Total Acc 0.742 	 Avg Batch Time 0.8250
New best validation model, saving...
Epoch 10/25 finished.
Train time: 1268.08 	 Val time 103.13
Train loss 0.5573 	 Train acc: 0.7376
Val loss: 0.5390 	 Val acc: 0.7420
Best val acc: 0.7420 at epoch 10.


500it [21:39,  2.60s/it]


>> train 	 Epoch 12/25 	 Batch 499/500 	 Loss 0.5438 	 Running Acc 0.742 	 Total Acc 0.742 	 Avg Batch Time 2.5989
Time: train: 1299.47 	 Train loss 0.5438 	 Train acc: 0.7419


63it [01:49,  1.74s/it]


>> val 	 Loss 0.5453 	 Running Acc 1.437 	 Total Acc 0.724 	 Avg Batch Time 0.8794
Epoch 11/25 finished.
Train time: 1299.47 	 Val time 109.92
Train loss 0.5438 	 Train acc: 0.7419
Val loss: 0.5453 	 Val acc: 0.7240
Best val acc: 0.7420 at epoch 10.


500it [21:22,  2.57s/it]


>> train 	 Epoch 13/25 	 Batch 499/500 	 Loss 0.5466 	 Running Acc 0.745 	 Total Acc 0.745 	 Avg Batch Time 2.5656
Time: train: 1282.80 	 Train loss 0.5466 	 Train acc: 0.7450


63it [01:41,  1.61s/it]


>> val 	 Loss 0.5245 	 Running Acc 1.482 	 Total Acc 0.747 	 Avg Batch Time 0.8112
New best validation model, saving...
Epoch 12/25 finished.
Train time: 1282.80 	 Val time 101.40
Train loss 0.5466 	 Train acc: 0.7450
Val loss: 0.5246 	 Val acc: 0.7470
Best val acc: 0.7470 at epoch 12.


500it [21:52,  2.62s/it]


>> train 	 Epoch 14/25 	 Batch 499/500 	 Loss 0.5431 	 Running Acc 0.747 	 Total Acc 0.747 	 Avg Batch Time 2.6247
Time: train: 1312.33 	 Train loss 0.5431 	 Train acc: 0.7472


63it [01:44,  1.66s/it]


>> val 	 Loss 0.5312 	 Running Acc 1.460 	 Total Acc 0.736 	 Avg Batch Time 0.8343
Epoch 13/25 finished.
Train time: 1312.33 	 Val time 104.29
Train loss 0.5431 	 Train acc: 0.7472
Val loss: 0.5314 	 Val acc: 0.7360
Best val acc: 0.7470 at epoch 12.


500it [21:50,  2.62s/it]


>> train 	 Epoch 15/25 	 Batch 499/500 	 Loss 0.5455 	 Running Acc 0.744 	 Total Acc 0.744 	 Avg Batch Time 2.6207
Time: train: 1310.37 	 Train loss 0.5455 	 Train acc: 0.7435


63it [01:47,  1.71s/it]


>> val 	 Loss 0.5261 	 Running Acc 1.484 	 Total Acc 0.748 	 Avg Batch Time 0.8603
New best validation model, saving...
Epoch 14/25 finished.
Train time: 1310.37 	 Val time 107.53
Train loss 0.5455 	 Train acc: 0.7435
Val loss: 0.5263 	 Val acc: 0.7480
Best val acc: 0.7480 at epoch 14.


500it [21:30,  2.58s/it]


>> train 	 Epoch 16/25 	 Batch 499/500 	 Loss 0.5440 	 Running Acc 0.745 	 Total Acc 0.745 	 Avg Batch Time 2.5800
Time: train: 1290.02 	 Train loss 0.5440 	 Train acc: 0.7448


63it [01:41,  1.60s/it]


>> val 	 Loss 0.5253 	 Running Acc 1.464 	 Total Acc 0.738 	 Avg Batch Time 0.8082
Epoch 15/25 finished.
Train time: 1290.02 	 Val time 101.03
Train loss 0.5440 	 Train acc: 0.7448
Val loss: 0.5255 	 Val acc: 0.7380
Best val acc: 0.7480 at epoch 14.


500it [22:11,  2.66s/it]


>> train 	 Epoch 17/25 	 Batch 499/500 	 Loss 0.5543 	 Running Acc 0.737 	 Total Acc 0.737 	 Avg Batch Time 2.6622
Time: train: 1331.08 	 Train loss 0.5543 	 Train acc: 0.7365


63it [01:52,  1.79s/it]


>> val 	 Loss 0.5311 	 Running Acc 1.474 	 Total Acc 0.743 	 Avg Batch Time 0.9013
Epoch 16/25 finished.
Train time: 1331.08 	 Val time 112.66
Train loss 0.5543 	 Train acc: 0.7365
Val loss: 0.5313 	 Val acc: 0.7430
Best val acc: 0.7480 at epoch 14.


500it [21:38,  2.60s/it]


>> train 	 Epoch 18/25 	 Batch 499/500 	 Loss 0.5497 	 Running Acc 0.744 	 Total Acc 0.744 	 Avg Batch Time 2.5979
Time: train: 1298.93 	 Train loss 0.5497 	 Train acc: 0.7442


63it [01:47,  1.71s/it]


>> val 	 Loss 0.6295 	 Running Acc 1.319 	 Total Acc 0.665 	 Avg Batch Time 0.8610
Epoch 17/25 finished.
Train time: 1298.93 	 Val time 107.63
Train loss 0.5497 	 Train acc: 0.7442
Val loss: 0.6293 	 Val acc: 0.6650
Best val acc: 0.7480 at epoch 14.


500it [21:36,  2.59s/it]


>> train 	 Epoch 19/25 	 Batch 499/500 	 Loss 0.5465 	 Running Acc 0.741 	 Total Acc 0.741 	 Avg Batch Time 2.5921
Time: train: 1296.05 	 Train loss 0.5465 	 Train acc: 0.7406


63it [01:44,  1.66s/it]


>> val 	 Loss 0.5395 	 Running Acc 1.452 	 Total Acc 0.732 	 Avg Batch Time 0.8367
Epoch 18/25 finished.
Train time: 1296.05 	 Val time 104.59
Train loss 0.5465 	 Train acc: 0.7406
Val loss: 0.5402 	 Val acc: 0.7320
Best val acc: 0.7480 at epoch 14.


500it [21:29,  2.58s/it]


>> train 	 Epoch 20/25 	 Batch 499/500 	 Loss 0.5463 	 Running Acc 0.739 	 Total Acc 0.739 	 Avg Batch Time 2.5782
Time: train: 1289.08 	 Train loss 0.5463 	 Train acc: 0.7390


63it [01:40,  1.59s/it]


>> val 	 Loss 0.5492 	 Running Acc 1.450 	 Total Acc 0.731 	 Avg Batch Time 0.8018
Epoch 19/25 finished.
Train time: 1289.08 	 Val time 100.22
Train loss 0.5463 	 Train acc: 0.7390
Val loss: 0.5499 	 Val acc: 0.7310
Best val acc: 0.7480 at epoch 14.


500it [21:20,  2.56s/it]


>> train 	 Epoch 21/25 	 Batch 499/500 	 Loss 0.5468 	 Running Acc 0.741 	 Total Acc 0.741 	 Avg Batch Time 2.5614
Time: train: 1280.71 	 Train loss 0.5468 	 Train acc: 0.7409


63it [01:40,  1.59s/it]


>> val 	 Loss 0.5423 	 Running Acc 1.458 	 Total Acc 0.735 	 Avg Batch Time 0.8035
Epoch 20/25 finished.
Train time: 1280.71 	 Val time 100.44
Train loss 0.5468 	 Train acc: 0.7409
Val loss: 0.5430 	 Val acc: 0.7350
Best val acc: 0.7480 at epoch 14.


500it [21:23,  2.57s/it]


>> train 	 Epoch 22/25 	 Batch 499/500 	 Loss 0.5394 	 Running Acc 0.744 	 Total Acc 0.744 	 Avg Batch Time 2.5663
Time: train: 1283.13 	 Train loss 0.5394 	 Train acc: 0.7440


63it [01:41,  1.62s/it]


>> val 	 Loss 0.5236 	 Running Acc 1.476 	 Total Acc 0.744 	 Avg Batch Time 0.8152
Epoch 21/25 finished.
Train time: 1283.13 	 Val time 101.90
Train loss 0.5394 	 Train acc: 0.7440
Val loss: 0.5240 	 Val acc: 0.7440
Best val acc: 0.7480 at epoch 14.


500it [21:35,  2.59s/it]


>> train 	 Epoch 23/25 	 Batch 499/500 	 Loss 0.5442 	 Running Acc 0.746 	 Total Acc 0.746 	 Avg Batch Time 2.5918
Time: train: 1295.89 	 Train loss 0.5442 	 Train acc: 0.7462


63it [01:40,  1.60s/it]


>> val 	 Loss 0.5212 	 Running Acc 1.470 	 Total Acc 0.741 	 Avg Batch Time 0.8045
Epoch 22/25 finished.
Train time: 1295.89 	 Val time 100.57
Train loss 0.5442 	 Train acc: 0.7462
Val loss: 0.5214 	 Val acc: 0.7410
Best val acc: 0.7480 at epoch 14.


500it [21:04,  2.53s/it]


>> train 	 Epoch 24/25 	 Batch 499/500 	 Loss 0.5432 	 Running Acc 0.745 	 Total Acc 0.745 	 Avg Batch Time 2.5290
Time: train: 1264.52 	 Train loss 0.5432 	 Train acc: 0.7451


63it [01:41,  1.62s/it]


>> val 	 Loss 0.5313 	 Running Acc 1.460 	 Total Acc 0.736 	 Avg Batch Time 0.8140
Epoch 23/25 finished.
Train time: 1264.52 	 Val time 101.75
Train loss 0.5432 	 Train acc: 0.7451
Val loss: 0.5312 	 Val acc: 0.7360
Best val acc: 0.7480 at epoch 14.


500it [21:18,  2.56s/it]


>> train 	 Epoch 25/25 	 Batch 499/500 	 Loss 0.5415 	 Running Acc 0.750 	 Total Acc 0.750 	 Avg Batch Time 2.5566
Time: train: 1278.31 	 Train loss 0.5415 	 Train acc: 0.7496


63it [01:40,  1.59s/it]
  best_model = torch.load(os.path.join(model_path, "best-val-model.pt"), map_location=device)


>> val 	 Loss 0.5145 	 Running Acc 1.500 	 Total Acc 0.756 	 Avg Batch Time 0.8023
New best validation model, saving...
Epoch 24/25 finished.
Train time: 1278.31 	 Val time 100.29
Train loss 0.5415 	 Train acc: 0.7496
Val loss: 0.5147 	 Val acc: 0.7560
Best val acc: 0.7560 at epoch 24.


63it [01:41,  1.60s/it]

>> test 	 Loss 0.5116 	 Running Acc 1.484 	 Total Acc 0.748 	 Avg Batch Time 0.8087
Final  tensor([[1.0000, 0.8497, 0.1503],
        [0.0000, 0.5746, 0.4254],
        [1.0000, 0.2029, 0.7971],
        ...,
        [1.0000, 0.3306, 0.6694],
        [0.0000, 0.2611, 0.7389],
        [1.0000, 0.1844, 0.8156]])
Test: Loss 0.5108 	 Acc 0.7480 	 AUC: 0.8320 	 1/eB 0.3: 33.3333 	 1/eB 0.5: 12.5000



