In [14]:
try:
    import cirq
except ImportError:
    print("installing cirq...")
    !pip install --quiet cirq
    print("installed cirq.")

In [15]:
import fractions
import math
import random

import numpy as np
import sympy
from typing import Callable, List, Optional, Sequence, Union

import cirq

In [16]:
def multiplicative_group(n: int) -> List[int]:
    """Zwraca liczby których największy wspólny dzielnik z n wynosi 1
    """
    assert n > 1
    group = [1]
    for x in range(2, n):
        if math.gcd(x, n) == 1:
            group.append(x)
    return group

In [17]:
def classical_order_finder(x: int, n: int) -> Optional[int]:
    """Szukamy takiego r żeby x**r mod n == 1.

    """
    # Make sure x is both valid and in Z_n.
    if x < 2 or x >= n or math.gcd(x, n) > 1:
        raise ValueError(f"Invalid x={x} for modulus n={n}.")
    
    # Determine the order.
    r, y = 1, x
    while y != 1:
        y = (x * y) % n
        r += 1
    return r

In [18]:
class Adder(cirq.ArithmeticOperation):
    """dodawanie kwantowe"""
    def __init__(self, target_register, input_register):
        self.input_register = input_register
        self.target_register = target_register

    def registers(self):
        return self.target_register, self.input_register

    def with_registers(self, *new_registers):
        return Adder(*new_registers)

    def apply(self, target_value, input_value):
        return target_value + input_value

In [20]:
class ModularExp(cirq.ArithmeticOperation):
    """algorytm szybkiego potęgowania w wersji modulo N"""
    def __init__(
        self, 
        target: Sequence[cirq.Qid],
        exponent: Union[int, Sequence[cirq.Qid]], 
        base: int,
        modulus: int
    ) -> None:
        if len(target) < modulus.bit_length():
            raise ValueError(f'Register with {len(target)} qubits is too small '
                             f'for modulus {modulus}')
        self.target = target
        self.exponent = exponent
        self.base = base
        self.modulus = modulus

    def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:
        return self.target, self.exponent, self.base, self.modulus

    def with_registers(
            self,
            *new_registers: Union[int, Sequence['cirq.Qid']],
    ) -> cirq.ArithmeticOperation:
        if len(new_registers) != 4:
            raise ValueError(f'Expected 4 registers (target, exponent, base, '
                             f'modulus), but got {len(new_registers)}')
        target, exponent, base, modulus = new_registers
        if not isinstance(target, Sequence):
            raise ValueError(
                f'Target must be a qubit register, got {type(target)}')
        if not isinstance(base, int):
            raise ValueError(
                f'Base must be a classical constant, got {type(base)}')
        if not isinstance(modulus, int):
            raise ValueError(
                f'Modulus must be a classical constant, got {type(modulus)}')
        return ModularExp(target, exponent, base, modulus)

    def apply(self, *register_values: int) -> int:
        assert len(register_values) == 4
        target, exponent, base, modulus = register_values
        if target >= modulus:
            return target
        return (target * base**exponent) % modulus

    def _circuit_diagram_info_(
            self,
            args: cirq.CircuitDiagramInfoArgs,
    ) -> cirq.CircuitDiagramInfo:
        assert args.known_qubits is not None
        wire_symbols: List[str] = []
        t, e = 0, 0
        for qubit in args.known_qubits:
            if qubit in self.target:
                if t == 0:
                    if isinstance(self.exponent, Sequence):
                        e_str = 'e'
                    else:
                        e_str = str(self.exponent)
                    wire_symbols.append(
                        f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')
                else:
                    wire_symbols.append('t' + str(t))
                t += 1
            if isinstance(self.exponent, Sequence) and qubit in self.exponent:
                wire_symbols.append('e' + str(e))
                e += 1
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))


In [21]:
def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
    """
    """
    L = n.bit_length()
    target = cirq.LineQubit.range(L)
    exponent = cirq.LineQubit.range(L, 3 * L + 3)
    return cirq.Circuit(
        cirq.X(target[L - 1]),
        cirq.H.on_each(*exponent),
        ModularExp(target, exponent, x, n),
        cirq.qft(*exponent, inverse=True),
        cirq.measure(*exponent, key='exponent'),
    )

In [22]:
def process_measurement(result: cirq.Result, x: int, n: int) -> Optional[int]:
    """zwraca znalezione r lub None
    """
    # Read the output integer of the exponent register.
    exponent_as_integer = result.data["exponent"][0]
    exponent_num_bits = result.measurements["exponent"].shape[1]
    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)

    # Run the continued fractions algorithm to determine f = s / r.
    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)
    
    # If the numerator is zero, the order finder failed.
    if f.numerator == 0:
        return None
    
    # Else, return the denominator if it is valid.
    r = f.denominator
    if x**r % n != 1:
        return None
    return r

In [23]:
def quantum_order_finder(x: int, n: int) -> Optional[int]:
    """
    """
    # Check that the integer x is a valid element of the multiplicative group
    # modulo n.
    if x < 2 or n <= x or math.gcd(x, n) > 1:
        raise ValueError(f'Invalid x={x} for modulus n={n}.')

    # Create the order finding circuit.
    circuit = make_order_finding_circuit(x, n)
    
    # Sample from the order finding circuit.
    measurement = cirq.sample(circuit)
    
    # Return the processed measurement result.
    return process_measurement(measurement, x, n)

In [24]:
def find_factor_of_prime_power(n: int) -> Optional[int]:
    """Returns non-trivial factor of n if n is a prime power, else None."""
    for k in range(2, math.floor(math.log2(n)) + 1):
        c = math.pow(n, 1 / k)
        c1 = math.floor(c)
        if c1**k == n:
            return c1
        c2 = math.ceil(c)
        if c2**k == n:
            return c2
    return None


def find_factor(
    n: int,
    order_finder: Callable[[int, int], Optional[int]] = quantum_order_finder,
    max_attempts: int = 30
) -> Optional[int]:
    """
    """
    # If the number is prime, there are no non-trivial factors.
    if sympy.isprime(n):
        print("n is prime!")
        return None
    
    # If the number is even, two is a non-trivial factor.
    if n % 2 == 0:
        return 2
    
    # If n is a prime power, we can find a non-trivial factor efficiently.
    c = find_factor_of_prime_power(n)
    if c is not None:
        return c
    
    for _ in range(max_attempts):
        # Choose a random number between 2 and n - 1.
        x = random.randint(2, n - 1)
        
        # Most likely x and n will be relatively prime.
        c = math.gcd(x, n)
        
        # If x and n are not relatively prime, we got lucky and found
        # a non-trivial factor.
        if 1 < c < n:
            return c
        
        # Compute the order r of x modulo n using the order finder.
        r = order_finder(x, n)
        
        # If the order finder failed, try again.
        if r is None:
            continue
        
        # If the order r is not even, try again.
        if r % 2 != 0:
            continue
        
        # Compute the non-trivial factor.
        y = x**(r // 2) % n
        assert 1 < y < n
        c = math.gcd(y - 1, n)
        if 1 < c < n:
            return c

    print(f"Failed to find a non-trivial factor in {max_attempts} attempts.")
    return None

In [38]:
# Number to factor
n = 15

# Attempt to find a factor
p = find_factor(n, order_finder= quantum_order_finder)
q = n // p

print("Factoring n = pq =", n)
print("p =", p)
print("q =", q)

Factoring n = pq = 15
p = 3
q = 5


In [None]:
p * q == n