In [1]:
import fractions
import math
import random

import numpy as np
import sympy
from typing import Callable, List, Optional, Sequence, Union

import cirq

#### Order Finding

In [2]:
# Returns congruence group of n (multiplicative group modulo n)

def multiplicative_group(n: int) -> List[int]:
    assert n > 1
    group = [1]
    for x in range(2, n):
        if math.gcd(x, n) == 1:
            group.append(x)
    return group

In [3]:
# Congruence group demo

n = 15
print(f"The multiplicative group modulo n = {n} is:")
print(multiplicative_group(n))

The multiplicative group modulo n = 15 is:
[1, 2, 4, 7, 8, 11, 13, 14]


In [4]:
# Classical order finding
#    - performs modular exponeniation until r is found s.t. x^r mod n = 1

def classical_order_finder(x: int, n: int) -> Optional[int]:
    # Make sure x is both valid and in Z_n.
    if x < 2 or x >= n or math.gcd(x, n) > 1:
        raise ValueError(f"Invalid x={x} for modulus n={n}.")

    # Determine the order
    r, y = 1, x
    while y != 1:
        y = (x * y) % n
        r += 1
    return r

In [5]:
# Ex: Classical order finding

n = 15
n_g = multiplicative_group(n)
x = 7
r = classical_order_finder(x, n)

# Check that the order is indeed correct.
print(f"Multiplicative group for n = {n}: {n_g}")
print()
print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")


Multiplicative group for n = 15: [1, 2, 4, 7, 8, 11, 13, 14]

x^r mod n = 7^4 mod 15 = 1


In [6]:
# Modular exponentiation in task 2 for period-finding subroutine
# Inheriting ArithmeticOperation performs the implicit modulus by N_t (dimension of target register)

class ModularExp(cirq.ArithmeticOperation):
    """
    Finds the unitary V which computes modular exponentiation x**e mod n:

        V|y⟩|e⟩ = |y * x**e mod n⟩ |e⟩     0 <= y < n
        V|y⟩|e⟩ = |y⟩ |e⟩                  n <= y

    y: target register
    e: exponent register
    x: base
    n: modulus

        V|y⟩|e⟩ = (U**e|y)|e⟩

    where U is the unitary defined as

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y
    """
    def __init__(
        self, 
        target: Sequence[cirq.Qid],
        exponent: Union[int, Sequence[cirq.Qid]], 
        base: int,
        modulus: int
    ) -> None:
        if len(target) < modulus.bit_length():
            raise ValueError(f'Register with {len(target)} qubits is too small '
                             f'for modulus {modulus}')
        self.target = target
        self.exponent = exponent
        self.base = base
        self.modulus = modulus

    def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:
        return self.target, self.exponent, self.base, self.modulus

    def with_registers(
            self,
            *new_registers: Union[int, Sequence['cirq.Qid']],
    ) -> cirq.ArithmeticOperation:
        if len(new_registers) != 4:
            raise ValueError(f'Expected 4 registers (target, exponent, base, '
                             f'modulus), but got {len(new_registers)}')
        target, exponent, base, modulus = new_registers
        if not isinstance(target, Sequence):
            raise ValueError(
                f'Target must be a qubit register, got {type(target)}')
        if not isinstance(base, int):
            raise ValueError(
                f'Base must be a classical constant, got {type(base)}')
        if not isinstance(modulus, int):
            raise ValueError(
                f'Modulus must be a classical constant, got {type(modulus)}')
        return ModularExp(target, exponent, base, modulus)
    
    def apply(self, *register_values: int) -> int:
        assert len(register_values) == 4
        target, exponent, base, modulus = register_values
        if target >= modulus:
            return target
        return (target * base**exponent) % modulus
    
    # (target * base**exponent) % modulus
    #    - target, exponent: depend on the value of the respective qbit registers
    #    - base, modulus: constant, modulus = n, base = x \in congruence group Z_n
    

    def _circuit_diagram_info_(
            self,
            args: cirq.CircuitDiagramInfoArgs,
    ) -> cirq.CircuitDiagramInfo:
        assert args.known_qubits is not None
        wire_symbols: List[str] = []
        t, e = 0, 0
        for qubit in args.known_qubits:
            if qubit in self.target:
                if t == 0:
                    if isinstance(self.exponent, Sequence):
                        e_str = 'e'
                    else:
                        e_str = str(self.exponent)
                    wire_symbols.append(
                        f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')
                else:
                    wire_symbols.append('t' + str(t))
                t += 1
            if isinstance(self.exponent, Sequence) and qubit in self.exponent:
                wire_symbols.append('e' + str(e))
                e += 1
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))


