In [2]:
import import_ipynb
import numpy as np
import math
import matplotlib.pyplot as plt
import gates as gate
import register as reg
from grover import grover
import random 
import string
plt.style.use('dark_background')

importing Jupyter notebook from gates.ipynb
importing Jupyter notebook from register.ipynb
importing Jupyter notebook from grover.ipynb


In [39]:
def shors(N,x,qubits):
    """
    Return a multiple of the period via QFT i.e. multiple of period by x^|a> mod N
    
        Parameters:
            N: int
                modulo value
            x: int
                base of exponents
            qubits: int
                qubit count for reg 1
                
        Returns:
            reg1.measure(): int
                returns a measured basis state (multiple of period)

        This is Shor's quantum algorithm, not the entire classical/QC combination.
        It is executed in the following steps:
        1. Generate two registers, the first with size N^2<size<2N^2, and the second of size N
        2. Force the first register into superposition.
        3. Apply the transformation x^|a> mod N on register 1 and store these values in register 2
        4. Measure the second register to get a value, k, which will be entangled with the first.
        5. Ensure entaglement is represented; set all ampltidues of register 1 to 1 if the 
           transformation yielded k, else set the amplitude zero.
        6. Apply the inverse Quantum Fourier Transform to register 1 to amplify this distribution.
        7. Measure register 1 to receive multiple of period.
    
    """
    # Steps 1 and 2
    # Initial setup of registers 1 and 2 (reg1, reg2) with sizes N^2<reg1size<2N^2 and reg2size = N.
    # Create global Hadamard (Hinitial) for reg1 and apply it to get an initial superposition. 
    n = 1<<qubits
    reg1 = reg.given(1,n)
    Hinitial = gate.globalapp(qubits,gate.hgate()) 
    reg1 = reg.register(np.dot(reg1.var, Hinitial.matrix))
    reg2 = reg.given(0,N)
    reg2.var[0]=0+0j
    
    # Step 3
    # Run through all basis states in reg1, and store their transformation function output x**|a> mod N
    # in register 2 (do not repeat the same output, though, just go to the next basis state)
    modlist = []
    count = 0
    for a in range(n):
        k = x**a % N
        if k not in modlist:
            reg2.var[k]=1+0j
            modlist.append(k)
            count+=1
    reg2.var=reg2.var/count
    
    # Step 4 and 5
    # Measure the second register, and collapse the first to remain consistent. This is done by measuring
    # "k" from register 2 (reg2), then running through register 1 (reg1) and changing amplitudes 
    # accordingly.
    k = reg2.measure()
    count = 0
    for a in range(n):
        if x**a % N != k:
            reg1.var[a]=0+0j
        else:
            reg1.var[a]=1+0j
            count+=1
    reg1.var=reg1.var/count
    
    # Step 6 and 7
    # Apply QFT to the entire register, which will amplify the relevant distances between the repeated
    # amplitudes. Return the measured output, which will be a multiple of this period.
    ftall=gate.ft(qubits)
    reg1=reg.register(np.dot(reg1.var, ftall.matrix))
    return reg1.measure()
            

