# Shor's algorithm, fully classical implementation

In [1]:
%matplotlib inline
import random
import math
import itertools
def period_finding_classical(a,N):
    # This is an inefficient classical algorithm to find the period of f(x)=a^x (mod N)
    # f(0) = a**0 (mod N) = 1, so we find the first x greater than 0 for which f(x) is also 1
    for r in itertools.count(start=1):
        if (a**r) % N == 1:
            return r

def shors_algorithm_classical(N):
    assert(N>0)
    assert(int(N)==N)
    while True:
        a=random.randint(0,N-1)
        g=math.gcd(a,N)
        if g!=1 or N==1:
            first_factor=g
            second_factor=int(N/g)
            return first_factor,second_factor
        else:
            r=period_finding_classical(a,N)  
            if r % 2 != 0:
                continue
            elif a**(int(r/2)) % N == -1 % N:
                continue
            else:
                first_factor=math.gcd(a**int(r/2)+1,N)
                second_factor=math.gcd(a**int(r/2)-1,N)
                if first_factor==N or second_factor==N:
                    continue
                return first_factor,second_factor


In [2]:
# Testing it out. Note because of the probabilistic nature of the algorithm, different factors and different ordering is possible
shors_algorithm_classical(15)
shors_algorithm_classical(91)

(13, 7)

# Shor's algorithm, working on a quantum implementation
## The following code will help give intuition for how to design a quantum circuit to do modular multiplication 

In [3]:
def U_a_modN(a,N,binary=False):
    """
    a and N are decimal
    This algorithm returns U_a where:
        U_a is a modular multiplication operator map from |x> to |ax mod N>
    If binary is set to True, the mapping is given in binary instead of in decimal notation.
    
    """
    res={}
    l=[]
    for i in range(1,N):
        l+=[a*i%N]
    res=set()

    for i in range(1,N):
        mp=[i]
        end=i
        nxt=i-1
        while l[nxt]!=end:
            mp+=[l[nxt]]
            nxt=l[nxt]-1
        res.add(tuple(mp))
    final_res=[]
    for item in res:
        dup=False
        for final_item in final_res:
            if set(item) == set(final_item):
                dup=True
        if not dup:
            final_res+=[item]
    if not binary:
        return final_res
    else:
        final_res_bin=[]
        for mapping in final_res:
            final_res_bin+=[tuple(['{0:06b}'.format(decimal) for decimal in mapping])]
        return final_res_bin
        
print(U_a_modN(8,35))
print(U_a_modN(8,35,binary=True))

[(7, 21, 28, 14), (34, 27, 6, 13), (2, 16, 23, 9), (26, 33, 19, 12), (18, 4, 32, 11), (24, 17, 31, 3), (15,), (30,), (5,), (8, 29, 22, 1), (20,), (10,), (25,)]
[('000111', '010101', '011100', '001110'), ('100010', '011011', '000110', '001101'), ('000010', '010000', '010111', '001001'), ('011010', '100001', '010011', '001100'), ('010010', '000100', '100000', '001011'), ('011000', '010001', '011111', '000011'), ('001111',), ('011110',), ('000101',), ('001000', '011101', '010110', '000001'), ('010100',), ('001010',), ('011001',)]


# This code implements modular multiplication by 2 mod 15

In [4]:
import qiskit
import matplotlib
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister, QISKitError
from  qiskit.tools.visualization  import circuit_drawer
from qiskit.extensions.standard import cx, cswap
from qiskit import Aer

from qiskit import IBMQ
# Authenticate an account and add for use during this session. Replace string
# argument with your private token.
IBMQ.enable_account("INSERT_YOUR_API_TOKEN_HERE")

def mult_2mod15_quantum(qr,qc):
    # Swap 0th qubit and 3rd qubit
    qc.cx(qr[0],qr[3])
    qc.cx(qr[3],qr[0])
    qc.cx(qr[0],qr[3])

    # Swap 0th qubit and 1st qubit
    qc.cx(qr[1],qr[0])
    qc.cx(qr[0],qr[1])
    qc.cx(qr[1],qr[0])

    # Swap 1st qubit and 2nd qubit
    qc.cx(qr[1],qr[2])
    qc.cx(qr[2],qr[1])
    qc.cx(qr[1],qr[2])


