In [27]:
import numpy as np
import random
from sympy import isprime
from fractions import Fraction
import sys
from tqdm import tqdm

from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit.quantum_info import Statevector
from qiskit_aer import AerSimulator
service = QiskitRuntimeService()

from utils import bin_to_int

In [32]:
class QDSimulator:
    """
        This script simulates secure communication between two parties, Alice and Bob, using
        the RSA encryption scheme. It also simulates an attack by a third party, Eve, who attempts 
        to break the RSA encryption using Shor's quantum factorization algorithm. Users can choose 
        to simulate quantum circuits locally using the AerSimulator or execute them on IBM's cloud-
        based quantum backends provided by Qiskit. The script supports circuits of up to approximately 
        22 qubits, enabling the factorization of RSA-encrypted messages that use prime numbers up to 
        the order of 10² with the implemented Shor’s algorithm.
    """

    def __init__(self, 
                 prime_lower_bound: int = int(1e1), 
                 prime_upper_bound: int = int(1e3), 
                 message: str = "hello world",
                 num_Qbits: int = 5,
                 verbose: bool = False,
                 ):
        # initialize lower and upper bound for prime numbers used for RSA
        self.prime_lower_bound = prime_lower_bound
        self.prime_upper_bound = prime_upper_bound
        # string message
        self.message = message
        # num Qbits
        self.n = num_Qbits
        # print stuff for debugging purposes
        self.verbose = verbose

        print("QDSimulator initialized! This class allows you to " 
              "\n   1. encrypt and decrypt messages using the RSA scheme, relying on ASCII for character to int encoding"
              "\n   2. decrypt the RSA encyption using quantum computing (AerSimulator or IBM's cloud-based quantum backends)"
        )
        

    def InitializeRSAKeys(self):
        """
            Generates encoding (encoding_key) and decoding (decoding_key) keys given two distinct 
            prime numbers p and q. Define G_M to be a group under modular multiplication (mod M). The 
            encoding (c) and decoding (d) keys are elements of G_{(p-1)(q-1)} and are eachothers inverses 
                cd ≡ 1 (mod (p-1)(q-1))
            such that for any integer a that is decrypted using
                b ≡ a^c (mod pq)
            it can be shown to be decrypted as following
                b^d ≡ a (mod N)
        """

        self._generate_primes()
        self._generate_encoding_key()
        self._generate_decoding_key()

    def _generate_primes(self):
        """
            Chooces two distinct random prime numbers p and q between the given lower 
            (prime_lower_bound) and upper (prime_upper_bound) bounds. Computes their product (N=p*q)
            and the group order of G_N (group_order = (p-1)*(q-1))
        """

        # construct a list of primes between prime_lower_bound and prime_upper_bound
        primes = [i for i in range(self.prime_lower_bound,self.prime_upper_bound) if isprime(i)]
        # pick two distinct random primes from list
        p, q = random.sample(primes, 2)
        self.p, self.q = p, q
        self.N, self.group_order = p*q, (p-1)*(q-1)

        print(f"\nGenerated random prime numbers p={p} and q={q}, and their product N={self.N}")
    
    def _generate_encoding_key(self):
        """
            Generates encoding_key ∈ G_{(p-1)(q-1)} (i.e key has no factor in common with (p-1)(q-1)
            NOTE: (p-1)*(q-1) is also the order of the G_N group.

            Picks random candidates that are smaller than group_order and checks their GCD via the 
            Euclidean Algorithm (see below). If GCD=1, candidate is chosen. 
            
            The encoding key should be shared with others.

            Euclidean Algorithm: 
            Computes the GCD of two integers. Assume we want to find the GCD of f and c, with f > c.
            Iteratively replace f and c with
                f' = f
                c' = f - int(f/c)*c
            where int(x) is the integer less than or equal to x. The GCD between f' and c' will be the
            same as between f and c. Iteratively apply algorithm until c'=0. The c' from the next to
            last iteration, call it c_0, is the GCD between f and c. 
        """

        found_encoding_key = False
        
        # loop over encoding_key candidates. Add loop counter to avoid infinite while loop
        loop_counter1 = 0
        while (not found_encoding_key and loop_counter1 < 100):
            # candidate for encoding key c
            encoding_key_candidate = random.randint(self.prime_lower_bound, self.group_order-1)
            # re-set f and c for next candidate
            c = encoding_key_candidate
            f = self.group_order

            # apply the Euclidean algorithm using loop. Add loop counter to avoid infinite while loop
            loop_counter2 = 0
            while (c > 0 and loop_counter2 < 100):
                # c from last iteration
                c_0 = c
                # Euclidean algorithm
                f, c = c, f - int(f/c)*c
                loop_counter2+=1
            
            # if GCD = 1, encoding key is c from last iteration
            if c_0 == 1:
                found_encoding_key = True
                self.encoding_key = encoding_key_candidate

            loop_counter1+=1

        if found_encoding_key:
            print(f"\nEncoding key generated c={self.encoding_key} (with (p-1)*(q-1)={self.group_order})")
        else:
            sys.exit(f"Could not generate an encoding key after {loop_counter1} attempts. Try increasing maximum allowed attempts")

    def _generate_decoding_key(self):
        """ 
            Generates the decoding key by finding the inverse modulo of the encoding key
                decoding_key * encoding_key = 1 (mod (p-1)(q-1))
            Do this by setting GCD=1 and applying the Eculidean algorithm backwards 

            The decoding key should not be shared with others.
        """

        f = self.group_order
        c = self.encoding_key

        found_decoding_key = False
        it = self.N 
        print("\nGenerating decoding key ...")
        for j in tqdm(range(it)):
            for k in range(-it,it):
                if (j*f+k*c == 1):
                    found_decoding_key = True
                    break
            if found_decoding_key:
                break
        if not found_decoding_key:
            sys.exit("Could not generate a decoding key. Try increasing searching parameter")

        d = k % f
        if ((c*d) % self.group_order) != 1:
            sys.exit(f"Decoding key check failed, incorrect key generated ...")

        self.decoding_key = d
        print(f"..Done. Check pased. Decoding key d={d}")
    

    def EncryptMessage(self):
        """ 
            Encrypt message of string type into a sequence of ints using the built-in ASCII encoding.
            Each integer (representing a char in the string) is furhter encrypted using the generated 
            encoding key above.
        """

        # ASCII encoding of string message -> ASCII message
        ascii_message = np.array([ord(char) for char in self.message])
        if ascii_message.any() > self.N:
            sys.exit(f"N to small. N is {self.N} and the largest ascii code is {max(ascii_message)}")

        # RSA encryption of ASCII message -> RSA message
        encoding_key = self.encoding_key
        N = self.N
        rsa_message = []
        for ascii_code in ascii_message:
            rsa_code = pow(int(ascii_code), encoding_key, N)
            rsa_message.append(rsa_code)
        self.encrypted_message = rsa_message

        print(f"\nMessage has been encrypted to {rsa_message}")
    
    def DecryptMessage(self):
        """ 
            Decrypt the encrypted message back to a sequuence of integers using the decoding key. Decode
            the ASCII code back to a collection of chars (string)
        """

        # Decryption of RSA message -> ASCII message
        decoding_key = self.decoding_key
        N = self.N
        ascii_message = []
        for rsa_code in self.encrypted_message:
            ascii_code = pow(rsa_code, decoding_key, N)
            ascii_message.append(ascii_code)
        
        # Decoding of ASCII message -> (string) message
        decrypted_message = ''.join(chr(ascii_code) for ascii_code in ascii_message)
        self.decrypted_message = decrypted_message

        print(f"\nMessage has been decrypted to '{decrypted_message}'")


    def QuantumDecryption(self):
        """
            Decrypt encrypted message using quantum computations. 

                1. Use Shor's factorization algorithm to identify the period of r Alice's original message
                    a. find the period r of Alice's encrypted message, i.e. find the period r of the modular
                        function
                            f(x) = b**x (mod N)
                        where b = encrypted_message, such that f(x) = f(x+r) for all natural numbers x. 
                    b. since encrypted_message = message**encoding_key, encrypted_message and (the original) 
                        message generate the same group, and therefore share the same period. Hence, by identifing 
                        the period of the encrypted message, we have also found the period of the encrypted message.
                
                2. Find the modulor r inverse of the publicly available encoding_key (c), called decoding_key_prime 
                    (d_prime), such that
                        c * d_prime ≡ 1 (mod r)
                    a. since encoding_key was chosen by Bob to have no factors in common with (p-1)(q-1) and r divides
                        the order of G_pq (p-1)(q-1) 
                        -> encoding_key has no factors in common with r
                        -> encoding_key (mod r) := encoding_key_prime ∈ G_r
                        -> ∃ decoding_key_prime s.t 
                            c * d_prime ≡
                                 ≡ c_prime * d_prime ≡ 
                                 ≡ 1 (mod r)
                    b. decoding_key_prime can be found using a classical computer, here we perform a forced scan
                3. Using the period r found by using Shor's algorithm and the decoding_key_prime, Alice's encrypted
                    message (b) can be decrypted to find her original message (a)
                        b^d_prime ≡
                            ≡ a^(c*d_prime) ≡
                            ≡ a^(1 + mr) ≡
                            ≡ a * (a^(r)^m) ≡
                            ≡ a * (1^r)) =
                            = a (mod r)
        """

        print("\nUsing the magical powers of quantum computation to break the RSA encryption ...")
        encrypted_message = self.encrypted_message

        decrypted_message = []
        # loop over one encrypted char at a time
        for b in tqdm(encrypted_message):
            # initialize the curcuit whose measurement is used to find the period
            self._initialize_quantum_curcuit(b)
            # perform measurements until the period is found
            self.found_period = False
            loop_counter = 0
            while not self.found_period and loop_counter <= 30:
                self._measure_quantum_curcuit_aerSimuulator()
                self._period_finding_algorithm(b)
                loop_counter += 1
            # if the period is found within the allowed number of iterations, use the period to decrypt 
            # the char, otherwise set the decrypted char to "?"
            if self.found_period:
                decrypted_char = self._retrieve_message(b)
            else:
                decrypted_char = "?"

            decrypted_message.append(decrypted_char)

            if self.verbose:
                print(f"   decrypted_char: {decrypted_char}")

        print(f"\n..Done. The message was decrypted to {''.join(self.decrypted_message)}")

    def _initialize_quantum_curcuit(self,b):
        """
            For Shor's algorithm a quantum curcuit is prepared in two steps.
                1. n-Qbit Hadamard transformation applied to input registry, and unitary gate U_f 
                represetning the modular function
                    f(x) = b^x (mod N)
                applied to output registry, producing the following state
                    1/2^{n/2} \Sum_{x=0}^{2^n-1} |x> |f(x)> * 
                The output regester is measured (to say f_0)) and the input register becomes superposition of 
                states that yield the particular output register measurement
                     1/\sqrt{m} \Sum_{x=0}^{m} |x_0 + k*r>
                where x_0 is the smallest value of x (0 <= x_0 < r) for which f(x_0) = f_0, and m is the smallest
                integer for which x_0 + mr >= 2^n.
                2. A quantum Fourier transform is applied on the resulting input register.

            NOTE: the modular function is not applied using a unitary gate. Instead the function f(x) is computed
            for each x, and a f(x) is randomly chosen. The curcuit is then initialized in the state *.
        """

        n = self.n
        N = self.N

        qreg = QuantumRegister(n, name="qreg")
        creg = ClassicalRegister(n, name="creg")
        qc = QuantumCircuit(qreg,creg)

        if self.verbose:
            print("\n   Running pre-calculations for quantum curcuit initialization ...")
        # compute \Sum_{x=0}^{2^n} f(x) where f(x) = b^x (mod N) and assign to output registry
        output_registry_state = np.zeros(2**n)
        previous_b = 1
        for input_registry_int in range(2**n):
            previous_b = (previous_b * b) % N
            output_registry_state[input_registry_int] = previous_b
        if self.verbose:
            print("   ..Done")

            print("   Initializing quantum curcuit ...")
        # perform measurement on output registry
        output_registry_state_measurement = int(np.random.choice(output_registry_state))

        # set input registry to superposition corresponding the the measurement performed on the output registry above
        input_registry_state = np.where(output_registry_state == output_registry_state_measurement, 1, 0)
        input_registry_state = input_registry_state / np.linalg.norm(input_registry_state)

        # initialize qiskit quantum curcuit 
        qc.initialize(input_registry_state, qreg[:])
        # self.psi = Statevector(qc)
        if self.verbose:
            print("   ..Done")

        # apply quantum Fourier transform to input registry. The 2-Qbit controlled-V gates have been replaced by 1-Qbit 
        # gates that act or not, depending on the outcome of a prior measurement of the control Qbit.
        for i in range(n-1,-1,-1):
            qc.h(qreg[i])
            qc.measure(qreg[i],creg[i])
            for j in range(i-1,-1,-1):
                qc.p(np.pi/(2**(abs(i-j))),j).c_if(creg[i],1)

        self.qc = qc
    
    def _measure_quantum_curcuit_aerSimuulator(self):
        """
            Measure quantum curcuit initialized above. Run locally using AerSimulator The measured value is used 
            to identify the period r of Alice's message (_period_finding_algorithm).
        """

        qc = self.qc

        # transpile quantum curcuit with AerSimulator
        simulator = AerSimulator()
        compiled_curcuit = transpile(qc, simulator)

        if self.verbose:
            print("   Measuring quantum curcuit ...")
        # run job and extract results
        job = simulator.run(compiled_curcuit, shots=1)
        sim_result = job.result()
        counts = sim_result.get_counts()

        self.result = sim_result
        self.counts = counts

        y_bn = list(counts.keys())[0]
        y = bin_to_int(y_bn)
        self.y = y
        if self.verbose:
            print(f"   ..Done. y was measured to {y}")


    def _measure_quantum_curcuit_ibm(self):
        """
        FIXME: clean-up, might not work
        Add description
        """
        print(1)
        qc = self.qc

        # fetch the backend
        print(2)
        backend_name = "ibm_brisbane"
        print(3)
        backend = service.backend(backend_name)

        # create a pass manager, which will populate the method with 
        # everything that needs to be know about the backend, e.g. connectivity,
        # native gates
        pass_manager = generate_preset_pass_manager(optimization_level=3, backend=backend)

        # transpile the code (source code to source code)
        qc_transpiled = pass_manager.run(qc)
        pass_manager = generate_preset_pass_manager(optimization_level=1, backend=backend)
        print(4)
        #qc_transpiled = transpile(qc, backend)
        qc_transpiled = pass_manager.run(qc)


        # options = EstimatorOptions()
        # # level 1 = use measurement readout mitigation, level 2 = zero noise extrapolation, ... 
        # options.resilience_level = 1
        # # transpolation already done on local device
        # options.optimization_level = 0
        # # qc has a lot idle time, use dynamic coupling
        # options.dynamical_decoupling.enable = True
        # options.dynamical_decoupling.sequence_type = "XY4"

        # generate estimator object with options
        #estimator = EstimatorV2(backend, options)

        # run job with transpiled qc and opterator list
        print(5)
        job = backend.run(qc_transpiled, shots=10)
        print(6)
        job_id = job.job_id()
        print(7)
        self.job_id = job_id
        print("\njob_id", job_id)

    def _check_job_ibm(self):
        """
        FIXME: clean-up, might not work
        Add description
        """

        job = service.job(self.job_id)
        result = job.result()
        counts = result.get_counts()
        self.counts = counts
        print("\ncounts", counts)


        try:
            for i in range(10):
                y_bn = list(counts.keys())[i]
                y = bin_to_int(y_bn)
                print(f"\n{i}: y was measured to {y}")
        except:
            pass

        y_bn = list(counts.keys())[0]
        y = bin_to_int(y_bn)
        self.y = y
        print(f"\n..Done. y was measured to {y}")


    def _periodic_function(self,b,x):
        """Helper function to compute value of periodic function"""
        N = self.N
        return (b**x) % N
    
    def _period_finding_algorithm(self, b):
        """
            The measured value y from the quantum curcuit above can be shown to be with at least 40% probability 
            within 1/2 of an integral multiple of 2^n/r, i.e. for some int j we have an estimate for the fraction 
            j/r in
                | y/2^n - j/r | <= 1/2^{n+1}.
            
            The value of j/r can be extracted from the known value of y/2^n by applying continued fractions. It 
            gives us integers j_0 and r_0 with no common factors that satisfy j_0/r_0 = j/r. The r_0 we learn is 
            thus a divisor of r. 

            NOTE: for the continued frations theorem to work we need to use an input an n-Qbit input register with
            2^n > N^2
        """

        n = self.n
        y = self.y
        N = self.N

        # our estimation of j/r
        j_over_r_approx_float = y/2**n
        if self.verbose:
            print(f"   denominator should be less than N {self.N} or {2**(n/2)}")
            print("   j_over_r_approx_float", j_over_r_approx_float)

        # compute continued fraction terms
        continued_fractions_list = []
        frac_part = j_over_r_approx_float
        counter = 0
        # FIXME
        counter_limit = 100
        while (frac_part > 1/2**n) and counter <= counter_limit:
            inv = 1/frac_part
            int_part = int(inv)
            frac_part = inv - int_part
            continued_fractions_list.append(int_part)
            counter+=1
        # retry if continued fractions stopped before finding largest possible denominator smaller than N
        if counter == counter_limit:
            if self.verbose:
                print(f"   continued fractions failed. Denominator {frac_part} needs to be less than {1/2**n}")
            return
        # retry if continued fractionas was unable to find a denominator which is less than N
        if len(continued_fractions_list) < 1:
            if self.verbose:
                print("   no period candidates found, retrying")
            return
        
        # using computed continued fractions terms above, compute the continued fraction which has denominator
        # less than 2**n
        j_over_r_approx_frac_temp = Fraction(-999)
        i=1
        while (i <= len(continued_fractions_list) and j_over_r_approx_frac_temp.denominator < self.N):
            j_over_r_approx_frac = j_over_r_approx_frac_temp
            continued_fractions_list_sliced = continued_fractions_list[:i]
            res = continued_fractions_list_sliced[-1]
            for int_part in reversed(continued_fractions_list_sliced[:-1]):
                res = int_part + Fraction(1,res)
            j_over_r_approx_frac_temp = Fraction(1,res)
            i+=1
            
        # check if j_0/r_0 fraction matches the j/r float, if not, retry
        if np.round(float(j_over_r_approx_frac),3) == np.round(j_over_r_approx_float,3):
            if self.verbose:
                print(f"   successfully found j/r to {j_over_r_approx_frac}")
        else:
            if self.verbose:
                print(f"   failed to converge j_over_r_approx_float {j_over_r_approx_float} to j_over_r_approx_frac {j_over_r_approx_frac}")
            return

        # pick r_0 from j_0/r_0
        r_0 = j_over_r_approx_frac.denominator
        # r will be some multiple of r_0 less than N
        #r_list = [i for i in range(r_0, 2**(int(n/2)), r_0)]
        r_list = [i for i in range(r_0, N, r_0)]

        if self.verbose:
            print(f"   found {len(r_list)} period candidates")

        # if too many potential r candidates, retry
        if len(r_list) > 30:
            if self.verbose:
                print("   too many period candidates, retrying")
            return
        
        # plug r candidates into periodic function and check
        for r_candidate in r_list:
            if b == self._periodic_function(b,1+r_candidate):
                self.r = r_candidate
                self.found_period = True
                break
        
        if not self.found_period and self.verbose:
            print("   no match for potential period, retrying")
    
    def _retrieve_message(self, b):
        """
            With the period r from Shor's algorithm and the decoding_key_prime, Alice's encrypted message
            can be decrypted.
        """
        r = self.r
        c = self.encoding_key
        N = self.N

        self.retrieved_message = False
        it = N+1
        found_d_prime = False
        for d_prime in range(0,it):
            if c * d_prime % r == 1:
                self.d_prime = d_prime
                found_d_prime = True
                break
        
        if not found_d_prime:
            if self.verbose:
                print("   did not find d prime")
            decrypted_char = "?"
        else:
            deciphered_ascii_code = b**d_prime % N
            decrypted_char = ''.join(chr(deciphered_ascii_code))

        return decrypted_char
        


