# Implementation of RSA algorithm
### Simone Cangini - RSA Project

## About RSA

RSA is an asymmetric public-key cryptosystem used to encrypt messages. Thanks to the public-key, everyone can encrypt a message, while only who created the public-key can decrypt the received message.<br>
It is a computationally intensive algorithm so can be used to exchange a private key for symmetric-key cryptography at the beginning of the conversation.<br>
The security of RSA relies on the practical difficulty of factoring the product of two large prime numbers.<br>
<br>
Alice:
 - produce 2 large prime numbers (**p**, **q**);
 - computes their product ($n=pq$);
 - computes the totient of the product, easy because knows the factorization ($m=(p-1)*(q-1)$);
 - choose a number **e** coprime with m;
 - computes the inverse of e (mod m), **d**: $ed=1 \bmod m$.

The **public key** is then composed by **n**, **e**, while **d** must remain private.
<br><br>
**Encryption**: Bob encrypt the message x computing $y = x^{e} \bmod n$.<br>
**Decryption**: Alice obtain original message x by $x = y^{d} \bmod n$.
<br>
<img src="http://www.mcseven.me/wp-content/uploads/2009/05/rsa_encryption.png" width="300px" margin="auto">

In [1]:
import random

## Miller-Robin algorithm
In this section is implemented the Miller-Robin algorithm for primality test.<br>
The two prime numbers requested are created by generating large random numbers and checking if they are prime.

### Decompose num

In [2]:
# Write p = s^r*q + 1
    
num = 13
r, q = 0, num - 1
while q % 2 == 0:
    r += 1
    q //= 2
    
print(f'r: {r}; q: {q}')

r: 2; q: 3


`q` must be odd

### Implement algorithm as function

In [3]:
def miller_robin(num, niter=None):
    '''
    Miller-Robin algorithm implementation
    Check if a number is prime
       
    With default niter prob of num erroneusly detected
        as prime is 8e-25
       
    INPUT:
        num: int
            number to be checked
        niter: int (default 40)
            number of iterations to be performed
     
    OUTPUT:
        result: True|False
            True if number is prime
    '''
    if niter is None:
        niter = 40
    
    if num % 2 == 0:
        return False
    
    # Write p = s^r*q + 1
    r, q = 0, num - 1
    while q % 2 == 0:
        r += 1
        q //= 2
        
    # Repeat niter times
    for _ in range(niter):
        x = random.randint(1, num-1)
        #y = (x**q) % num
        y = pow(x, q, num)
        
        if y == 1 or y == num-1:
            # Number could be prime,
            #  skip the rest of the code
            continue
        
        for _ in range(r - 1):
            #y = (y**2) % num
            y = pow(y, 2, num)
            if y == num-1:
                # Number could be prime
                break
        else:
            # If none succeed, number is not prime
            return False
    return True

Check Miller-Robin implementation with prime numbers in tables (already known to be prime)

In [4]:
print(miller_robin(4813))

True


In [5]:
prime = [237091, 237137, 237143] # List of prime numbers to check if func is working
print(miller_robin(145863, 1))

False


## Creation of two prime numbers
Take randomly one number and check if it is prime with the Miller-Robin primality test

In [6]:
L = 512
p = 4
n_test = 0

# Cycle until a prime number is found
while miller_robin(p, 10) == False:
    p = random.randint(2**L, 2**(L+1)-1)
    n_test += 1

print(f'Numbers tested: {n_test}')
print(p)

Numbers tested: 38
17190683411255532182861363686309289566882635705069473076422375533982160072078504792628717805848198909058837222327648570979195823968415510036126928197810783


In [7]:
%timeit miller_robin(2341606047)

12.4 µs ± 86.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


### Make a compact function

In [8]:
def create_prime_num(L, niter=None):
    p = 4
    while miller_robin(p, niter) == False:
        p = random.randint(2**L, 2**(L+1)-1)
    return p

## Compute gcd between two numbers
The function is needed to check if e is coprime with m

In [9]:
def gcd(n1, n2):
    '''
    Compute the greatest common divisor between the two numbers
    
    INPUT:
        n1, n2: int
    
    OUTPUT:
        gcd: int
    '''
    
    # Preliminary checks
    if n1 == n2:
        return n1
    if n1 < n2: # n1 must be the greatest
        n1, n2 = n2, n1
    
    # a = qb + r
    b = n2
    while n1 % n2 != 0:
        q = n1 // n2
        n1, n2 = n2, n1 % n2
    return n2

