# [Shor's Algorithm Code used from IBM's Qiskit Tutorials](https://qiskit.org/textbook/ch-algorithms/shor.html)
## Code adapted for Cirq

In [8]:
import unittest

import cirq
from cirq.ops import H, X, I
import random
import matplotlib.pyplot as plt
from math import gcd
import numpy as np
from numpy.random import randint

#import hypothesis.strategies as st
#from hypothesis import given, settings

from fractions import Fraction
from math import gcd # greatest common divisor

In [9]:
# Specify variables
n_count = 8  # number of counting qubits
a = 7
N = 15

In [10]:
class aMod15Gate(cirq.Gate):
    def __init__(self, a, power):
        super(aMod15Gate, self)
        self.a = a
        self.power = power

    def _num_qubits_(self):
        return 4

    def _decompose_(self, qubits):
        q0, q1, q2, q3 = qubits
        if self.a not in [2,7,8,11,13]:
            raise ValueError("'a' must be 2,7,8,11 or 13")
        for iteration in range(self.power):
            if self.a in [2,13]:
                yield cirq.SWAP(q0,q1)
                yield cirq.SWAP(q1,q2)
                yield cirq.SWAP(q2,q3)
            if self.a in [7,8]:
                yield cirq.SWAP(q2,q3)
                yield cirq.SWAP(q1,q2)
                yield cirq.SWAP(q0,q1)
            if self.a == 11:
                yield cirq.SWAP(q1,q3)
                yield cirq.SWAP(q0,q2)
            if self.a in [7,11,13]:
                yield cirq.X(q0)
                yield cirq.X(q1)
                yield cirq.X(q2)
                yield cirq.X(q3)

    def _circuit_diagram_info_(self, args):
        return "a mod 15" 