def mult_2mod15_quantum_test(x):
    qr = QuantumRegister(4)
    cr = ClassicalRegister(4)
    qc = QuantumCircuit(qr,cr)

    # input
    x_bin='{0:04b}'.format(x)
    for i,b in enumerate(x_bin):
        if int(b):
            qc.x(qr[i])
    # run circuit
    mult_2mod15_quantum(qr,qc)

    # measure results
    for i in range(4):
        qc.measure(qr[i],cr[i])
        
    import time
    from qiskit.tools.visualization import plot_histogram
    backend=Aer.get_backend('qasm_simulator')
    shots=50
    job_exp = qiskit.execute(qc, backend=backend)
    result = job_exp.result()
    final=result.get_counts(qc)
    result_in_order=list(final.keys())[0]
    dec=0
    for i,b in enumerate(result_in_order):
        if int(b):
            dec+=2**i
    return (x,dec)

def mult_2mod15_classical_test(x):
    return (x,2*x%15)

# testing!
for i in range(1,15):
    quantum=mult_2mod15_quantum_test(i)
    classical=mult_2mod15_classical_test(i)
    if quantum!=classical:
        print(quantum,classical)



## This code makes the previous an operation controlled by a control qubit

In [5]:
def controlled_mult_2mod15_quantum(qr,qc,control_qubit):
    """
    Controlled quantum circuit for multiplication by 2 mod 15.
        Note: control qubit should an index greater than 3, 
        and qubits 0,1,2,3 are reserved for circuit operations
    """
    # Swap 0th qubit and 3rd qubit
    qc.cswap(control_qubit,qr[0],qr[3])

    # Swap 0th qubit and 1st qubit
    qc.cswap(control_qubit,qr[1],qr[0])

    # Swap 1st qubit and 2nd qubit
    qc.cswap(control_qubit,qr[1],qr[2])


# This code performas the entire Shor's algorithm subroutine for multiplication by 2 mod 15

In [6]:
import math
def shors_subroutine_period_2mod15(qr,qc,cr):
    qc.x(qr[0])
    qc.h(qr[4])
    qc.h(qr[4])
    qc.measure(qr[4],cr[0])

    qc.h(qr[5])
    qc.cx(qr[5],qr[0])
    qc.cx(qr[5],qr[2])
    if cr[0] == 1:
        qc.u1(math.pi/2,qr[4]) #pi/2 is 90 degrees in radians
    qc.h(qr[5])
    qc.measure(qr[5],cr[1])

    qc.h(qr[6])
    controlled_mult_2mod15_quantum(qr,qc,qr[6])
    if cr[1] == 1:
        qc.u1(math.pi/2,qr[6]) # pi/2 is 90 degrees in radians
    if cr[0] == 1:
        qc.u1(math.pi/4,qr[6]) #pi/4 is 45 degrees in radians
    qc.h(qr[6])
    qc.measure(qr[6],cr[2]) 

# This code will help us read out the results from our quantum Shor's subroutine. First, implementing the code to compute the period from the output of the quantum computation:


In [7]:
# see https://arxiv.org/pdf/quant-ph/0010034.pdf for more details (convergence relations on page 11)
import math
def continued_fraction(xi,max_steps=100): # stop_after is cutoff for algorithm, for debugging
    """
    This function computes the continued fraction expansion of input xi
    per the recurrance relations on page 11 of https://arxiv.org/pdf/quant-ph/0010034.pdf
    
    """
    #a and xi initial
    all_as=[]
    all_xis=[]
    a_0=math.floor(xi)
    xi_0=xi-a_0
    all_as+=[a_0]
    all_xis+=[xi_0]
    # p and q initial
    all_ps=[]
    all_qs=[]
    p_0=all_as[0]
    q_0=1
    all_ps+=[p_0]
    all_qs+=[q_0]
    
    xi_n=xi_0
    while not numpy.isclose(xi_n,0,atol=1e-7):
        if len(all_as)>=max_steps:
            print("Warning: algorithm did not converge within max_steps %d steps, try increasing max_steps"%max_steps)
            break
        # computing a and xi
        a_nplus1=math.floor(1/xi_n)
        xi_nplus1=1/xi_n-a_nplus1
        all_as+=[a_nplus1]
        all_xis+=[xi_nplus1]
        xi_n=xi_nplus1
        # computing p and q
        n=len(all_as)-1
        if n==1:
            p_1=all_as[1]*all_as[0]+1
            q_1=all_as[1]
            all_ps+=[p_1]
            all_qs+=[q_1]
        else:
            p_n=all_as[n]*all_ps[n-1]+all_ps[n-2]
            q_n=all_as[n]*all_qs[n-1]+all_qs[n-2]
            all_ps+=[p_n]
            all_qs+=[q_n]
    return all_ps,all_qs,all_as,all_xis
    