In [41]:
def algorithm(N):
    """
    Print prime factorisation
    
    Parameters:
        N: int
            product of two primes
    
    This runs through all the requirements of Shor's algorithm,
    which is mostly a classical problem. First, a base is created
    to test our exponent basis states on. If that base isn't coprime,
    it will give the answers, but that isn't QC. If it is coprime,
    the correct qubit count (according to the rules listed in  "Shors"
    function) is found  for the QC part. Shor's is run, and the period 
    is checked. If it isn't 0, or odd, or lead to trivial factors, p 
    and q are solved for and printed.
    """
    
    # Classical part
    done = False  
    count=1 
    while not done:
        print("Count:", count)
        count += 1
        # Create random basis where x: x^a mod N
        x = np.random.randint(2, N)
        print("x =", x)
        # Check if x happens to be a factor
        if math.gcd(N, x) > 1:
            p = math.gcd(N, x)
            q = int(N / p)
            print("Guessed:", p, q)
            print(" ")
            done = True
        # If the factors aren't guessed by chance, go ahead and
        # try in the quantum setting
        else:
            # Define the setting for Shor's algorithm by finding the qubits
            # needed for register 1 (see Shors function).
            qubits = 1
            while (1<<qubits)<N**2:
                qubits+=1
            span = 1<<qubits
            periodcheck = False 
            while not periodcheck:
                # Run Shor's algorithm, for each multiple of the period produce,
                # ignore it if it is a known failure.
                mlist = []
                mlist.append(0)
                rmultiple = shors(N, x, qubits)
                if rmultiple == 0:
                    continue
                elif rmultiple in mlist:
                    continue
                else:
                    # failure conditions
                    mlist.append(rmultiple)
                    print("Multiple =", rmultiple)
                    r = period(N, x, rmultiple / span) 
                    if r == 0:
                        print("RETRY: r = 0.")
                        print("VVV")
                    
                    # success condition
                    else:
                        periodcheck = True
                        print("r =", r)
            # more failure conditions
            if (r+1)%2 == 0: 
                print("RETRY: Period odd")
                print("VVV")
            elif (x ** (int(r / 2)) + 1) % N == 0:
                print("RETRY: Trivial Factors")
                print("VVV")
            # success condition, and classical solution.
            else:
                p = math.gcd(x**(int(r/2))+1, N)
                q = math.gcd(x**(int(r/2))-1, N)
                print("Factors of",N,":", p, q)
                print(" ")
                done = True

In [33]:
def period(N, x, frac):
    """
    Return the period via continuing fraction method (all classical)
    
    Parameters:
        N: int
            modulo number
        x: int
            basis 
        frac: double
            ratio of period multiple to basis size
    
    This function uses continuing fraction to confirm
    deliver the correct period from the multiple/span
    ratio given in Shor's full algorithm steps.
    """
    # first do continued fraction expansion
    if frac < 1:
        continued = [0]
    else:
        continued = [int(frac)]
        frac -= a[-1]
    while frac > 10. ** (-4):
        frac = 1 / frac
        continued.append(int(frac))
        frac -= continued[-1]
        
    # next do period, starting with a list. We know
    # a = p/q where p and q are coprime, and we know
    # how to construct the list based on the continuing
    # fraction, and since we're dealing with a multiple,
    # q should be the period.
    q0 = 1
    q1 = continued[1]
    qt = [q0, q1]
    # the following is based on qn = a_n(q_n-1)+q_n-2
    for i in range(2, len(continued)):
        qt.append(continued[i]*qt[i - 1]+qt[i - 2])
    # one of the qs (period) should fit the system
    # x^q[i] % N, this tests to find out which one
    # that is.
    r = 0  
    for i in range(1, len(cont)):
        if 1 == x ** (qt[i]) % N:
            return qt[i]
    return r

In [42]:
def silent(N):
    """
    Print prime factorisation (WITHOUT ANY OTHER OUTPUTS, SAME AS ALGORITHM())
    
    Parameters:
        N: int
            product of two primes
    
    This runs through all the requirements of Shor's algorithm,
    which is mostly a classical problem. First, a base is created
    to test our exponent basis states on. If that base isn't coprime,
    it will give the answers, but that isn't QC. If it is coprime,
    the correct qubit count (according to the rules listed in  "Shors"
    function) is found  for the QC part. Shor's is run, and the period 
    is checked. If it isn't 0, or odd, or lead to trivial factors, p 
    and q are solved for and printed.
    """
    done = False  
    count=1 
    while not done:
        count += 1
        x = np.random.randint(2, N)
        if math.gcd(N, x) > 1:
            p = math.gcd(N, x)
            q = int(N / p)
            done = True
        else:
            qubits = 1
            while (1<<qubits)<N**2:
                qubits+=1
            span = 1<<qubits
            periodcheck = False 
            while not periodcheck:
                mlist = []
                mlist.append(0)
                rmultiple = shors(N, x, qubits)
                if rmultiple == 0:
                    continue
                elif rmultiple in mlist:
                    continue
                else:
                    mlist.append(rmultiple)
                    r = period(N, x, rmultiple / span) 
                    if r == 0:
                        continue
                    else:
                        periodcheck = True
            if (r+1)%2 == 0: 
                continue
            elif (x ** (int(r / 2)) + 1) % N == 0:
                continue
            else:
                p = math.gcd(x**(int(r/2))+1, N)
                q = math.gcd(x**(int(r/2))-1, N)
                done = True
    return p,q