In [10]:
# p: 51*689 q: 51*3345
gcd(35139, 170595)

51

## Extended Euclidean algorithm
Needed to find the multiplicative inverse of e

In [11]:
def gcdExtended(a, b): 
    if abs(a) == 1:
        return a, 0
    elif abs(b) == 1:
        return 0, b

    r_old, r = a, b
    a_old, a = 1, 0
    b_old, b = 0, 1

    while r != 0:
        q = r_old // r
        r_old, r = r, r_old - q * r
        a_old, a = a, a_old - q * a
        b_old, b = b, b_old - q * b

    return a_old, b_old

In [12]:
a,b = gcdExtended(237091, 237137)
print(f'a: {a}; b: {b}')

a: -67017; b: 67004


## Compute all data for RSA
Use previous implemented functions to compute all requested parameters for RSA

In [13]:
L = 512
p, q = create_prime_num(L), create_prime_num(L)

n = p*q
m = (p-1)*(q-1)

ck = 0
while ck != 1:
    e = random.randint(1, m)
    if gcd(e, m) == 1:
        d,_ = gcdExtended(e, m)
        ck = (e*abs(d)) % m

print(f'e: {e}\n\nd: {d}')
print(f'\nck: {ck}')

e: 41332174336127729968396966022813715510795123694526512049019721122032338092950068502653231017230509328160657806825702657402153477690436607704554008977520598536771094610861572036693158986614019824864827308268810398870433779485491166052801327931854445809980647689124796526708817287392160398181810933960144334523

d: 22060925812435142817186411515187434644300194637150772036104454198415950602674498980380917777553273163172200880539485872068021825249617009835236970054745986677841904071182318911759168675132561879245788500623117998769389196164039243608265110484566924026767077257241245101187438224555322470840331856853650044267

ck: 1


Verify $d$ by computing the product $ed \bmod m$

### Test the encryption/decryption process

In [14]:
# n, e, d are provided by the code before
Kpub = [n, e] # Publick key for Bob
Kpri = d      # Private key for Alice

# Compute ciphertext from the known plaintext "1234567890"
encr = pow(1234567890, Kpub[1], Kpub[0])

# Compute plaintext from the cypertext
decr = pow(encr, Kpri, n)

print(f'Original message: 1234567890\nEncrypted: {encr}\nDecrypted: {decr}')

Original message: 1234567890
Encrypted: 126400850010232192146520584330097987222641815939137317421444919078226243755934891955174728074864494601566291789047293443113792313699430115131109944516095523484293713621970723203634482696139735043612468984230468778410755265325336068478834998049956867869402506752133047953516973411754644018955031304236444962488
Decrypted: 1234567890


# Packing all code in a class
All the previous functions 

