In [1]:
import asyncio
import random
from dotenv import load_dotenv
import importlib.util
import json
import math
import os
import tempfile
import qiskit
from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit.circuit.classicalfunction import classical_function
from qiskit.circuit.classicalfunction.types import Int1
from qiskit.circuit.library import grover_operator, QFT
from qiskit_aer import AerSimulator
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

load_dotenv()
API_TOKEN = os.getenv("API_TOKEN")
API_INSTANCE = os.getenv("API_INSTANCE", None)
service = QiskitRuntimeService(channel="ibm_quantum", token=API_TOKEN, instance=API_INSTANCE)
backend = service.backend(name="ibm_rensselaer")

In [3]:
def get_variables(num_vars):
    return ["x" + str(i) for i in range(num_vars)]

def get_oracle(function_string):
    """
    given a classical function in string form (such as the output of get_classical_function), returns a quantum oracle circuit
    for that function
    """
    # For now, we write the function to a file and import it then delete the file, since the classical function synthesis wants source code to work with
    function_name = function_string.split("(")[0].split("def")[1].strip()
    required_imports = """
from qiskit.circuit.classicalfunction import classical_function
from qiskit.circuit.classicalfunction.types import Int1
"""
    with tempfile.TemporaryDirectory() as temp_dir:
        module_name = "temp_boolean_func"
        file_path = os.path.join(temp_dir, f"{module_name}.py")

        with open(file_path, "w") as f:
            f.write(required_imports)
            f.write(function_string)

        spec = importlib.util.spec_from_file_location(module_name, file_path)
        temp_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(temp_module)

        classical_function = getattr(temp_module, function_name)
        oracle = classical_function.synth(registerless=False)

        return oracle

In [4]:
class SortPairNode:
    def __init__(self, high, low):
        self.high = high
        self.low = low

def get_sort_statements(variables):
    num_variables = len(variables)
    statements = []

    nodes = [[SortPairNode(None, None) for _ in range(num_variables)] for _ in range(num_variables)]
    for i in range(num_variables):
        nodes[i][0] = SortPairNode(variables[i], None)

    for i in range(1, num_variables):
        for j in range(1, i+1):
            s_high = f"s_{i}_{j}_high"
            s_low = f"s_{i}_{j}_low"
            nodes[i][j] = SortPairNode(s_high, s_low)

            if j == i:
                statements.append(f"{s_high} = {nodes[i-1][j-1].high} or {nodes[i][j-1].high}")
                statements.append(f"{s_low} = {nodes[i-1][j-1].high} and {nodes[i][j-1].high}")
            else:
                statements.append(f"{s_high} = {nodes[i-1][j].low} or {nodes[i][j-1].high}")
                statements.append(f"{s_low} = {nodes[i-1][j].low} and {nodes[i][j-1].high}")

    outputs = [nodes[num_variables-1][num_variables-1].high] + [nodes[num_variables-1][i].low for i in range(num_variables-1, 0, -1)]

    return statements, outputs

print(get_sort_statements(["x0", "x1", "x2", "x3"]))

(['s_1_1_high = x0 or x1', 's_1_1_low = x0 and x1', 's_2_1_high = s_1_1_low or x2', 's_2_1_low = s_1_1_low and x2', 's_2_2_high = s_1_1_high or s_2_1_high', 's_2_2_low = s_1_1_high and s_2_1_high', 's_3_1_high = s_2_1_low or x3', 's_3_1_low = s_2_1_low and x3', 's_3_2_high = s_2_2_low or s_3_1_high', 's_3_2_low = s_2_2_low and s_3_1_high', 's_3_3_high = s_2_2_high or s_3_2_high', 's_3_3_low = s_2_2_high and s_3_2_high'], ['s_3_3_high', 's_3_3_low', 's_3_2_low', 's_3_1_low'])