In [35]:
def Shor_Demonstration():
    """
    Print the steps involved in RSA generation, encoding, and decoding using Shor's
    
    This function is a demonstration of Shor's. It randomly generates a prime-product "N",
    a public key "e", the toitent function "lambda", and the private key "d". It then encodes
    a random message "m". 
    
    public = (N,e)
    private = (N,d)
    
    Given this public information, it then uses Shor's algorithm to factor the large N, thereby
    arriving at the toitent function, and using e finds the private key. From there, all one 
    must do is take the cipher "c", and raise it to the "d"th power mod N to find the encoded
    message.
    
    """
    print(" RSA DECRYPTION")
    print("----ENCODING----")
    primes = [3, 5, 7]
    p = random.choice(primes)
    q = random.choice(primes)
    if p and q == 7:
        p = random.choice(primes[0:1])
    N = p*q
    Lambda = (p-1)*(q-1)
    print("N = " + str(N) + ", the product of primes " + str(p) + " and " + str(q))
    m = int(np.random.random()*N)+1
    print("Padded Message 'm' = " + str(m))
    e=3
    for i in range(2,Lambda):
        e+=1
        if math.gcd(e,Lambda)==1:
            break
    print("Public exponent 'e' = " + str(e))
    cipher = m**e%N
    print("m**e mod N = cipher 'c' = " + str(cipher))
    print("----DECODING-----")
    pshor,qshor = silent(N)
    print("Factors w/Shor = ",pshor,qshor)
    toitent = (pshor-1)*(qshor-1)
    d = pow(e, -1, toitent)
    print("Private Exponent 'd' = ",d)
    print("c**d mod N = m mod N")
    print("----TESTING----")
    mguess = cipher**d%N
    print("Decoded Message = ",mguess)

In [36]:
# Some useful stuff for a caeser cipher
alph=string.ascii_lowercase
al = [" ","a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","x"]
def encode():
    """Prompts user to write a message, which it will then convert to a numerical code"""
    w = list(input("Write your message: ").lower())
    a2n = []
    for i in range(len(w)):
        a2n.append(alph.find(w[i])+1)
    return a2n

In [38]:
def full_process():
    """
    This is a directly applied RSA decryption scheme.
    
    1. Prompt user for a message to encode.
    2. Pad this message with caeser cipher
    3. Generate product of primes "N", public key "e",
       totient function "lambda", and private key "d".
    4. Encode this paded message into a cipher.
    5. Use Shor's to break the "d" key value using 
       prime factorisation of N and the totient function
    6. Decipher the message to reveal padded message
    7. Unpad the message to show original text
    
    """
    # Step 1 and 2
    code=encode()
    
    # Step 4
    primes = [3, 5]
    p = random.choice(primes)
    q = random.choice(primes)
    if p and q == 7:
        p = random.choice(primes[0:1])
    # specified N, totient, and e, just for
    # simplicity and ease of understanding
    N = 21
    Lambda = 18
    e=17
    #for i in range(2,Lambda):
    #    e+=1
    #    if math.gcd(e,Lambda)==1:
    #        break
    results = []
    
    #Step 5 and 6
    #For more information on how the decoding works
    #see def algorithm or def Shors
    for i in range(len(code)):
        m = code[i]
        cipher = m**e%N  
        pshor,qshor = silent(N)
        toitent = (pshor-1)*(qshor-1)
        d = pow(e, -1, toitent)
        results.append(cipher**d%N)
    
    #Step 7
    decoded = []
    for b in range(len(results)):
        decoded.append(al[results[b]]) 
        if results[b]==0:
            decoded[-1]="z"
    decoded=" ".join(decoded)
    decoded=decoded.replace(" ", "")  
    decoded=decoded.replace("z"," ")
    print("The decoded message is: ",decoded)