In [31]:
import pennylane as qml
from compare_sdk import *
from common_functions import *
import networkx as nx
# from functools import partial

In [32]:
[ds] = qml.data.load("ketgpt")

In [33]:
# example
seed = 0
@qml.qnode(qml.device('default.qubit'))
def circuit():
    for op in ds.circuits[seed]:
        name = op.name
        params = op.parameters
        wires = op.wires
        if name == 'QubitUnitary':
            continue
        elif name == 'CZ':
            qml.Hadamard(wires[1])
            qml.CNOT(wires)
            qml.Hadamard(wires[1])
        elif name == 'U1':
            qml.RZ(params[0], wires=wires)
        elif name == 'U2':
            qml.RZ(params[0], wires=wires)
            qml.RY(np.pi/2, wires=wires)
            qml.RZ(params[1], wires=wires)
        else:
            qml.apply(op)
    return qml.state()

In [34]:
G, circuit_info = qnode_info(qnode=circuit)
cc = list(nx.connected_components(G.to_undirected()))
cc = [c for c in cc if len(c)>1]
subcircuit_info = [[circuit_info[extract_index(g)] for g in c] for c in cc]
sub_qnode = [info_to_qnode(info) for info in subcircuit_info]

In [43]:
def extract_info_from_qc(qc):
    circuit_info = []
    for instruction_obj in qc.data:
        op = instruction_obj.operation
        qargs = instruction_obj.qubits
        label = op.label if op.label else op.name.capitalize()
        wires = [qc.find_bit(q).index for q in qargs]
        params = [round(float(p),3) if hasattr(p, '__float__') else p for p in op.params]
        
        circuit_info.append({
            'label': label,
            'wires': wires,
            'params': params
        })
    return circuit_info

In [None]:
compiled_info = []
for qnode in sub_qnode:
    qc = qiskit_optimizer(sub_qnode[0]) # test
    compiled_info.append(extract_info_from_qc(qc))

next steps

In [None]:
# to dataset
import torch
import torch.nn as nn
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader

gate_map = { # one-hot encoding
    "H":  [1,0,0,0],
    "Cx": [0,1,0,0],
    "Rx": [0,0,1,0],
    "Ry": [0,0,0,1],
}

def circuit_to_data(circuit, wire_embed=None):
    node_features = []
    edge_list = []
    node_id = 0

    for instr in circuit:
        gate = instr["label"]
        wires = instr["wires"]
        params = instr["params"]

        gate_feat = torch.tensor(gate_map[gate], dtype=torch.float)
        max_params=2
        padded = [0.0]*max_params
        for i, p in enumerate(params[:max_params]):
            padded[i] = p
        param_vec = torch.tensor(padded, dtype=torch.float)

        node_feat = torch.cat([gate_feat, param_vec])
        node_features.append(node_feat)
    
    last_gate_for_wire = {}
    for w in wires:
        if w in last_gate_for_wire:
            edge_list.append([last_gate_for_wire[w], node_id])
        last_gate_for_wire[w] = node_id
    
    node_id += 1

    x = torch.stack(node_features)
    if edge_list:
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2,0), dtype=torch.long)

    return Data(x=x, edge_index=edge_index)

# -----------------------------
# 3. Dataset class
# -----------------------------
class QuantumCircuitDataset(InMemoryDataset):
    def __init__(self, circuits, transform=None, pre_transform=None):
        super().__init__('.', transform, pre_transform)
        data_list = [circuit_to_data(c) for c in circuits]
        self.data, self.slices = self.collate(data_list)

# -----------------------------
# 4. Example usage
# -----------------------------

circuits = [circuit1, circuit2]

dataset = QuantumCircuitDataset(circuits)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

for batch in loader:
    print("Batch node features shape:", batch.x.shape)
    print("Batch edge index:", batch.edge_index)

In [None]:
def circuit_to_data(circuit, gate_map, max_params=2):
    node_features = []
    edge_list = []
    node_id = 0

    # Track the last gate node seen for each wire
    last_gate_for_wire = {}

    for instr in circuit:
        gate = instr["label"]
        wires = instr["wires"]
        params = instr["params"]

        # Gate one-hot
        gate_feat = torch.tensor(gate_map[gate], dtype=torch.float)

        # Params padded to length max_params
        padded = [0.0]*max_params
        for i, p in enumerate(params[:max_params]):
            padded[i] = p
        param_vec = torch.tensor(padded, dtype=torch.float)

        # Node feature = concat(gate, params)
        node_feat = torch.cat([gate_feat, param_vec])
        node_features.append(node_feat)

        # Add edges: connect to previous gate(s) on the same wire(s)
        for w in wires:
            if w in last_gate_for_wire:
                edge_list.append([last_gate_for_wire[w], node_id])
            last_gate_for_wire[w] = node_id

        node_id += 1

    x = torch.stack(node_features)
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index)