import numpy
def test_continued_fraction():
    """
    Testing the continued fraction  see https://arxiv.org/pdf/quant-ph/0010034.pdf, step 2.5 chart page 20
    NOTE: I believe there is a mistake in this chart at the last row, and that n should range as in my code below
    their chart is missing one line. Please contact me if you find differently! 
    """
    xi=13453/16384
    all_ps,all_qs,all_as,all_xis=continued_fraction(xi)
    ## step 2.5 chart in https://arxiv.org/pdf/quant-ph/0010034.pdf page 20
    #n_13453_16384=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
    #a_n_13453_16384=[0,1,4,1,1,2,3,1,1,3,1,1,1,1,3]
    #p_n_13453_16384=[0,1,4,5,9,23,78,101,179,638,817,1455,2272,3727,13453]
    #q_n_13453_16384=[1,1,5,6,11,28,95,123,218,777,995,1772,2767,4539,16384]
    ## what I find instead:
    n_13453_16384=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    a_n_13453_16384=[0,1,4,1,1,2,3,1,1,3,1,1,1,1,2,1]
    p_n_13453_16384=[0,1,4,5,9,23,78,101,179,638,817,1455,2272,3727,9726,13453]
    q_n_13453_16384=[1,1,5,6,11,28,95,123,218,777,995,1772,2767,4539,11845,16384]
    for tup in [("ns",range(len(all_ps)),range(len(n_13453_16384))),
                ("as",all_as,a_n_13453_16384),
                ("ps",all_ps,p_n_13453_16384),
                ("qs",all_qs,q_n_13453_16384),
               ]:
        if not numpy.array_equal(tup[1],tup[2]):
            print(tup[0])
            print("act:",tup[1])
            print("exp:",tup[2])
            print()
            
from IPython.display import display, Math
def pretty_print_continued_fraction(results,raw_latex=False):
    all_ps,all_qs,all_as,all_xis=results
    for i,vals in enumerate(zip(all_ps,all_qs,all_as,all_xis)):
        p,q,a,xi=vals
        if raw_latex:
            print(r'\frac{p_%d}{q_%d}=\frac{%d}{%d}'%(i,i,p,q))
        else:
            display(Math(r'$\frac{p_%d}{q_%d}=\frac{%d}{%d}$'%(i,i,p,q)))
    
            
test_continued_fraction()
#pretty_print_continued_fraction(continued_fraction(5/8),raw_latex=True)
#pretty_print_continued_fraction(continued_fraction(0/8))
pretty_print_continued_fraction(continued_fraction(6/8))

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

# Next we will integrate the check for whether we have found the period into the continued fraction code, so that we can stop computing the continued fraction as soon as we've found the period

In [8]:
import math
def period_from_quantum_measurement(quantum_measurement,
                                    number_qubits,
                                    a_shor,
                                    N_shor,
                                    max_steps=100): # stop_after is cutoff for algorithm, for debugging
    """
    This function computes the continued fraction expansion of input xi
    per the recurrance relations on page 11 of https://arxiv.org/pdf/quant-ph/0010034.pdf
    a_shor is the random number chosen as part of Shor's algorithm
    N_shor is the number Shor's algorithm is trying to factor
    """
    xi=quantum_measurement/2**number_qubits
    
    #a and xi initial
    all_as=[]
    all_xis=[]
    a_0=math.floor(xi)
    xi_0=xi-a_0
    all_as+=[a_0]
    all_xis+=[xi_0]
    # p and q initial
    all_ps=[]
    all_qs=[]
    p_0=all_as[0]
    q_0=1
    all_ps+=[p_0]
    all_qs+=[q_0]
    
    xi_n=xi_0
    while not numpy.isclose(xi_n,0,atol=1e-7):
        if len(all_as)>=max_steps:
            print("Warning: algorithm did not converge within max_steps %d steps, try increasing max_steps"%max_steps)
            break
        # computing a and xi
        a_nplus1=math.floor(1/xi_n)
        xi_nplus1=1/xi_n-a_nplus1
        all_as+=[a_nplus1]
        all_xis+=[xi_nplus1]
        xi_n=xi_nplus1
        # computing p and q
        n=len(all_as)-1
        if n==1:
            p_1=all_as[1]*all_as[0]+1
            q_1=all_as[1]
            all_ps+=[p_1]
            all_qs+=[q_1]
        else:
            p_n=all_as[n]*all_ps[n-1]+all_ps[n-2]
            q_n=all_as[n]*all_qs[n-1]+all_qs[n-2]
            all_ps+=[p_n]
            all_qs+=[q_n]
        # check the q to see if it is our answer (note with this we skip the first q, as a trivial case)
        if a_shor**all_qs[-1]%N_shor == 1 % N_shor:
            return all_qs[-1]

