In [None]:
import numpy as np
import random
from sympy import isprime
from fractions import Fraction

import sys
from tqdm import tqdm

from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit.quantum_info import Statevector
from qiskit_aer import AerSimulator
from qiskit_aer.primitives import EstimatorV2, Estimator
from qiskit.quantum_info import SparsePauliOp
from qiskit.visualization import plot_histogram

from utils import bin_to_int







In [None]:
class QDSimulator:
    def __init__(self, 
                 lower_bound=int(1e1), 
                 upper_bound=int(1e3), 
                 message="hello world",
                 num_Qbits=5,
                 mod_N = None,
                 member = None,
                 exponent = None,
                 ):
        # Initialize lower and upper bound for N
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        # string message
        self.message = message
        # num Qbits
        self.n = num_Qbits

        # testing
        # self.mod_N = mod_N
        # self.b = member
        #self.x = exponent


        print("QDSimulator intitialized!")

    def InitializeEncryptionKeys(self):
        """
            Generates encoding (encoding_key) and decoding (decoding_key) keys given two prime
            numbers (p and q)
        """
        self._generate_primes()
        self._generate_encoding_key()
        self._generate_decoding_key()

    def _generate_primes(self):
        """
            Chooces two random prime numbers (p and q) between the given lower (lower_bound)
            and upper (upper_bound) bounds and computes their product (N=p*q) and the group order
            (group_order = (p-1)*(q-1))
        """

        primes = [i for i in range(self.lower_bound,self.upper_bound) if isprime(i)]
        p, q = random.choice(primes), random.choice(primes)        
        self.p, self.q = p, q
        self.N, self.group_order = p*q, (p-1)*(q-1)

        print(f"\nGenerated random prime numbers p={p} and p={q}, and their product N={self.N}")
    
    def _generate_encoding_key(self):
        """
            Generates the encoding key by finding a number that has no factor in common with
            group_order = (p-1)*(q-1). Picks random candidates that are smaller than group_order
            and checks their GCD via the Euclidean algorithm. If GCD=1, candidate is chosen. 
            
            The encoding key should be shared with others.
        """

        found_encoding_key = False
        
        # Loop over encoding_key candidates. Add loop counter to avoid infinite while loop
        loop_counter1 = 0
        while (not found_encoding_key and loop_counter1 < 100):
            # Candidate for encoding key c
            encoding_key_candidate = random.randint(self.lower_bound, self.group_order-1) # lower bound needed?
            # Re-set f and c for next candidate
            c = encoding_key_candidate
            f = self.group_order

            # apply the Euclidean algorithm using loop. Add loop counter to avoid infinite while loop
            loop_counter2 = 0
            while (c > 0 and loop_counter2 < 100):
                # c from last iteration
                c_0 = c
                # Euclidean algorithm
                f, c = c, f - int(f/c)*c
                loop_counter2+=1
            
            # if GCD = 1
            if c_0 == 1:
                found_encoding_key = True
                self.encoding_key = encoding_key_candidate
            loop_counter1+=1

        if found_encoding_key:
            print(f"\nEncoding key generated c={self.encoding_key} (with (p-1)*(q-1)={self.group_order})")
        else:
            sys.exit(f"Could not generate an encoding key after {loop_counter1} attempts. Try increasing maximum allowed attempts")

    def _generate_decoding_key(self):
        """ 
            Generates the decoding key by finding the inverse modulo of the encoding key
                decoding_key * encoding_key = 1 (mod (p-1)(q-1))
            Do this by setting GCD=1 and applying the Eculidean algorithm backwards ... 

            The decoding key should not be shared with others.
        """

        f = self.group_order
        c = self.encoding_key

        found_decoding_key = False
        it = 10000
        print("\nGenerating decoding key ...")
        for j in tqdm(range(-it,it)):
            for k in range(-it,it):
                if (j*f+k*c == 1):
                    found_decoding_key = True
                    break
            if found_decoding_key:
                break
        if not found_decoding_key:
            sys.exit("Could not generate a decoding key. Try increasing searching parameter")

        d = k % f
        if (remainder := (c*d) % self.group_order) != 1:
            sys.exit(f"Decoding key check failed, incorrect key generated")

        self.decoding_key = d
        print(f"Done. Check pased. Decoding key d={d}")
    

    def EncryptMessage(self):
        """ 
            Add description
        """

        ascii_message = np.array([ord(char) for char in self.message])
        if ascii_message.any() > self.N:
            sys.exit(f"N to small. N is {self.N} and the largest ascii code is {max(ascii_message)}")
        #self.ascii_message = ascii_message

        encoding_key = self.encoding_key
        N = self.N
        encrypted_message = []
        for ascii_code in ascii_message:
            encrypted_ascii_code = pow(int(ascii_code), encoding_key, N)
            encrypted_message.append(encrypted_ascii_code)
        self.encrypted_message = encrypted_message

        print(f"\nMessage has been encrypted to {encrypted_message}")
    

    def DecryptMessage(self):
        """ 
            Add description
        """

        decoding_key = self.decoding_key
        N = self.N
        decrypted_ascii_message = []
        for encrypted_ascii_code in self.encrypted_message:
            decrypted_ascii_code = pow(encrypted_ascii_code, decoding_key, N)
            decrypted_ascii_message.append(decrypted_ascii_code)
        
        #decrypted_message = [chr(ascii_code) for ascii_code in decrypted_ascii_message]
        decrypted_message = ''.join(chr(ascii_code) for ascii_code in decrypted_ascii_message)
        self.decrypted_message = decrypted_message

        print(f"\nMessage has been decrypted to '{decrypted_message}'")


    def _initialize_quantum_curcuit(self):
        n = self.n
        b = self.encrypted_message[0]

        qreg = QuantumRegister(n, name="qreg")
        creg = ClassicalRegister(n, name="creg")
        qc = QuantumCircuit(qreg,creg)

        def int_to_bin(x):
            """ Helper function ... """
            bn = bin(x)[2:].zfill(n)
            return bn
        

        def int_to_int(x):
            """ Helper function ... """
            bn = bin(x)[2:]
            my_int = 0
            for (i,bit) in enumerate(bn[::-1]):
                if bit == "1":
                    my_int += 2**(n+i)
            return my_int

        output_registry_state = np.zeros(2**n)
        for input_registry_int in range(2**n):
            output_registry_state[input_registry_int] = self._test_periodic_function(b,input_registry_int)
        output_registry_state_measurement = int(np.random.choice(output_registry_state))
        input_registry_state = np.where(output_registry_state == output_registry_state_measurement, 1, 0)
        input_registry_state = input_registry_state / np.linalg.norm(input_registry_state)

        # output_registry_int = int_to_int(f_of_x)
        # index = input_registry_int + output_registry_int
        # if index > len(initial_state):
        #     sys.exit("\nFailed to initialize quantum curcuit in superposition of all possible states for input registry and corresponding states for output registry.")
        # initial_state[index] = 1/np.sqrt(2**n)

        qc.initialize(input_registry_state, qreg[:])
        self.psi = Statevector(qc)
        print("\nQuantum curcuit has been initialized in proper superposition of all possible states for input registry and corresponding states for output registry.")

        for i in range(n-1,-1,-1):
            qc.h(qreg[i])
            qc.measure(qreg[i],creg[i])
            for j in range(i-1,-1,-1):
                qc.p(np.pi/(2**(abs(i-j))),j).c_if(creg[i],1)

        self.qc = qc
    
    def _measure_quantum_curcuit(self):
        qc = self.qc

        simulator = AerSimulator()
        compiled_curcuit = transpile(qc, simulator)
        job = simulator.run(compiled_curcuit, shots=1)
        sim_result = job.result()
        counts = sim_result.get_counts()
        self.result = sim_result
        self.counts = counts

        y_bn = list(counts.keys())[0]
        y = bin_to_int(y_bn)
        self.y = y
    
    def _test_periodic_function(self,b,x):
        N = self.N
        return (b**x) % N

    
    def _period_finding_algorithm(self):
        # Initialize the input registry to the encrypted message b. Apply the quantum curcuit defined
        # above and make a few measurements and convert from binary to decimal. Each result will with
        # prob 40% be within 1/2 of a multiple of 2**n/r. n is known, calculate r within the error margin.
        #  Make a plot of different r. Take the closest integerhos

        # dummy variables
        # y = 13653 #11490
        # n = self.n

        n = self.n
        y = self.y
        j_over_r_approx_float = y/2**(2*n)

        # compute continued fraction terms
        continued_fractions_list = []
        frac_part = j_over_r_approx_float
        counter = 0
        # compute extra terms -> 1/2**(2*n)
        while (frac_part > 1/2**(2*n) and counter <= 1000):
            inv = 1/frac_part
            int_part = int(inv)
            frac_part = inv - int_part
            continued_fractions_list.append(int_part)
            counter+=1  
        if counter > 999:
            sys.exit("continued fractions failed")
        
        print(continued_fractions_list)
        

        # using computed continued fractions terms above, compute the continued fraction which
        # has denominator less than 2**n
        j_over_r_approx_frac_temp = Fraction(-999)
        i=1
        while (i <= len(continued_fractions_list) and j_over_r_approx_frac_temp.denominator < 2**n):
            j_over_r_approx_frac = j_over_r_approx_frac_temp
            continued_fractions_list_sliced = continued_fractions_list[:i]
            res = continued_fractions_list_sliced[-1]
            for int_part in reversed(continued_fractions_list_sliced[:-1]):
                res = int_part + Fraction(1,res)
            j_over_r_approx_frac_temp = Fraction(1,res)
            i+=1

        # check if j/r fraction matches the j/r float
        if np.round(float(j_over_r_approx_frac),3) == np.round(j_over_r_approx_float,3):
            print(f"\nsuccessfully found j/r to {j_over_r_approx_frac}")
        else:
            sys.exit(f"\nfailed to converge j_over_r_approx_float {j_over_r_approx_float} to j_over_r_approx_frac {j_over_r_approx_frac}")

        r_0 = j_over_r_approx_frac.denominator
        r_list = [i for i in range(r_0, 2**n, r_0)]

        print("r list", r_list)


        # given list of potential periods, check which is the correct one by plugging them 
        # into the periodic function
        # found_period = False
        # if len(r_list) < 100:
        #     b = self.encrypted_message
        #     N = self.N
        #     for r_prime in r_list:
        #         if (b**r_prime) % N == b:
        #             self.decrypted_period = r_prime
        #             found_period = True
        #             break
        # if not found_period:
        #     sys.exit("\ntoo many potential periods!")

    
    # def _test_periodic_function_algorithm(self):
    #     b = self.b
        
        # found_it = False
        # for i in range(2,100):
        #     res = self._period_finding_algorithm(i)
        #     if res == b:
        #         print("Yo we found it")
        #         print(i-1)
        #         found_it = True
        #         break
        # if not found_it:
        #     print("did not find")


    

        


my_class = QDSimulator(lower_bound=int(3e1), 
                       upper_bound=int(1e2),
                        message="hello name is elias",
                        num_Qbits = 24,
                        #mod_N = 4,
                        #member = 7,
                        )


my_class.InitializeEncryptionKeys()
my_class.EncryptMessage()
my_class.DecryptMessage()



In [None]:
my_class._initialize_quantum_curcuit()
psi = my_class.psi
qc = my_class.qc


my_class._measure_quantum_curcuit()
counts = my_class.counts
print(counts)


#psi.draw("latex")
qc.draw(output="mpl")