In [80]:
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric
import random 
import numpy as np
from tqdm import tqdm
import copy
import itertools



In [81]:
## Data preparation

#We generate graphs as instructed, resulting in 4 splits for training and cross-validation, and one split for testing.
def gen_graph_pairs(min_n, max_n):
  pairs = []
  for n in range(min_n, max_n+1):
    for k in range(3, n-2):
      single = nx.cycle_graph(n)
      disjoint = nx.disjoint_union(nx.cycle_graph(k),
                                   nx.cycle_graph(n-k))
      pairs.append((single, disjoint))
  return pairs
def to_pyg(g, label):
  data = torch_geometric.utils.from_networkx(g)
  data.x = torch.zeros((g.number_of_nodes(), 50))
  data.y = torch.tensor([label])
  return data


def split(data_pairs):
  random.shuffle(data_pairs)
  split_size = len(data_pairs)//5
  splits = [data_pairs[i:i+split_size] for i in 
            range(0, len(data_pairs), split_size)]
  for i in range(len(splits)):    
    splits[i] = list(sum(splits[i], ())) # flatten list of tuples
    random.shuffle(splits[i]) # shuffle positive/negative examples in split
  return splits



In [82]:
## Quantum
### Initial experiment
graph_pairs = gen_graph_pairs(6, 6)
data_pairs = [(to_pyg(g1, 1), to_pyg(g2, 0)) for (g1,g2) in graph_pairs]
data_pairs # two 3-cycles vs one 6-cycle
two_triangles = data_pairs[0][1]
one_6cycle = data_pairs[0][0]
class QGNN(torch.nn.Module):
    def __init__(self, n_qubits=1, n_nodes=6, verbose = False):
        super(QGNN, self).__init__()
        self.verbose = verbose
        self.n_qubits = n_qubits
        self.n_nodes = n_nodes
        self.rots = torch.rand(n_qubits, n_qubits)*2*np.pi
        self.total_qubits = n_qubits*n_nodes
        self.state_dim = 2**self.total_qubits
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(self.state_dim, self.state_dim//2),
            torch.nn.ReLU(),
            torch.nn.Linear(self.state_dim//2, 1)
        )

    # Note: this implementation was chosen not for efficiency, but in the
    # hope that thinking in such low-level operations might help give
    # intuitions about the expressivity and inductive bias of the model.
    def forward(self, graph):
        edges = graph.edge_index

        # start as |++...+>, uniform superposition over all bitstrings
        state = torch.ones(self.state_dim, dtype=torch.cfloat) / (2**(self.total_qubits/2))

        if self.verbose:
          print("Sanity check: init prob mass", torch.sum(torch.square(torch.abs(state))).numpy())

        
        for edge in range(edges.shape[1]):
          # for each edge, apply matrix of CZ rotations
          for l in range(self.n_qubits):            
            for r in range(self.n_qubits):
              # apply CZ(rots[l][r]) between appropriate qubits
              ql = (edges[0][edge]*self.n_qubits + l).numpy()
              qr = (edges[1][edge]*self.n_qubits + r).numpy()
              for s in range(self.state_dim):
                # search for state entries where ql, qr are 1 and apply phase 
                if (s & (1 << ql)) and (s & (1 << qr)):
                  #if torch.abs(state[s]) > 0.001:
                  #  print('Applying phase to state ' + str(s) + ' from edge between qubits ' + str(ql)  + ', ' + str(qr))
                  state[s] *= torch.exp(self.rots[l][r]*1j)
        if self.verbose:
          print(state.numpy())
        # Apply Hadamard on each qubit.
        # Each bitstring goes to each bitstring with +-1/2^(n/2) of its
        # probability amplitude, with the sign determined by the number of
        # 1s at the same position in source and target
        out = torch.zeros(self.state_dim, dtype=torch.cfloat) 
        norm = 1/(2**(self.total_qubits/2))
        for s in range(self.state_dim):
          for t in range(self.state_dim):
            weight = state[s]*norm
            for i in range(self.total_qubits):
              if (s & (1 << i)) and (t & (1 << i)):
                weight *= -1
            out[t] += weight

        if self.verbose:
          print("Sanity check: final prob mass", torch.sum(torch.square(torch.abs(out))).numpy())        
        
        probs = torch.square(torch.abs(out))
        if self.verbose:
          print(probs.numpy())
        odds = 0
        evens = 0
        big = 0
        small = 0
        counts = np.zeros(self.total_qubits+1)
        for s in range(len(probs)):
          ones = 0
          for i in range(self.total_qubits):
            if (s & (1 << i)):
              ones += 1
          counts[ones] += probs[s]
          if ones >= self.total_qubits//2:
            big += probs[s]          
          else:
            small += probs[s]
          if ones % 2 == 0:
            evens += probs[s]
          else:
            odds += probs[s]
        if self.verbose:
          print("evens vs odds: " + str(evens) + ", " + str(odds))
          print("big vs small: " + str(big) + ", " + str(small))
          print("1-count distr: ")
          print(counts)
        return counts

In [83]:
import pennylane as qml

class QGNN(torch.nn.Module):
    def __init__(self, n_qubits=1, n_nodes=6, verbose=False):
        super(QGNN, self).__init__()
        self.verbose = verbose
        self.n_qubits = n_qubits
        self.n_nodes = n_nodes
        self.total_qubits = n_qubits * n_nodes
        self.state_dim = 2 ** self.total_qubits

        # Match original parameterization
        self.rots = torch.rand(n_qubits, n_qubits) * 2 * np.pi

        # PennyLane device with statevector access
        self.dev = qml.device("default.qubit", wires=self.total_qubits, shots=None)

        # Define the quantum circuit as a QNode that outputs the full statevector
        @qml.qnode(self.dev, interface="torch", diff_method=None)
        def circuit(theta, edges_tensor):
            # Prepare |+...+> = H on all qubits starting from |0...0>
            for w in range(self.total_qubits):
                qml.Hadamard(wires=w)

            # Apply per-edge Controlled-Phase(theta) between appropriate node-qubit wires
            # Equivalent to CZ(theta): diag(1,1,1, e^{i theta})
            # Implement as controlled RZ with phase distributed via global frame:
            # PennyLane's ControlledPhaseShift applies diag(1,1,1,e^{i theta}) directly.
            edges = edges_tensor
            for e in range(edges.shape[1]):
                u = int(edges[0, e].item())
                v = int(edges[1, e].item())
                # For n_qubits=1, only (l=0, r=0) is used; keep loops for compatibility
                for l in range(self.n_qubits):
                    for r in range(self.n_qubits):
                        ql = u * self.n_qubits + l
                        qr = v * self.n_qubits + r
                        qml.ControlledPhaseShift(theta, wires=[ql, qr])

            # Final layer: Hadamard on each qubit (as in original)
            for w in range(self.total_qubits):
                qml.Hadamard(wires=w)

            # Return full statevector
            return qml.state()

        self._circuit = circuit

        # Keep the MLP definition to avoid breaking references, though it is unused in forward’s return
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(self.state_dim, self.state_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(self.state_dim // 2, 1)
        )

    def forward(self, graph):
        edges = graph.edge_index  # shape [2, E], torch.long
        # Ensure CPU tensors for PennyLane
        if edges.is_cuda:
            edges = edges.cpu()

        theta = self.rots[0, 0]  # only single-parameter case used in the notebook
        # Run circuit and get statevector as a complex torch tensor
        state = self._circuit(theta, edges)

        # Convert to probabilities
        probs = torch.abs(state) ** 2  # shape [2**total_qubits]

        if self.verbose:
            print("Sanity check: final prob mass", torch.sum(probs).detach().cpu().numpy())

        # Compute Hamming-weight histogram counts[k] = sum_{|s|=k} probs[s]
        counts = np.zeros(self.total_qubits + 1, dtype=np.float64)
        # Iterate over computational basis indices
        # Using NumPy loop for simplicity; keep identical result to original
        probs_np = probs.detach().cpu().numpy()
        for s in range(probs_np.shape[0]):
            ones = int(bin(s).count("1"))
            counts[ones] += probs_np[s]

        return counts


In [84]:
# from qiskit import QuantumCircuit
# from qiskit.quantum_info import Statevector
# from qiskit.circuit.library.standard_gates import HGate, CPhaseGate

# class QGNN(torch.nn.Module):
#     def __init__(self, n_qubits=1, n_nodes=6, verbose=False):
#         super(QGNN, self).__init__()
#         self.verbose = verbose
#         self.n_qubits = n_qubits
#         self.n_nodes = n_nodes
#         self.total_qubits = n_qubits * n_nodes
#         self.state_dim = 2 ** self.total_qubits

#         # Match original parameterization
#         self.rots = torch.rand(n_qubits, n_qubits) * 2 * np.pi

#         # Keep the MLP to preserve attribute presence
#         self.mlp = torch.nn.Sequential(
#             torch.nn.Linear(self.state_dim, self.state_dim // 2),
#             torch.nn.ReLU(),
#             torch.nn.Linear(self.state_dim // 2, 1)
#         )

#     def _build_circuit(self, theta, edges):
#         qc = QuantumCircuit(self.total_qubits)

#         # Prepare |+...+>
#         for w in range(self.total_qubits):
#             qc.append(HGate(), [w])

#         # Apply per-edge Controlled-Phase(theta): CP(theta) = diag(1,1,1,e^{i theta})
#         # Qiskit CP(θ) matches ControlledPhaseShift(θ)
#         for e in range(edges.shape[1]):
#             u = int(edges[0, e].item())
#             v = int(edges[1, e].item())
#             for l in range(self.n_qubits):
#                 for r in range(self.n_qubits):
#                     ql = u * self.n_qubits + l
#                     qr = v * self.n_qubits + r
#                     qc.append(CPhaseGate(theta=float(theta)), [ql, qr])

#         # Final Hadamards
#         for w in range(self.total_qubits):
#             qc.append(HGate(), [w])

#         return qc

#     def forward(self, graph):
#         edges = graph.edge_index
#         if edges.is_cuda:
#             edges = edges.cpu()

#         theta = self.rots[0, 0].item()  # use single parameter

#         qc = self._build_circuit(theta, edges)
#         # Exact statevector simulation
#         sv = Statevector.from_label('0' * self.total_qubits)
#         sv = sv.evolve(qc)
#         probs = np.abs(sv.data) ** 2  # NumPy array length 2**total_qubits

#         if self.verbose:
#             print("Sanity check: final prob mass", probs.sum())

#         counts = np.zeros(self.total_qubits + 1, dtype=np.float64)
#         for s in range(probs.shape[0]):
#             ones = int(bin(s).count("1"))
#             counts[ones] += probs[s]

#         return counts


In [85]:

### Permute edge lists to confirm invariance
perm = np.random.permutation(two_triangles.edge_index.shape[1])
two_triangles_alt = copy.deepcopy(two_triangles)
two_triangles_alt.edge_index = two_triangles_alt.edge_index[:, perm]
one_6cycle_alt = copy.deepcopy(one_6cycle)
one_6cycle_alt.edge_index = one_6cycle_alt.edge_index[:, perm]
print(two_triangles.edge_index)
print(two_triangles_alt.edge_index)


tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
        [1, 2, 0, 2, 0, 1, 4, 5, 3, 5, 3, 4]])