my_class = QDSimulator(prime_lower_bound=int(3e1), 
                       prime_upper_bound=int(7e1),
                       message="Hello World!",
                       num_Qbits = 20,
                       verbose = False
                       )

QDSimulator initialized! This class allows you to 
   1. encrypt and decrypt messages using the RSA scheme, relying on ASCII for character to int encoding
   2. decrypt the RSA encyption using quantum computing (AerSimulator or IBM's cloud-based quantum backends)


In [33]:
my_class.InitializeRSAKeys()
my_class.EncryptMessage()
my_class.DecryptMessage()


Generated random prime numbers p=43 and q=37, and their product N=1591

Encoding key generated c=169 (with (p-1)*(q-1)=1512)

Generating decoding key ...


  5%|▍         | 75/1591 [00:00<00:00, 4947.20it/s]

..Done. Check pased. Decoding key d=841

Message has been encrypted to [1534, 101, 366, 366, 111, 462, 560, 111, 1447, 366, 100, 377]

Message has been decrypted to 'Hello World!'





In [34]:
my_class.QuantumDecryption()


Using the magical powers of quantum computation to break the RSA encryption ...


100%|██████████| 12/12 [04:13<00:00, 21.11s/it]


..Done. The message was decrypted to Hello World!





In [38]:
qc = my_class.qc
# draw quantum curcuit
# qc.draw(output="mpl")