In [15]:
class myRSA(object):
    '''
    RSA Encryption/Decryption class
    
    Methods:
        - __init__(KSize, debug)
        - static encrypt(plain, kPub)
        - decrypt(encr)
    
    Attributes:
        - kPub: list of int
            get public key as list [n, e]
        - KSize: int
            return the key size
    ''' 
    
    def __init__(self, KSize=512, debug=False):
        '''Init RSA object, create public and private key'''
               
        if KSize < 256:
            KSize = 256
        self._L = KSize
        self._niter = 40
        
        self._p, self._q = self._create_prime_num(), self._create_prime_num()
        
        self.n = self.p*self.q
        self.m = (self.p-1)*(self.q-1)
        
        ck = 0
        while ck != 1:
            self._e = random.randint(1, self.m)
            if self._gcd(self.e, self.m) == 1:
                self._d,_ = self._gcdExtended(self.e, self.m)
                ck = (self.e*abs(self.d)) % self.m
        
        if debug:
            print("Key generated")
    
    @staticmethod
    def encrypt(plain, kPub):
        # If plain is provided as string, convert to a number
        if isinstance(plain, str):
            plain = int.from_bytes(bytes(f'{plain}'.encode("utf-8")), 'big')
        
        return pow(plain, kPub[1], kPub[0])
    
    def decrypt(self, encr):
        return pow(encr, self.d, self.n)
    
    # Return public key [n, e]
    @property
    def kPub(self):
        return [self.n, self.e]
    # Public key is generated internally
    @kPub.setter
    def kPub(self, val):
        raise AttributeError('Denied')
    
    # Return p number
    @property
    def p(self):
        return self._p
    # p is generated internally
    @p.setter
    def p(self, val):
        raise AttributeError('Denied')
        
    # Return q number
    @property
    def q(self):
        return self._q
    # q is generated internally
    @q.setter
    def q(self, val):
        raise AttributeError('Denied')
    
    # Return e number
    @property
    def e(self):
        return self._e
    # e is generated internally
    @e.setter
    def e(self, val):
        raise AttributeError('Denied')
        
    # Return d number
    @property
    def d(self):
        return self._d
    # d is generated internally
    @d.setter
    def d(self, val):
        raise AttributeError('Denied')
    
    # Return Key Size
    @property
    def KSize(self):
        return self._L
    # Key Size must be set once on __init__ call
    @KSize.setter
    def KSize(self, val):
        raise AttributeError('Denied')
    
    @staticmethod
    def _gcd(n1, n2):
        '''
        Compute the greatest common divisor between the two numbers

        INPUT:
            n1,n2: int
        
        OUTPUT:
            gcd: int
        '''

        # Preliminary checks
        if n1 == n2:
            return n1
        if n1 < n2: # n1 must be the greatest
            n1, n2 = n2, n1

        # a = qb + r
        b = n2
        while n1 % n2 != 0:
            q = n1 // n2
            n1, n2 = n2, n1 % n2
        return n2
    
    @staticmethod
    def _gcdExtended(a, b): 
        if abs(a) == 1:
            return a, 0
        elif abs(b) == 1:
            return 0, b

        r_old, r = a, b
        a_old, a = 1, 0
        b_old, b = 0, 1

        while r != 0:
            q = r_old // r
            r_old, r = r, r_old - q * r
            a_old, a = a, a_old - q * a
            b_old, b = b, b_old - q * b

        return a_old, b_old
    
    def _create_prime_num(self):
        p = 4
        while self._miller_robin(p, self._niter) == False:
            p = random.randint(2**(self._L), 2**(self._L+1)-1)
        return p
    
    @staticmethod
    def _miller_robin(num, niter=None):
        '''
        Miller-Robin algorithm implementation
        Check if a number is prime
        
        With default niter prob of num erroneusly detected
            as prime is 8e-25
        
        INPUT:
            num: int
                number to be checked
            niter: int (default 40)
                number of iterations to be performed
        
        OUTPUT:
            result: True|False
                True if number is prime
        '''
        
        if niter is None:
            niter = 40

        if num % 2 == 0:
            return False

        # Write p = s^r*q + 1
        r, q = 0, num - 1
        while q % 2 == 0:
            r += 1
            q //= 2

        # Repeat niter times
        for _ in range(niter):
            x = random.randint(1, num-1)
            #y = (x**q) % num
            y = pow(x, q, num)

            if y == 1 or y == num-1:
                # Number could be prime,
                #  skip the rest of the code
                continue

            for _ in range(r - 1):
                #y = (y**2) % num
                y = pow(y, 2, num)
                if y == num-1:
                    # Number could be prime
                    break
            else:
                # If none succeed, number is not prime
                return False
        return True
    
    def __str__(self):
        _str = f'Public Key:\nn: {self.n}\n\ne: {self.e}'
        return _str
    
    def __repr__(self):
        return f'--- RSA implementation ---\n{self.__str__()}'

In [22]:
rsa = myRSA()
print(rsa)

Public Key:
n: 296734531723323797017682564905439498697331234885792742678802100344832068645767602576994094791108773438107177078703302349242794680816035874926311703414143439948704720382332280814410876718322771531643286077206734836944005571168155180480500183713523921179367798204117071704121330579884985666268983587103814677189

e: 74662175262538067940949679689305067976389012485496164525119114582418814282697472104780356803619638831027851437441038255764264771031335203504608153914811353163334902016340610279703113218964085872712218950969795309579090390002716318754262932431334964568828696259136456014152581612912637579167426067765167345683


## Test class implementation

In [16]:
rsa = myRSA(debug=True)

# Compute ciphertext from the known plaintext "1234567890"
encr = rsa.encrypt(1234567890, rsa.kPub)

# Compute plaintext from the cypertext
decr = rsa.decrypt(encr)

print(f'Original message: 1234567890\nEncrypted: {encr}\nDecrypted: {decr}')

Key generated
Original message: 1234567890
Encrypted: 393999596407271049607514490623741818597497719011080548770253852927327217159269413106144013220268181615635645544575937322031392753430574704670110282007276905322170521964341381707160771307338135834274626047339772760452893572676164757066659464131031228765378147948526103141240230910520660325200549188920295903222
Decrypted: 1234567890


In [17]:
rsa = myRSA(debug=True)
mess = "Hello world!"