In [5]:
def test_sort_4(x0, x1, x2, x3):
    s_1_1_high = x0 or x1
    s_1_1_low = x0 and x1
    s_2_1_high = s_1_1_low or x2
    s_2_1_low = s_1_1_low and x2
    s_2_2_high = s_1_1_high or s_2_1_high
    s_2_2_low = s_1_1_high and s_2_1_high
    s_3_1_high = s_2_1_low or x3
    s_3_1_low = s_2_1_low and x3
    s_3_2_high = s_2_2_low or s_3_1_high
    s_3_2_low = s_2_2_low and s_3_1_high
    s_3_3_high = s_2_2_high or s_3_2_high
    s_3_3_low = s_2_2_high and s_3_2_high
    return (s_3_3_high, s_3_3_low, s_3_2_low, s_3_1_low)

assert test_sort_4(1, 0, 1, 0) == (1, 1, 0, 0)
assert test_sort_4(0, 1, 0, 1) == (1, 1, 0, 0)
assert test_sort_4(1, 0, 0, 1) == (1, 1, 0, 0)
assert test_sort_4(0, 1, 1, 0) == (1, 1, 0, 0)
assert test_sort_4(1, 1, 0, 0) == (1, 1, 0, 0)
assert test_sort_4(0, 0, 1, 1) == (1, 1, 0, 0)
assert test_sort_4(1, 0, 1, 1) == (1, 1, 1, 0)
assert test_sort_4(0, 1, 1, 1) == (1, 1, 1, 0)
assert test_sort_4(1, 1, 1, 0) == (1, 1, 1, 0)
assert test_sort_4(1, 1, 1, 1) == (1, 1, 1, 1)

In [6]:
def construct_clique_verifier(graph, as_classical_function=False, clique_size=None):
    """ 
    Given a graph in the form of binary string 
    e_11 e_12 e_13 ... e_1n e_23 e_24 ... e_2n ... e_n-1n, returns the string of a python function that takes n boolean variables denoting vertices 
    True if in the clique and False if not,
    and returns whether the input is a clique of size at least n/2 in the graph.

    if clique_size is unspecified, the default is to require at least n/2 vertices
    """
    n = int((1 + (1 + 8*len(graph))**0.5) / 2)
    variables = get_variables(n)
    statements, sort_outputs = get_sort_statements(variables)
    clique_size = clique_size or n//2

    # count whether there are at least clique_size vertices in the clique
    statements.append("count = " + sort_outputs[clique_size-1])

    # whenever there is not an edge between two vertices, they cannot both be in the clique
    if as_classical_function:
        statements.append(f"edge_sat = {variables[0]} or not {variables[0]}") # this should be initialized to True, but qiskit classical function cannot yet parse True
    else:
        statements.append("edge_sat = True")
    edge_idx = 0
    for i in range(n):
        for j in range(i+1, n):
            edge = graph[edge_idx]
            edge_idx += 1
            if edge == '0':
                # TODO: we could reduce depth to log instead of linear by applying AND more efficiently
                # for now, we'll let tweedledum optimize this
                statements.append(f"edge_sat = edge_sat and not ({variables[i]} and {variables[j]})")

    statements.append("return count and edge_sat")
    if as_classical_function:
        output = "@classical_function\ndef is_clique(" + ", ".join([f"{v} : Int1" for v in variables]) + ") -> Int1:\n    "
    else:
        output = "def is_clique(" + ", ".join(variables) + "):\n    "
    output += "\n    ".join(statements)
    return output


In [None]:
#print(construct_clique_verifier("110001000000000"))
#print(construct_clique_verifier("111111111111111"))
#print(construct_clique_verifier("110001100100000"))
print(construct_clique_verifier("100", clique_size=2))


def is_clique(x0, x1, x2):
    s_1_1_high = x0 or x1
    s_1_1_low = x0 and x1
    s_2_1_high = s_1_1_low or x2
    s_2_1_low = s_1_1_low and x2
    s_2_2_high = s_1_1_high or s_2_1_high
    s_2_2_low = s_1_1_high and s_2_1_high
    count = s_2_2_low
    edge_sat = True
    edge_sat = edge_sat and not (x0 and x2)
    edge_sat = edge_sat and not (x1 and x2)
    return count and edge_sat