tensor([[4, 2, 3, 0, 0, 5, 2, 3, 5, 1, 4, 1],
        [3, 0, 5, 2, 1, 3, 1, 4, 4, 0, 5, 2]])


In [86]:

qgnn = QGNN()
print(qgnn(two_triangles))
print(qgnn(two_triangles_alt))
print(qgnn(one_6cycle))
print(qgnn(one_6cycle_alt))

[4.79294549e-01 3.06249534e-01 1.41127622e-01 5.70343779e-02
 1.32447172e-02 2.65255665e-03 3.96643408e-04]
[4.79294549e-01 3.06249534e-01 1.41127622e-01 5.70343779e-02
 1.32447172e-02 2.65255665e-03 3.96643408e-04]
[4.78149946e-01 3.06249534e-01 1.43269399e-01 5.70343779e-02
 1.23949709e-02 2.65255665e-03 2.49215487e-04]
[4.78149946e-01 3.06249534e-01 1.43269399e-01 5.70343779e-02
 1.23949709e-02 2.65255665e-03 2.49215487e-04]


In [None]:

### Animate 1-count distribution as a function of the single parameter
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import rc
rc('animation', html='jshtml')

def create_anim(inps, frames, titles=None):
  fig, axs = plt.subplots(len(inps))
  for i, ax in enumerate(axs):
    ax.set_ylim(bottom=0, top=1)
    if titles:
      ax.set_title(titles[i])
  bins = inps[0].x.shape[0] + 1
  bar_containers = [ax.bar(np.linspace(0, 6, 7), np.zeros(bins)) for ax in axs]
  rotations = np.linspace(-np.pi/2, +np.pi/2, frames)