In [11]:
def qft_dagger_cirq(qc, qubits, n):
    for qubit in range(n//2):
        qc.append(cirq.SWAP(qubits[qubit], qubits[n-qubit-1]))
    for j in range(n):
        for m in range(j):
            qc.append((cirq.CZ**(-1/2**(j-m)))(qubits[m],qubits[j]))
        qc.append(cirq.H(qubits[j]))

In [12]:
def qpe_amod15(a):
    n_count = 8
    qubits = cirq.LineQubit.range(4+n_count)
    qc = cirq.Circuit()     
    for q in range(n_count):
        #print(q)
        qc.append(cirq.H(qubits[q]))     # Initialize counting qubits in state |+>
    qc.append(cirq.X(qubits[3+n_count])) # And auxiliary register in state |1>
    for q in range(n_count): # Do controlled-U operations
        qc.append(aMod15Gate(a, 2**q).on(qubits[8],qubits[9],qubits[10],qubits[11]).controlled_by(qubits[q]))
    qft_dagger_cirq(qc, qubits[:n_count], n_count) # Do inverse-QF
    qc.append(cirq.measure(*qubits[:8], key='m'))
    # Simulate Results
    simulator = cirq.Simulator()
    results = simulator.run(qc , repetitions =1)
    readings = np.array2string(results.measurements['m'][0], separator='')[1:-1][::-1]
    phase = int(readings,2)/(2**n_count)
    return phase

In [13]:
phase = qpe_amod15(a) # Phase = s/r
Fraction(phase).limit_denominator(15)

Fraction(3, 4)

In [14]:
frac = Fraction(phase).limit_denominator(15)
s, r = frac.numerator, frac.denominator
print(r)

4


In [15]:
guesses = [gcd(a**(r//2)-1, N), gcd(a**(r//2)+1, N)]
print(guesses)

[3, 5]


In [16]:
def find_factor(coprime):
    a = coprime
    attempt = 0
    factors = []
    for i in range(100):
        attempt += 1
        #print("\nAttempt %i:" % attempt)
        phase = qpe_amod15(a) # Phase = s/r
        frac = Fraction(phase).limit_denominator(N) # Denominator should (hopefully!) tell us r
        r = frac.denominator
        #print("Result: r = %i" % r)
        if phase != 0:
            # Guesses for factors are gcd(x^{r/2} ±1 , 15)
            guesses = [gcd(a**(r//2)-1, N), gcd(a**(r//2)+1, N)]
            #print("Guessed Factors: %i and %i" % (guesses[0], guesses[1]))
            for guess in guesses:
                if guess not in [1,N] and (N % guess) == 0: # Check to see if guess is a factor
            #        print("*** Non-trivial factor found: %i ***" % guess)
                    factors += [guess]
            return factors

find_factor(7)

[3]

In [17]:
import fractions
import math
import random

import numpy as np
import sympy
from typing import Callable, List, Optional, Sequence, Union

import cirq

In [18]:
"""Function to compute the elements of Z_n."""
def multiplicative_group(n: int) -> List[int]:
    """Returns the multiplicative group modulo n.

    Args:
        n: Modulus of the multiplicative group.
    """
    print("multiplicative group")
    assert n > 1
    group = [1]
    for x in range(2, n):
        if math.gcd(x, n) == 1:
            group.append(x)
    return group

In [19]:
"""Defines the modular exponential operation used in Shor's algorithm."""
class ModularExp(cirq.ArithmeticOperation):
    """Quantum modular exponentiation.

    This class represents the unitary which multiplies base raised to exponent
    into the target modulo the given modulus. More precisely, it represents the
    unitary V which computes modular exponentiation x**e mod n:

        V|y⟩|e⟩ = |y * x**e mod n⟩ |e⟩     0 <= y < n
        V|y⟩|e⟩ = |y⟩ |e⟩                  n <= y

    where y is the target register, e is the exponent register, x is the base
    and n is the modulus. Consequently,

        V|y⟩|e⟩ = (U**e|y)|e⟩

    where U is the unitary defined as

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y
    """
    def __init__(
        self, 
        target: Sequence[cirq.Qid],
        exponent: Union[int, Sequence[cirq.Qid]], 
        base: int,
        modulus: int
    ) -> None:
        if len(target) < modulus.bit_length():
            raise ValueError(f'Register with {len(target)} qubits is too small '
                             f'for modulus {modulus}')
        self.target = target
        self.exponent = exponent
        self.base = base
        self.modulus = modulus

    def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:
        return self.target, self.exponent, self.base, self.modulus

    def with_registers(
            self,
            *new_registers: Union[int, Sequence['cirq.Qid']],
    ) -> cirq.ArithmeticOperation:
        if len(new_registers) != 4:
            raise ValueError(f'Expected 4 registers (target, exponent, base, '
                             f'modulus), but got {len(new_registers)}')
        target, exponent, base, modulus = new_registers
        if not isinstance(target, Sequence):
            raise ValueError(
                f'Target must be a qubit register, got {type(target)}')
        if not isinstance(base, int):
            raise ValueError(
                f'Base must be a classical constant, got {type(base)}')
        if not isinstance(modulus, int):
            raise ValueError(
                f'Modulus must be a classical constant, got {type(modulus)}')
        return ModularExp(target, exponent, base, modulus)

    def apply(self, *register_values: int) -> int:
        assert len(register_values) == 4
        target, exponent, base, modulus = register_values
        if target >= modulus:
            return target
        return (target * base**exponent) % modulus

    def _circuit_diagram_info_(
            self,
            args: cirq.CircuitDiagramInfoArgs,
    ) -> cirq.CircuitDiagramInfo:
        assert args.known_qubits is not None
        wire_symbols: List[str] = []
        t, e = 0, 0
        for qubit in args.known_qubits:
            if qubit in self.target:
                if t == 0:
                    if isinstance(self.exponent, Sequence):
                        e_str = 'e'
                    else:
                        e_str = str(self.exponent)
                    wire_symbols.append(
                        f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')
                else:
                    wire_symbols.append('t' + str(t))
                t += 1
            if isinstance(self.exponent, Sequence) and qubit in self.exponent:
                wire_symbols.append('e' + str(e))
                e += 1
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))

In [20]:
def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
    """Returns quantum circuit which computes the order of x modulo n.

    The circuit uses Quantum Phase Estimation to compute an eigenvalue of
    the unitary

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y

    Args:
        x: positive integer whose order modulo n is to be found
        n: modulus relative to which the order of x is to be found

    Returns:
        Quantum circuit for finding the order of x modulo n
    """
    L = n.bit_length()
    target = cirq.LineQubit.range(L)
    exponent = cirq.LineQubit.range(L, 3 * L + 3)
    return cirq.Circuit(
        cirq.X(target[L - 1]),
        cirq.H.on_each(*exponent),
        ModularExp(target, exponent, x, n),
        cirq.qft(*exponent, inverse=True),
        cirq.measure(*exponent, key='exponent'),
    )

In [21]:
def process_measurement(result: cirq.Result, x: int, n: int) -> Optional[int]:
    """Interprets the output of the order finding circuit.

    Specifically, it determines s/r such that exp(2πis/r) is an eigenvalue
    of the unitary

        U|y⟩ = |xy mod n⟩  0 <= y < n
        U|y⟩ = |y⟩         n <= y

    then computes r (by continued fractions) if possible, and returns it.

    Args:
        result: result obtained by sampling the output of the
            circuit built by make_order_finding_circuit

    Returns:
        r, the order of x modulo n or None.
    """
    print("process measurement")
    # Read the output integer of the exponent register.
    exponent_as_integer = result.data["exponent"][0]
    exponent_num_bits = result.measurements["exponent"].shape[1]
    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)

    # Run the continued fractions algorithm to determine f = s / r.
    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)

    # If the numerator is zero, the order finder failed.
    if f.numerator == 0:
        return None

    # Else, return the denominator if it is valid.
    r = f.denominator
    if x**r % n != 1:
        return None
    return r

In [22]:
def quantum_order_finder(x: int, n: int) -> Optional[int]:
    """Computes smallest positive r such that x**r mod n == 1.

    Args:
        x: integer whose order is to be computed, must be greater than one
           and belong to the multiplicative group of integers modulo n (which
           consists of positive integers relatively prime to n),
        n: modulus of the multiplicative group.
    """
    print("quantum order finder")
    # Check that the integer x is a valid element of the multiplicative group
    # modulo n.
    if x < 2 or n <= x or math.gcd(x, n) > 1:
        raise ValueError(f'Invalid x={x} for modulus n={n}.')

    # Create the order finding circuit.
    circuit = make_order_finding_circuit(x, n)
    print(circuit)
    
    # Sample from the order finding circuit.
    measurement = cirq.sample(circuit)

    # Return the processed measurement result.
    return process_measurement(measurement, x, n)

In [23]:
def find_factor_of_prime_power(n: int) -> Optional[int]:
    """Returns non-trivial factor of n if n is a prime power, else None."""
    print("factor of prime power")
    for k in range(2, math.floor(math.log2(n)) + 1):
        c = math.pow(n, 1 / k)
        c1 = math.floor(c)
        if c1**k == n:
            return c1
        c2 = math.ceil(c)
        if c2**k == n:
            return c2
    return None


def find_factor(
    n: int,
    order_finder: Callable[[int, int], Optional[int]] = quantum_order_finder,
    max_attempts: int = 30
) -> Optional[int]:
    """Returns a non-trivial factor of composite integer n.

    Args:
        n: Integer to factor.
        order_finder: Function for finding the order of elements of the
            multiplicative group of integers modulo n.
        max_attempts: number of random x's to try, also an upper limit
            on the number of order_finder invocations.

    Returns:
        Non-trivial factor of n or None if no such factor was found.
        Factor k of n is trivial if it is 1 or n.
    """
    # If the number is prime, there are no non-trivial factors.
    if sympy.isprime(n):
        print("n is prime!")
        return None

    # If the number is even, two is a non-trivial factor.
    #if n % 2 == 0:
    #    return 2

    # If n is a prime power, we can find a non-trivial factor efficiently.
    #c = find_factor_of_prime_power(n)
    #if c is not None:
    #    return c

    print("find factor")
    
    for _ in range(max_attempts):
        print("loop")
        # Choose a random number between 2 and n - 1.
        x = random.randint(2, n - 1)
        print("x " + str(x))

        # Most likely x and n will be relatively prime.
        c = math.gcd(x, n)
        print("c " + str(c))

        # If x and n are not relatively prime, we got lucky and found
        # a non-trivial factor.
        if 1 < c < n:
            return c

        # Compute the order r of x modulo n using the order finder.
        r = order_finder(x, n)
        print(r)

        # If the order finder failed, try again.
        if r is None:
            continue

        # If the order r is even, try again.
        if r % 2 != 0:
            continue

        # Compute the non-trivial factor.
        y = x**(r // 2) % n
        assert 1 < y < n
        c = math.gcd(y - 1, n)
        if 1 < c < n:
            return c

    print(f"Failed to find a non-trivial factor in {max_attempts} attempts.")
    return None

In [69]:
find_factor(15)

find factor
loop
x 7
c 1
quantum order finder
0: ────────ModularExp(t*7**e % 15)────────────────────────────
           │
1: ────────t1─────────────────────────────────────────────────
           │
2: ────────t2─────────────────────────────────────────────────
           │
3: ────X───t3─────────────────────────────────────────────────
           │
4: ────H───e0────────────────────────qft^-1───M('exponent')───
           │                         │        │
5: ────H───e1────────────────────────#2───────M───────────────
           │                         │        │
6: ────H───e2────────────────────────#3───────M───────────────
           │                         │        │
7: ────H───e3────────────────────────#4───────M───────────────
           │                         │        │
8: ────H───e4────────────────────────#5───────M───────────────
           │                         │        │
9: ────H───e5────────────────────────#6───────M───────────────
           │                    

3