def is_clique(x0, x1, x2):
    s_1_1_high = x0 or x1
    s_1_1_low = x0 and x1
    s_2_1_high = s_1_1_low or x2
    s_2_1_low = s_1_1_low and x2
    s_2_2_high = s_1_1_high or s_2_1_high
    s_2_2_low = s_1_1_high and s_2_1_high
    count = s_2_2_low
    edge_sat = True
    return count and edge_sat


In [None]:
def direct_clique_oracle_circuit(graph, clique_size=None):
    """ 
    Given a graph in the form of binary string 
    e_11 e_12 e_13 ... e_1n e_23 e_24 ... e_2n ... e_n-1n, returns a quantum oracle circuit for the 
    verifier function of such a clique.

    if clique_size is unspecified, the default is to require at least n/2 vertices
    """
    n = int((1 + (1 + 8*len(graph))**0.5) / 2)
    ret_qubit = n
    edge_sat_qubit = n + 1
    count_sat_qubit = n + 2
    variables = get_variables(n)
    statements, sort_outputs = get_sort_statements(variables)
    clique_size = clique_size or n//2

    # map variable names to qubit indices
    var_map = {}
    for i in range(n):
        var_map[variables[i]] = i

    num_sort_temps = len(statements) - 1
    num_missing_edges = len(list(filter(lambda x: x == '0', graph)))

    qc = qiskit.QuantumCircuit(n + 3 + num_missing_edges + num_sort_temps, n)
    operations = []

    # whenever there is not an edge between two vertices, they cannot both be in the clique
    edge_idx = 0
    qubit_idx = n+3
    for i in range(n):
        for j in range(i+1, n):
            edge = graph[edge_idx]
            edge_idx += 1
            if edge == '0':
                operations.append((qc.mcx, [i, j], qubit_idx))
                qubit_idx += 1
    for i in range(n+3, n+3+num_missing_edges):
        operations.append((qc.x, [i], None))

    if num_missing_edges > 0:
        operations.append((qc.mcx, [i for i in range(n+3, n+3+num_missing_edges)], edge_sat_qubit))
    else:
        operations.append((qc.x, edge_sat_qubit))

    # count whether there are at least clique_size vertices in the clique
    for s in statements:
        var_map[s.split('=')[0].strip()] = qubit_idx
        qubit_idx += 1

    var_map[sort_outputs[clique_size-1]] = count_sat_qubit

    for s in statements:
        res = var_map[s.split('=')[0].strip()]
        if "or" in s:
            var1, var2 = s.split('=')[1].split('or')
            var1 = var_map[var1.strip()]
            var2 = var_map[var2.strip()]
            operations.append((qc.x, var1))
            operations.append((qc.x, var2))
            operations.append((qc.mcx, [var1, var2], res))
            operations.append((qc.x, res))
            operations.append((qc.x, var1))
            operations.append((qc.x, var2))
            continue

        elif "and" in s:
            var1, var2 = s.split('=')[1].split('and')
            var1 = var_map[var1.strip()]
            var2 = var_map[var2.strip()]
            operations.append((qc.mcx, [var1, var2], res))
            continue

    # apply operations in forward order
    for i in range(len(operations)):
        op = operations[i][0]
        op(*operations[i][1:])
    qc.mcx([edge_sat_qubit, count_sat_qubit], ret_qubit)
    # apply operations in reverse order
    for i in range(len(operations)-1, -1, -1):
        op = operations[i][0]
        op(*operations[i][1:])
    return qc

In [127]:
graph = "111000"#"110001000000000"#"110001100100000"#
n = int((1 + (1 + 8*len(graph))**0.5) / 2)
print(n)
clique_oracle = get_oracle(construct_clique_verifier(graph, as_classical_function=True, clique_size=2))
clique_oracle = direct_clique_oracle_circuit(graph, clique_size=2)
#clique_oracle.draw('text')