period_from_quantum_measurement(13453,14,3,91) #should return, for example 6 per page 20 of https://arxiv.org/pdf/quant-ph/0010034.pdf

6

In [9]:
# Testing this:
import qiskit
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister

def binary_string_to_decimal(s):
    dec=0
    for i in s[::-1]:
        if int(i):
            dec+=2**int(i)
    return dec

def run_shors_subroutine_period2_mod15():
    qr = QuantumRegister(7)
    cr = ClassicalRegister(3)
    qc = QuantumCircuit(qr,cr)
    # initialize x to be a superposition of all possible r quibit values
    #for i in range(4):
    #    qc.h(qr[i])
    # run circuit (which includes measurement steps)
    shors_subroutine_period_2mod15(qr,qc,cr)
        
    import time
    from qiskit.tools.visualization import plot_histogram
    backend=Aer.get_backend('qasm_simulator')
    job_exp = qiskit.execute(qc, backend=backend,shots=1)
    result = job_exp.result()
    final=result.get_counts(qc)
    # convert final result to decimal
    measurement=binary_string_to_decimal(list(final.keys())[0])
    period_r=period_from_quantum_measurement(measurement,3,2,15)
    return period_r
print(run_shors_subroutine_period2_mod15())



4


# The last thing to do will be to implement the full Shor's algorithm and check if the r is correct by plugging it in, getting factors and checking results. If not, rerun the algorithm. 

In [10]:
def period_finding_quantum(a,N):
    # for the sake of example we will not implement this algorithm in full generality
    # rather, we will create an example with one specific a and one specific N
    # extension work could be done to impl
    if a==2 and N==15:
        return run_shors_subroutine_period2_mod15()
    else:
        raise Exception("Not implemented for N=%d, a=%d" % (N,a))
        
def shors_algorithm_quantum(N,fixed_a=None):
    assert(N>0)
    assert(int(N)==N)
    while True:
        if not fixed_a:
            a=random.randint(0,N-1) 
        else:
            a=fixed_a
        g=math.gcd(a,N)
        if g!=1 or N==1:
            first_factor=g
            second_factor=int(N/g)
            return first_factor,second_factor
        else:
            r=period_finding_quantum(a,N)  
            if not r:
                continue
            if r % 2 != 0:
                continue
            elif a**(int(r/2)) % N == -1 % N:
                continue
            else:
                first_factor=math.gcd(a**int(r/2)+1,N)
                second_factor=math.gcd(a**int(r/2)-1,N)
                if first_factor==N or second_factor==N:
                    continue
                if first_factor*second_factor!=N:
                    # checking our work
                    continue
                return first_factor,second_factor



In [11]:
# Here's our final result
shors_algorithm_quantum(15,fixed_a=2)


(5, 3)

In [13]:
# Now trying it out to see how the algorithm would function if we let it choose a given random a:
for a in range(15):
    # Here's the result for a given a:
    try:
        print("randomly chosen a=%d would result in %s"%(a,shors_algorithm_quantum(15,fixed_a=a)))
    except:
        print("FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=%d at this stage"%a)
            

randomly chosen a=0 would result in (5, 3)
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=1 at this stage
randomly chosen a=2 would result in (5, 3)
randomly chosen a=3 would result in (3, 5)
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=4 at this stage
randomly chosen a=5 would result in (5, 3)
randomly chosen a=6 would result in (3, 5)
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=7 at this stage
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=8 at this stage
randomly chosen a=9 would result in (3, 5)
randomly chosen a=10 would result in (5, 3)
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=11 at this stage
randomly chosen a=12 would result in (3, 5)
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=13 at this stage
FINISH IMPLEMENTING algorithm doesn't work with a randomly chosen a=14 at this stage
