In [1]:
!pip install pytorch-lightning
!pip install torch-geometric
!pip install energyflow
!pip install tensorcircuit
!pip install jax
!pip install jaxlib
!pip install torch-cluster

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.6.2-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.0-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.1.0->pytorch-lightning)
  Dow

In [2]:
!pip install pennylane

Collecting pennylane
  Downloading PennyLane-0.40.0-py3-none-any.whl.metadata (10 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting tomlkit (from pennylane)
  Downloading tomlkit-0.13.2-py3-none-any.whl.metadata (2.7 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.40 (from pennylane)
  Downloading PennyLane_Lightning-0.40.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (27 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.40->pennylane)
  Downloading scipy_openblas32-0.3.29.0.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5

In [3]:
!pip install torch-scatter

Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/108.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp311-cp311-linux_x86_64.whl size=3618829 sha256=ee93fb1524b3aaa549ce1b9e1f4bc15eb1e28cca8a3f2689c6692b2c0f5294c1
  Stored in directory: /root/.cache/pip/wheels/b8/d4/0e/a80af2465354ea7355a2c153b11af2da739cfcf08b6c0b28e2
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2


In [4]:
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import knn_graph, global_mean_pool
from sklearn.model_selection import train_test_split
import energyflow as ef
import pennylane as qml
import numpy as np

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
print("Loading dataset...")
qg_dataset = ef.qg_jets.load(num_data=5000, pad=True, ncol=4, generator='pythia')
x = qg_dataset[0]  # Jet features (pT, y, phi, energy)
y = torch.tensor(qg_dataset[1], dtype=torch.long)

Loading dataset...
Downloading QG_jets.npz from https://zenodo.org/record/3164691/files/QG_jets.npz?download=1 to /root/.energyflow/datasets


In [7]:
def construct_graph(jet_features, label, k=5):
    """Converts a jet (list of particles) into a graph using k-NN connectivity."""
    node_features = torch.tensor(jet_features, dtype=torch.float)
    edge_index = knn_graph(node_features, k=k, loop=False)

    return Data(x=node_features, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))

jet_graphs = [construct_graph(x[i], y[i], k=5) for i in range(len(x))]

In [None]:
train_data, val_data = train_test_split(jet_graphs, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(jet_graphs, test_size=0.1, random_state=42)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
val_loader = DataLoader(val_data, batch_size=128, shuffle=False)

In [None]:
def default_EDU(bond_params, rzz_param, wires):
    """EDU (Entanglement Driven Unit) for message passing."""
    w0, w1 = int(wires[0]), int(wires[1])

    qml.RZ(bond_params[0], wires=w0)
    qml.RX(bond_params[1], wires=w0)
    qml.RZ(bond_params[0], wires=w1)
    qml.RX(bond_params[1], wires=w1)
    qml.CNOT(wires=[w0, w1])
    qml.RZ(rzz_param, wires=w1)
    qml.CNOT(wires=[w0, w1])
    qml.RX(bond_params[0], wires=w0)
    qml.RZ(bond_params[1], wires=w0)
    qml.RX(bond_params[0], wires=w1)
    qml.RZ(bond_params[1], wires=w1)

In [None]:
class QuantumNet(nn.Module):
    """Quantum Neural Network Layer for Message Passing in Graphs."""
    def __init__(self, num_qubits, num_layers):
        super().__init__()
        self.num_qubits = num_qubits
        self.num_layers = num_layers
        self.params = nn.Parameter(torch.randn(num_layers * 4))

        self.dev = qml.device('default.qubit', wires=num_qubits)

    def forward(self, x, edge_index):
        batch_size = x.shape[0]
        num_nodes = min(batch_size, self.num_qubits)
        x = x[:num_nodes]
        edge_index = edge_index[:, (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)]

        @qml.qnode(self.dev, interface='torch', diff_method="backprop")
        def circuit():
            for i in range(num_nodes):
                qml.RY(x[i][0], wires=i)
                qml.RZ(x[i][1], wires=i)

            for layer in range(self.num_layers):
                rzz_param = self.params[layer]
                for start, end in edge_index.T:
                    default_EDU(self.params[layer:layer+2], rzz_param, [int(start), int(end)])

            return [qml.expval(qml.PauliZ(i)) for i in range(num_nodes)]

        output = torch.tensor(circuit(), dtype=torch.float32)
        return output.unsqueeze(0).expand(batch_size, -1)



class QGCNConv(torch_geometric.nn.GCNConv):
    """Quantum Graph Convolution Layer (GCN + QuantumNet)."""
    def __init__(self, in_channels, out_channels, num_layers=1, **kwargs):
        super().__init__(in_channels, out_channels, **kwargs)
        self.quantum_layer = QuantumNet(in_channels, num_layers)

    def forward(self, x, edge_index):
        return self.quantum_layer(x, edge_index)

In [None]:
class GraphGNNModel(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, **kwargs):
        super().__init__()
        self.embed = nn.Linear(c_in, c_hidden)
        self.GNN = QGCNConv(c_hidden, c_hidden, num_layers=2, **kwargs)
        self.head = nn.Linear(c_hidden, c_out)

    def forward(self, x, edge_index, batch_idx):
        batch_size = x.shape[0]
        batch_idx = torch.clamp(batch_idx[:batch_size], min=0, max=batch_size - 1)

        x = self.embed(x)
        x = self.GNN(x, edge_index)

        batch_idx = batch_idx[:x.shape[0]]
        x = global_mean_pool(x, batch_idx)

        return self.head(x)


class GraphLevelGNN(pl.LightningModule):
    def __init__(self, c_in, c_hidden, c_out, **model_kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = GraphGNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_out, **model_kwargs)
        self.loss_module = nn.CrossEntropyLoss()

    def forward(self, batch):
        x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
        x = self.model(x, edge_index, batch_idx)
        return x

    def training_step(self, batch, batch_idx):
        y_pred = self.forward(batch)
        loss = self.loss_module(y_pred, batch.y.view(-1))
        acc = (y_pred.argmax(dim=1) == batch.y.view(-1)).float().mean()
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_pred = self.forward(batch)
        loss = self.loss_module(y_pred, batch.y.view(-1))
        acc = (y_pred.argmax(dim=1) == batch.y.view(-1)).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [None]:
def train_qgnn(num_epochs=50):
    trainer = pl.Trainer(accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=1, max_epochs=num_epochs)
    model = GraphLevelGNN(c_in=4, c_hidden=6, c_out=2)
    trainer.fit(model, train_loader, val_loader)
    return model

best_model = train_qgnn()


Using device: cpu
Loading dataset...


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type             | Params | Mode 
---------------------------------------------------------
0 | model       | GraphGNNModel    | 94     | train
1 | loss_module | CrossEntropyLoss | 0      | train
---------------------------------------------------------
94        Trainable params
0         Non-trainable params
94        Total params
0.000     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4448. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1112. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.