In [7]:
'''
    - Total number of qbits required is 3(L + 1), where L is the number of bits to store integer n to factor
    - Size of unitary that implements modular exponential is 4^(3(L+1))
    
    ex: n= 15 requires 2^30 floating point numbers
    
'''

n = 15
L = n.bit_length()

# The target register has L qubits
target = cirq.LineQubit.range(L)

# The exponent register has 2L + 3 qubits
exponent = cirq.LineQubit.range(L, 3 * L + 3)

# Display the total number of qubits to factor n
print(f"To factor n = {n} which has L = {L} bits, we need 3L + 3 = {3 * L + 3} qubits.")

To factor n = 15 which has L = 4 bits, we need 3L + 3 = 15 qubits.


In [8]:
# x is some element of multiplicative group modulo n
x = 3

# Display (part of) the unitary
#cirq.unitary(ModularExp(target, exponent, x, n))


In [9]:
# Returns quantum circuit which computes the order of x modulo n

def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
    """

    The circuit uses Quantum Phase Estimation to compute an eigenvalue of
    the unitary

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y

    x: integer > 0 with order modulo n to be found
    n: modulus relative to the order of x

    """
    L = n.bit_length()
    target = cirq.LineQubit.range(L)
    exponent = cirq.LineQubit.range(L, 3 * L + 3)
    
    return cirq.Circuit(
        cirq.X(target[L - 1]),
        cirq.H.on_each(*exponent),
        ModularExp(target, exponent, x, n),
        cirq.qft(*exponent, inverse=True),
        cirq.measure(*exponent, key='exponent'),
    )


In [10]:
n = 15
x = 7
circuit = make_order_finding_circuit(x, n)
print(circuit)

0: ────────ModularExp(t*7**e % 15)────────────────────────────
           │
1: ────────t1─────────────────────────────────────────────────
           │
2: ────────t2─────────────────────────────────────────────────
           │
3: ────X───t3─────────────────────────────────────────────────
           │
4: ────H───e0────────────────────────qft^-1───M('exponent')───
           │                         │        │
5: ────H───e1────────────────────────#2───────M───────────────
           │                         │        │
6: ────H───e2────────────────────────#3───────M───────────────
           │                         │        │
7: ────H───e3────────────────────────#4───────M───────────────
           │                         │        │
8: ────H───e4────────────────────────#5───────M───────────────
           │                         │        │
9: ────H───e5────────────────────────#6───────M───────────────
           │                         │        │
10: ───H───e6─────────────────

In [11]:
# Measuring period finding circuit

circuit = make_order_finding_circuit(x=5, n=6)
res = cirq.sample(circuit, repetitions=8)

print("Raw measurements:")
print(res)

print("\nInteger in exponent register:")
print(res.data)


Raw measurements:
exponent=00101001, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000

Integer in exponent register:
   exponent
0         0
1         0
2       256
3         0
4       256
5         0
6         0
7       256


In [12]:
# Interprets the output of the order finding circuit