# Compute ciphertext from the known plaintext "1234567890"
encr = rsa.encrypt(mess, rsa.kPub)

# Compute plaintext from the cypertext
decr = rsa.decrypt(encr)
decr = str(decr.to_bytes(12, byteorder='big'))

print(f'Original message: {mess}\nEncrypted: {encr}\nDecrypted: {decr}')

Key generated
Original message: Hello world!
Encrypted: 90792472862784219600237927895167038538309081218865072194012391355703315813969051781559683358820244176402524013217480966588845476645815653774173704654549495977837592016218585049952871514319194449563522146259902303823097877004020875048876794172912853245550337638251072336132996631704142825966261529518218804521
Decrypted: b'Hello world!'


# Test implementation with public online RSA enc/dec tools
For this purpose I used primitive functions in Cryptodome. The motivation is in the way the key are shared.<br>
A base64 encoded key, that is the way the keys are copied, does not contain only the numbers that this implementation compute. Public and private keys are packed with other information about the encoding type, description etc.<br>
In order to verify the implementation with public online tools I need to add this informations as well. This work is done by the library function, while all values are computed by my implementation.<br>
Reference here: [construct function](https://pycryptodome.readthedocs.io/en/latest/src/public_key/rsa.html#Crypto.PublicKey.RSA.construct)

In [52]:
import base64
from Crypto.PublicKey import RSA

rsa1 = myRSA()
object_key = RSA.construct((rsa1.n, rsa1.e)) # .construct function accept a touple with only (n,e) for building public key

string_key = object_key.exportKey(format='DER') # Export as binary
b64_string_key = base64.b64encode(string_key) # Encode in base64 to have a printable object that can be copied/pasted
print(b64_string_key)

#key = RSA.importKey(base64.b64decode(str_key))
#print(key)

#priv = base64.b64decode(str_key)
#str(priv, "utf-8")
#str(priv)

b'MIIBIDANBgkqhkiG9w0BAQEFAAOCAQ0AMIIBCAKBgQILhjcMcegdHvfc5XSeM9GXP+tXgr1iKIv+HYos56XBRlrYxFJC6UXQUOjH9jSlLDg1MGzrz8ytd2/vmXH7/TMJ5dNBTmxt1gl+6/4ggThL+qmzd/2sO+C4rzb0mgBdYk2OigiJhqlmMSJTfujdM7PEhPKh18dL1QKtY2fEsrQ9BwKBgQDXt1Sa3YqSnpfFeqamdd2qsfLhE8q3WESRYlzDbIak0Dn2++vB3awaY0Yo2Py1FlutM+kUXDSHm23gAda2+KdndvB1nrUdZ7PC3l6fSK7Z8KdvyfKqlaNWfdddho/djwyRlZVY1doNX8YCGZRhq9gsa0nOCDsM3g1fke3ZqQ2aNw=='



The public key produced by RSA library function has been converted to base64 encoding and given to [online RSA decr/encr](https://www.devglan.com/online-tools/rsa-encryption-decryption). The plaintext inserted was "sample text to try RSA encoding".<br>
The provided ciphertext has been copied in the variable b64_encr.

In [68]:
b64_encr = "Ae0ODdiaA4oGWm4ezss/6UBueNBDU/yZ7VC2/CUQLTBQkzKti2G4q+Z1Zr9N3L7qgN96C0f1q9oSiai8JUmaf5fpcuSqWs7BMJaaYTHPnyHYWUviN7gBAzPXiyPjWQcNZjcxQSvZrhUYOh+7PlIFtGqJGPELgnmZmXZNn24QaEcs"

byte_encr = base64.b64decode(b64_encr)
int_encr = int.from_bytes(byte_encr, 'big')
print(int_encr)

int_plain = rsa1.decrypt(int_encr)
str_plain = int_plain.to_bytes(150, byteorder='big')
print(f'Plain text: {str_plain}')

346234924467792436865103504756694979092436575621472142977695190282534596739260169898993362238330337411956032487348261981523899175705341526819315748415146737032685502367619652351003132775773183255551210167816115177630366174472618999668254065591303540431014998403923128478658067141546207732388540018043050805036
Plain text: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x99\xc9\x9a\xb5|\xad\x8fxdX\x8an\xf2H\xc1\xaeH\xa2`E\xd3\xf4\xcd0d\x8b\xfc\x14g\xdfq\xcb\x841\x90l\x06\x08"/\xc5\x1f\x19V\xb2\x85?\xe2m\x9a%\xe4A\x91\x1bA\xc2\x8c9y\xa9\xc8-YQ\x1c\xb9\xd60"\xd9m\xfeb\xadt\xd4\xe1\xc7\xb5)\xb3\x16\x8eS\x16\xcb\xe4{<\xf8\xcb\xe4\x9b\xf7\x00sample text to try RSA encoding'


The resulting encoded string, in base64 encoding, has then been converted into an integer to be decrypted by `.decrypt()` custom function.<br>
The integer plaintext is then converted in a bytes array and then printed.
<br><br>
The printed string shows non-printable characters and, at the end, the given plaintext inserted in the website.<br>
Probably other chars encodes informations like encryption method used, and sould be non-encrypted. I was not able to get something useful.

In [34]:
encodedBytes = base64.b64encode(f'{encr}'.encode("utf-8"))
encodedStr = str(encodedBytes, "utf-8")

priKBytes = base64.b64encode(f'{d}'.encode("utf-8"))
priKStr = str(priKBytes, "utf-8")

print(f'Encoded message: {encodedStr}\n\nKey: {priKStr}')

Encoded message: NjgxMTMxNjc4NzczNzU0NTA5NzExNzczNjg5MDMwNzI2ODA3MTE3OTgzMTI1MzQzMTk2MDM0MDY1MDAwMTMzODcyNjcwMTAzNDYzMDA2ODk5MTU2NTg3OTQxNzc2MjUwNjEzNzM5OTI2NDg2Njg4NzQ2OTQwMjg2Mjk0MzgxMzM0OTAyNzU1Mzk0NjYzNzU1NzkwOTgxNDQxNTk5MjQ0NTY0NzM3MDA2Mzc4NTcwMzk2OTY0NjI4OTI4MjA4NTk2NjE1NzA5OTM0MTIzMTQ0NTU1ODg0MjgwMjI5MTE4MTcxNjQ4MjQyMjI5MjgwODU2NjgxODE4OTU5MDk0OTA1NzA3NTAwNTg3NzE2MDA2MDE5NzU0NTQ1MjU1NDE5NjUwNzc4MTM0NzkzMTg3MTU4MjM5NDI3

Key: MTYzODIyOTYzMTE0Njk5MzY2NjE3NDU3ODY5MTI4NzY5MDAzNTYwNDYxNzQ5Mzk2MzEwNjY1NDI3NDQ1MTExODcwOTYzMTYwMzYyNzY4NzMwMzEwMzQwODgxNjI0MzI5NzU2ODY4Mzc0OTg1MTUyMzcyOTMwOTM0MTAwODcyMzcwMjE0MjIyNjU3NTc2MjM4Mzc1MTAyNjgzNjE4MzU4ODgwNTI1MTM1NjU1NjQ0NjA0Njg4NjQzNjg3NDUwNDcxNjQ4NzM3NDI1MDAxMTkyMzI4NzA1NjQwOTEwMTU3MjUxMzY4NjgyOTY2MjQyNjEzMTgwMDU5MTI0ODE2MjA3MjUyODc1ODk2NjU1NjM4MDkxMDQwODU2NzU4ODAxNDI2MDcwNzc5MzczMzMwMTg1ODYxNTc5


# Conclusion
The implementation works straight forward with a integer as plaintext. All the algorithms have been implemented as seen during lectures, with no particular optimizations.<br>
Some improvements can be done in relation to the choice of `e`. In particular, the RSA documentation specify that usually this number is chosen with small \#1 in his binary rapresentation. This allows to speed up the exponentiation process thanks to the algorithm discussed in the lecture.
<br><br>
I faced many difficulties in the implementation with string as plaintext, in particular translating text into an integer value to be digested by the algorithm.<br>
Tests with public services for encryption/decryption with RSA are difficult due to the same problem, aggravated by the fact that keys and cipher are encoded with base64 encoding system and other standards. It results in many tests on how to convert string to bytes and then to integer, in a way that online tools are able to work with.
<br><br>
Initial problems related to finding prime numbers are due to the fact that exponentiation with $x^{e} \bmod n$ is very low in performance wrt `pow(x,e,n)` function. Computation with $x^{e} \bmod n$ was very slow also with relatively small numbers with 6/7 digits. The exponentiation could also have been replaced with the iteratie method exposed in the lecture.

In [60]:
%timeit -n1 tt = (123**254081) % 237137

172 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [59]:
%timeit -n1 tt = pow(123, 254081, 237137)

10.2 µs ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [61]:
p = 123**254081
%timeit p % 237137

828 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