4


In [74]:

grover_op = grover_operator(clique_oracle, reflection_qubits=range(n))
#grover_op.draw('text')

In [21]:

def count_solutions(oracle, n, use_simulator=False):
    """ 
    Given oracle U_f (or phase oracle) and input space size n, returns an estimate of the number of solutions to U_f(x) = 1 and the phase angle of the oracle 
    We assume that if the oracle has one extra qubit than n, it is the "result" qubit initialized to H |1> for the U_f oracle.
    """
    assert oracle.num_qubits in [n, n+1]
    uf_mode = oracle.num_qubits == n+1
    counting_qubits = n
    counting_circuit = qiskit.QuantumCircuit(counting_qubits + oracle.num_qubits, counting_qubits)
    grover_op = grover_operator(oracle, reflection_qubits=range(n))
    
    counting_circuit.h(range(counting_qubits))
    counting_circuit.h(range(counting_qubits, counting_qubits + n))
    # initialize the result qubit to H |1> if uf_mode
    if uf_mode:
        counting_circuit.x(counting_qubits + n)
        counting_circuit.h(counting_qubits + n)
        
    for i in range(counting_qubits):
        power = 2**i
        controlled_grover = grover_op.power(power).control()
        counting_circuit.append(controlled_grover.to_instruction(),
                            [i] + list(range(counting_qubits, counting_qubits + oracle.num_qubits)))
    counting_circuit.append(QFT(counting_qubits, do_swaps=False).inverse(), range(counting_qubits))
    counting_circuit.measure(range(counting_qubits), range(counting_qubits))
    print("finished constructing circuit")

    if use_simulator:
        simulator = AerSimulator()
        pass_manager = generate_preset_pass_manager(optimization_level=1, backend=simulator)
        counting_circuit = pass_manager.run(counting_circuit)
        result = simulator.run(counting_circuit,shots=10**4).result()
        counts = result.get_counts()
    else:
        qc_transpiled = qiskit.transpile(counting_circuit, backend)
        sampler = Sampler(backend)
        print("running job")
        job = sampler.run([qc_transpiled], shots=10**4)
        result = job.result()[0]
        counts = result.data.c.get_counts()

    # extract the phase angle (average across the shots)
    phase = 0
    shots = 0
    for output, count in counts.items():
        phase += count * int(output[::-1], 2)
        shots += count
    phase /= (shots * 2**counting_qubits)

    # adjust phase (not sure where the factor of 3 comes from, but this was needed for simulation to match expected TODO: figure out why)
    phase /= 3

    N = 2**n
    m = N * (1 - math.cos(2 * math.pi * phase)) / 2
    return m, phase

#count_solutions(clique_oracle, 6, use_simulator=False)

In [129]:
def find_solution(oracle, n, m, use_simulator=False):
    """ 
    Given oracle U_f that has m solutions, this uses Grover's algorithm to find one of them.
    Since Grover's search is probabilistic, we also need a classical certifier function that takes a solution and returns whether it is valid
    to repeat the search in the case of failure.
    """
    #assert oracle.num_qubits in [n, n+1]
    uf_mode = oracle.num_qubits >= n+1
    grover_op = grover_operator(oracle, reflection_qubits=range(n))

    optimal_num_iterations = math.floor(
        math.pi / (4 * math.asin(math.sqrt(m / 2**n)))
    )
    
    search_circuit = qiskit.QuantumCircuit(oracle.num_qubits, n)

    # initialize the result qubit to H |1> if uf_mode
    if uf_mode:
        search_circuit.x(n)
        search_circuit.h(n)

    search_circuit.h(range(n))
    search_circuit.compose(grover_op.power(optimal_num_iterations), inplace=True)
    search_circuit.measure(range(n), range(n))

    if use_simulator:
        simulator = AerSimulator()
        pass_manager = generate_preset_pass_manager(optimization_level=1, backend=simulator)
        qc = pass_manager.run(search_circuit)
        result = simulator.run(qc,shots=10**4).result()
        counts = result.get_counts()
    else:
        qc = qiskit.transpile(search_circuit, backend)
        sampler = Sampler(backend)
        print("running job")
        job = sampler.run([qc], shots=10**4)
        result = job.result()[0]
        counts = result.data.c.get_counts()

    print(sorted(counts.items(), key=lambda x: x[1], reverse=True))
    return search_circuit