def process_measurement(result: cirq.Result, x: int, n: int) -> Optional[int]:
    """
    Determines s/r such that exp(2πis/r) is an eigenvalue
    of the unitary

        U|y⟩ = |xy mod n⟩  0 <= y < n
        U|y⟩ = |y⟩         n <= y

    then computes r by fractional expansion

    result: result obtained by sampling output of circuit built by make_order_finding_circuit

    Return: r, the order of x modulo n or None
    """
    # Read the output integer of the exponent register
    exponent_as_integer = result.data["exponent"][0]
    exponent_num_bits = result.measurements["exponent"].shape[1]
    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)

    # Run the continued fractions algorithm to determine f = s / r
    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)

    # If the numerator is zero, the order finder failed
    if f.numerator == 0:
        return None

    # Else, return the denominator if it is valid
    r = f.denominator
    if x**r % n != 1:
        return None
    return r

In [13]:
# Computes smallest positive r such that x**r mod n == 1
def quantum_order_finder(x: int, n: int) -> Optional[int]:
    """

    Args:
        x: integer whose order is to be computed, must be greater than one
           and belong to the multiplicative group of integers modulo n (which
           consists of positive integers relatively prime to n),
        n: modulus of the multiplicative group.
    """
    # Check that the integer x is a valid element of the multiplicative group
    # modulo n
    if x < 2 or n <= x or math.gcd(x, n) > 1:
        raise ValueError(f'Invalid x={x} for modulus n={n}.')

    # Create the order finding circuit
    circuit = make_order_finding_circuit(x, n)

    # Sample from the order finding circuit
    measurement = cirq.sample(circuit)

    # Return the processed measurement result
    return process_measurement(measurement, x, n)

In [14]:
# Returns non-trivial factor of n if n is a prime power, else None
def find_factor_of_prime_power(n: int) -> Optional[int]:
    for k in range(2, math.floor(math.log2(n)) + 1):
        c = math.pow(n, 1 / k)
        c1 = math.floor(c)
        if c1**k == n:
            return c1
        c2 = math.ceil(c)
        if c2**k == n:
            return c2
    return None

In [15]:
# Returns a non-trivial factor of composite integer n
def find_factor(
    n: int,
    order_finder: Callable[[int, int], Optional[int]] = quantum_order_finder,
    max_attempts: int = 30
) -> Optional[int]:
    """

    Args:
        n: Integer to factor.
        order_finder: Function for finding the order of elements of the
            multiplicative group of integers modulo n.
        max_attempts: number of random x's to try, also an upper limit
            on the number of order_finder invocations.

    Returns:
        Non-trivial factor of n or None if no such factor was found.
        Factor k of n is trivial if it is 1 or n.
    """
    # If the number is prime, there are no non-trivial factors
    if sympy.isprime(n):
        print("n is prime!")
        return None

    # If the number is even, two is a non-trivial factor
    if n % 2 == 0:
        return 2

    # If n is a prime power, we can find a non-trivial factor efficiently
    c = find_factor_of_prime_power(n)
    if c is not None:
        return c

    for _ in range(max_attempts):
        # Choose a random number between 2 and n - 1
        x = random.randint(2, n - 1)

        # Most likely x and n will be relatively prime
        c = math.gcd(x, n)

        # If x and n are not relatively prime, we got lucky and found
        # a non-trivial factor
        if 1 < c < n:
            return c

        # Compute the order r of x modulo n using the order finder
        r = order_finder(x, n)

        # If the order finder failed, try again
        if r is None:
            continue

        # If the order r is even, try again
        if r % 2 != 0:
            continue

        # Compute the non-trivial factor
        y = x**(r // 2) % n
        assert 1 < y < n
        c = math.gcd(y - 1, n)
        if 1 < c < n:
            return c

    print(f"Failed to find a non-trivial factor in {max_attempts} attempts.")
    return None


In [23]:
# Example of factoring via Shor's algorithm (order finding)

# Number to factor
n = 55

# Attempt to find a factor
p = find_factor(n, order_finder=quantum_order_finder)
q = n // p

print("Factoring n = pq =", n)
print("   p =", p)
print("   q =", q)


KeyboardInterrupt: 