circuit = find_solution(clique_oracle, n, 3, use_simulator=False)
#circuit.draw('text')

running job
[('01111', 707), ('01011', 679), ('01100', 669), ('01010', 647), ('00010', 644), ('00111', 633), ('01001', 625), ('01110', 617), ('00110', 612), ('01101', 610), ('01000', 610), ('00001', 607), ('00101', 601), ('00100', 599), ('00011', 572), ('00000', 568)]


In [None]:
[('1101', 721), ('0111', 714), ('1110', 708), ('1001', 666), ('1010', 640), ('1111', 629), ('0000', 626), ('0011', 618), ('0010', 609), ('1100', 608), ('0100', 589), ('0110', 587), ('1011', 583), ('1000', 575), ('0101', 565), ('0001', 562)]

In [None]:
@classical_function
def test_count_2_of_4(x0: Int1, x1: Int1, x2: Int1, x3: Int1) -> Int1:
    s_1_1_high = x0 or x1
    s_1_1_low = x0 and x1
    s_2_1_high = s_1_1_low or x2
    s_2_1_low = s_1_1_low and x2
    s_2_2_high = s_1_1_high or s_2_1_high
    s_2_2_low = s_1_1_high and s_2_1_high
    s_3_1_high = s_2_1_low or x3
    s_3_1_low = s_2_1_low and x3
    s_3_2_high = s_2_2_low or s_3_1_high
    s_3_2_low = s_2_2_low and s_3_1_high
    s_3_3_high = s_2_2_high or s_3_2_high
    s_3_3_low = s_2_2_high and s_3_2_high
    return s_3_3_high and s_3_3_low  # true if there are at least two 1s

quantum_circuit = test_count_2_of_4.synth(registerless=False)
quantum_circuit.draw('text')

In [None]:
circuits = []
for inpt in ['1111', '0000', '1010', '1011', '0011', '0101', '0110', '1100', '1001']:
    qc = qiskit.QuantumCircuit(4 + 1) # +1 for result qubit
    for i, bit in enumerate(inpt):
        if bit == '1':
            qc.x(qc.qubits[i])

    qc.compose(quantum_circuit, inplace=True)
    qc.measure_all()
    qc_transpiled = qiskit.transpile(qc, backend=backend)
    job_pub_idx = len(circuits)
    circuits.append(qc_transpiled)


sampler = Sampler(backend)
job = sampler.run(circuits, shots=10**4)
job_id = job.job_id()

In [None]:
results = job.result()

In [None]:
results[1].data.meas.get_counts()

In [None]:
# Simulation comparison
from qiskit_aer import AerSimulator
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
simulator = AerSimulator()

circuits = []
for inpt in ['1111', '0000', '1010', '1011', '0011', '0101', '0110', '1100', '1001']:
    qc = qiskit.QuantumCircuit(4 + 1) # +1 for result qubit
    for i, bit in enumerate(inpt):
        if bit == '1':
            qc.x(qc.qubits[i])

    qc.compose(quantum_circuit, inplace=True)
    qc.measure_all()
    pm = generate_preset_pass_manager(optimization_level=1, backend=simulator)
    qc_transpiled = pm.run(qc)

    simulation = simulator.run(qc_transpiled, shots=10**4)
    simulation_counts = simulation.result()
    print(simulation_counts.get_counts())
