# RSA signatures

Sigurjón Ágústsson

First, we import libraries and load the first 1000 primes in an array. We will use them when generating probable prime numbers later.

In [2]:
import random
import time
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt

In [3]:
def load_small_primes():
    pfile =  open("1000primes.txt", "r")
    small_primes = [int(num) for num in pfile.read().split(',')]
    return small_primes

small_odd_primes = load_small_primes()
print(f"{len(small_odd_primes)} small odd primes loaded : {small_odd_primes[:5]}...{small_odd_primes[-3:]}")

999 small odd primes loaded : [3, 5, 7, 11, 13]...[7901, 7907, 7919]


First we define a couple of helper functions that are useful when printing large numbers. 

In [4]:
# Miscellaneous functions
def std_form(n):
    """Returns a number n in standard form (string)"""
    string_n = str(n)
    num = string_n[0] + "." + string_n[1:3] + f"*10^{len(string_n)-1}"
    return num

def seconds_to_prefix(sec):
    """Just a function that inputs number of seconds (float) and 
        outputs a string representation with a unit (s, ms or min) """
    if sec < 1:
        ms = sec * 1000
        return "{} ms".format(round(ms,4))
    elif sec < 60:
        return "{} s".format(round(sec,4))
    else:
        min = int(sec) // 60
        rem = seconds_to_prefix(sec-60*min)
        return "{} min and {}".format(min,rem)
    
def divisible_by_small_prime(n, resolution=1000):
    """A function that checks if a number n has one of the 
        1000 first primes as a factor (excluding 2 because 
        it is checked for already)"""
    global small_odd_primes # The array of primes
    for small_prime in small_odd_primes[:resolution]:
        if n == small_prime:
            # In this case n is a small prime number, and we return
            # False to indicate it is not divisible by any of
            # the smaller primes
            return False
        if (n % small_prime) == 0:
            return True
    return False

def gcd(a,b):
    """Naive gcd function"""
    r = a%b
    if r==0:
        return b
    return gcd(b,r)

Now we define the important functions, 

In [91]:
def extended_euclidiean(a,b):
    """The extended euclidian algorithm"""
    r = [a,b]
    s = [1,0]
    t = [0,1]
    i = 2
    while True:
        q_im1 = r[-2] // r[-1]
        r_i = r[-2] - q_im1 * r[-1]
        s_i = s[-2] - q_im1 * s[-1]
        t_i = t[-2] - q_im1 * t[-1]
        
        if r_i == 0:
            u,v = s[-1],t[-1]
            break
        r.append(r_i)
        s.append(s_i)
        t.append(t_i)
    return u, v

def montgomery(a, R, N):
    """Returns the montgomery form of the integer a with respect to N"""
    return (a*R) % N

def exp_powering_with_reductions(a,b,N, debug=False):
    """Returns a**b mod n"""
    # First we choose R, it should be a power of two ! works only for odd N!
    if not N&1:
        print("warning N is even: {}".format(std_form(N)))
    pwr_of_R = math.ceil(math.log(N)/math.log(2))
    R = 1<<pwr_of_R
    # Then we precompute R',N'
    u,v = extended_euclidiean(R,N)
    Rp,Np = u+N, R-v
    res_bar = montgomery(1,R,N)
    a_bar = montgomery(a,R,N)
    # we assume b = b_0b_1...b_{k-1}. we want to iterate over all bits
    for i in range(b.bit_length()):
        b_i = (b&(1<<(b.bit_length()-1-i)))>>(b.bit_length()-1-i)
        res_bar = res_bar * res_bar
        res_bar = REDC(res_bar, R,Rp, N,Np, pwr_of_R, debug)
        if debug: assert res_bar <= N, f"{res_bar} is not in interval [0, N-1] N={N}"
        if b_i:
            res_bar = res_bar * a_bar
            res_bar = REDC(res_bar, R,Rp, N,Np, pwr_of_R, debug)
            if debug: assert res_bar <= N, f"{res_bar} is not in interval [0, N-1] for N={N}"
    res = montgomery(res_bar, Rp, N)
    return res    

def REDC(T, R,Rp,N,Np,pwr_of_R, debug=False):
    # T is res_bar
    m = ((T % R)*Np) % R
    assert R == 2**pwr_of_R, f"the power is incorrect, {R}!=2**{pwr_of_R}={2**pwr_of_R}"
#     t = (T + m*N)>>pwr_of_R
    t = (T+m*N)//R
    if debug: assert t<2*N , f"N={N}, t={t} is too large"
    if t>=N:
        return t-N
    else: 
        return t
    
def oracle(M, N):
    """decides whether the first reduction of (a**b) is computed with a subtraction or not"""
    # recall we are computing M**b mod N where b is the secret number we want to find
    square = M**2
    if square >= N:
        cube = square*M
        if cube>=N:
            return True
        else:
            return False
    else:
        return False
        
    
   

In [84]:
N = 13*19
R = 1<<(math.ceil(math.log(N)/math.log(2)))
u,v = extended_euclidiean(R,N)
Rp,Np = u+N, R-v
print(N,Np, ",", R,Rp)
print("RR' - NN':")
print(-N*Np + R*Rp)

247 313 , 256 302
RR' - NN':
1


In [51]:
a,b,N = 5,7,13
print("size b: ", b.bit_length())
exp_powering_with_reductions(a,b,N)

for i in range(100):
    continue
    print(f"{N}:")
    exp_powering_with_reductions(a,b,N)
    print("---")


size b:  3
extra REDC
extra REDC
extra REDC


In [18]:
############## not in use ###################
def modular_exponentiation_old_outdated(a,n):
    """Python's fast way of doing modular exponentiation"""
    """Returns a**(n-1) mod n"""
    return pow(a,(n-1),n)

def modular_exponentiation_outdated(a,N):
    """Python's fast way of doing modular exponentiation"""
    """Returns a**(N-1) mod N"""
    return exp_powering_with_reductions(a,(n-1), N)

In [19]:
a,b,N = 5,7, 13
print("a,b,N=", a,b,N)
print(f"normal: {pow(a,b,N)}\nwith reductions: {exp_powering_with_reductions(a,b,N)}")
a,b,N = 5,91, 131
print("a,b,N=", a,b,N)
print(f"normal: {pow(a,b,N)}\nwith reductions: {exp_powering_with_reductions(a,b,N)}")
      

a,b,N= 5 7 13
normal: 8
with reductions: 8
a,b,N= 5 91 131
normal: 53
with reductions: 53


In [57]:
f, t = 0,0
#N = small_odd_primes[-1]
print(a,b,N)
for _ in range(100):
    m = random.randint(0,N)
    val = exp_powering_with_reductions(a,b,N)
    if oracle(m, N):
        t += 1
    else:
        f += 1
    
print("successful")
print(f"subtraction performed {f} times \nNot perfomred {t} times")

5 7 13
successful
subtraction performed 25 times 
Not perfomred 75 times


Let's demonstrate with a timing example.

In [58]:
bits = 20
# Let n be a random 20 bit integer
n_odd = True
while n_odd:
    n = random.randint((1<<(bits-1)),((1<<bits)-1))
    if n&1:
        break
a = random.randint(0xFFFF,n-1) # a very large number but still less than n


ex1_start = time.perf_counter()
print(a**(n-1) % n )
ex1_end = time.perf_counter()

ex3_start = time.perf_counter()
print(pow(a,n-1,n) )
ex3_end = time.perf_counter()
ex2_start = time.perf_counter()
# print(modular_exponentiation_(a,n))
print(exp_powering_with_reductions(a, n-1, n))
ex2_end = time.perf_counter()
t1,t2,t3 = (ex1_end-ex1_start), (ex2_end-ex2_start), (ex3_end-ex3_start)
strt1,strt2,strt3 = seconds_to_prefix(t1), seconds_to_prefix(t2), seconds_to_prefix(t3)
diff = round(t1/t2,2)
diff2 = round(t2/t3,2)
print(" {}\t- naive algorithm \
        \n {}\t- using montgomery multiplication \
        \n {}\t- python's builtin modular multiplication".format(strt1,strt2,strt3))
print("The Montgomery algorithm makes modular exponentiation {} times faster in this case!".format(diff))
print("Python offers a further {} times increase in speed".format(diff2))

388119
388119
388119
 1.0328 s	- naive algorithm         
 0.1051 ms	- using montgomery multiplication         
 0.0484 ms	- python's builtin modular multiplication
The Montgomery algorithm makes modular exponentiation 9826.89 times faster in this case!
Python offers a further 2.17 times increase in speed


# BELOW THIS POINT
### The code is a bit of a mess, that is because I still have not figured out how the oracle works

In [59]:
raise NotImplementedError

NotImplementedError: 

## Prime number generation

Now we are ready to do generate prime numbers. 

In [75]:
def probable_prime(size):
    assert size > 0
    s = size-1
    smallest = 1<<s
    largest = (1<<(s+1))-1
    while True:
        # First pick a random number n of size 'size' bits
        n = random.randint(smallest,largest)
        if (not n&1) or divisible_by_small_prime(n):
            continue
        a = random.randint(2,n-1)

#         mod_computation = modular_exponentiation_old(a,n)
        mod_computation = exp_powering_with_reductions(a,(n-1),n)

        if mod_computation == 1:
            # By Fermat's little thm, if the result is not 1
            # then n is not prime, otherwise, it is likely to 
            # be prime
            return n

Let's use this to generate some "small" probably-prime numbers, or pseudoprimes.

In [85]:
for i in range(4):
    bits = 8*(i+1)
    print("{} bits: \t".format(bits), probable_prime(bits))

8 bits: 	 131
16 bits: 	 65479
24 bits: 	 15125441
32 bits: 	 2239310111


In RSA encryption we want much larger primes. Typically 1024 or 3072 bits. That is no obstacle for us, let's generate such pseudoprimes:

In [86]:
t1s = time.perf_counter()
p1024 = probable_prime(1024)
t1e = time.perf_counter()
print(p1024)
print("execution time: {} seconds".format(round((t1e-t1s),3)))

161893610166330130096494242968505876783895218554774147432167829488361997535014498222644044034360738004222229687752459298683288303123121185112965456815158388689486433843813312190204327032306715495659169264235240692590345626312292364487460312916337840592075152978389039402820933694081761747729051416225115724811
execution time: 0.27 seconds


In [78]:
t1s = time.perf_counter()
p3072 = probable_prime(2048)
t1e = time.perf_counter()
print(p3072)
print("execution time: {} seconds".format(round((t1e-t1s),3)))

27670654601581534075703186568095028570617252279032169172813483087019428480251596725533769154932008522824878441747522770846479580798926391140066578725888527721997267168130103056445818228672934271069519151149779496180970807949549638638653559584784191289329903883324289120454687363225504731497429229752450223873038618828195064291385431421755024891082128214491761896979298865252389000547540639959672710836791865478518989954023611810922406837622936693637024721494348981558288626778639499078426170674856500888315395167400025982691791357181497828782795910289268733166520708841342994234306448180143819030804193207531931446591
execution time: 22.367 seconds


## RSA signatures

In [87]:
e = 3
print(pow(e,-1,500))

167


In [88]:
def find_ed(N,T):
    e = 3
    d = pow(e,-1,T)
    return (e,d)

In [92]:
def generate_keys(size=512):
    """Function that returns a tuple of public and private keys of given size (in bits)"""
    p,q = probable_prime(size),probable_prime(size)
    N = p*q
    # e is less than T. And it is coprime to T and N
    # e*d mod T has to be 1
    # e is typically 3
    T = N - p - q + 1 # (p-1)(q-1)
    e = 3
    d = exp_powering_with_reductions(e,-1,T)
    assert (N&1), f"{N} is even"
    assert d != 0 
#     d = pow(e,-1,T)
    return ((N,e),(N,d))


def encrypt(val, public_key):
    """public key is a tuple containing N,e """
    N,e = public_key
    return exp_powering_with_reductions(val,e,N)
#     return pow(val,e,N)

def decrypt(cipher, private_key):
    """Private key is a tuple containing N,d """
    N,d = private_key
    return exp_powering_with_reductions(cipher, d,N)
    return pow(cipher,d,N)

In [99]:
while True:
    try:
        public_key, secret_key = generate_keys(100)
        break
    except Error as e:
        print(e)
        continue
# public_key, secret_key = dummy_keys()



In [100]:
public_key

(1106603033275934659886298762961330757449565658268905752902093, 3)

In [101]:
secret_key

(1106603033275934659886298762961330757449565658268905752902093,
 522268173713512477670467348266278816751755915702448601991902)

In [102]:
text = "HELLO WORLD!"

In [103]:
def str2ascii_list(string):
    return [ord(c) for c in string]

def ascii_list2str(alist):
    s = ""
    for c in alist:
        s+=chr(c)
    return s

def encrypt_string(string, public_key):
    ascii_list = [ord(i) for i in string]
    enc_list = [encrypt(j,public_key) for j in ascii_list]
    print(enc_list)
    enc_string  = ""
    for enc in enc_list:
        enc_string += chr(enc)
    return enc_list

def decrypt_string(enc_list,private_key):
    ascii_list = [decrypt(j, private_key) for j in enc_list]
    return ascii_list

In [104]:
text_asc = str2ascii_list(text)
print("original:")
print(text_asc)
print("encrypted:")
enc = [encrypt(c, public_key) for c in text_asc]
print(enc)
print("decrypted_list")
dec_list = [decrypt(c, secret_key) for c in enc]
print(dec_list)
print("result")
res = ascii_list2str(dec_list)
print(res)

original:
[72, 69, 76, 76, 79, 32, 87, 79, 82, 76, 68, 33]
encrypted:
[373248, 328509, 438976, 438976, 493039, 32768, 658503, 493039, 551368, 438976, 314432, 35937]
decrypted_list
[172315095570825730831307061129040642613707564299662250698979, 1016176660477914548138634116114045509080130145241916367147129, 862031913942412796801973562248356747943007354478504912955334, 862031913942412796801973562248356747943007354478504912955334, 233752557393709477898148919475078297904711709522924697334881, 1004564470272174430652852718178448999298423568464630193839763, 287827587604847074326092599195807080472694721414920692400222, 233752557393709477898148919475078297904711709522924697334881, 918540255434630152019763879193672837801032976678177238554034, 862031913942412796801973562248356747943007354478504912955334, 892896174020101958997289199983415081219287678967870539148091, 994209852245747913642266295188199620501960548752705757650427]
result


OverflowError: Python int too large to convert to C int

In [None]:
# ASCII has values from 32 to 126
# print("Public,private keys are: ", public_key, secret_key)
for i in range(32,127):
    enc = encrypt(i, public_key)
    dec = decrypt(enc, secret_key)
    if dec == i:
        print(f"{i}, ASCII: '{chr(i)}', Encryption: '{enc}', Decrypting encryption: '{dec}'")
    else:
        print(f"{i}, ASCII: '{chr(i)}', Encryption: '{enc}', Decrypting encryption: '{dec}' - FAIL")

In [None]:
encrypted = encrypt_string(text,public_key)
print(encrypted)
decrypted = decrypt_string(encrypted,secret_key)
print(decrypted)

Different way to encode strings

In [None]:
def string_to_num(string):
    num = 0
    for c in string:
        new_char = ord(c)
        num = (num<<7)+new_char
    return num

def num_to_string(num):
    string = ""
    while num>0:
        new_char = 0x7F
        new_char &= num # Now new_char is the ascii value of the last character of the string
        string = chr(new_char) + string # We prepend the new character to the string, since we are going in the reverse
        num = num >> 7
    return string

s = "AB"

string_to_num(s)


In [None]:
while True:
    try:
        public_key, private_key = generate_keys(100) # generate keys of size 100
        break
    except:
        continue
num_chars = 5
bits = num_chars * 8
largest, lowest = (1<<bits)-1, (1<<(bits-1))
for _ in range(5):
    num = random.randint(lowest,largest)
    s = num_to_string(num)
    e = encrypt(num, public_key)
    d = decrypt(e, private_key)
    ds = num_to_string(d)
    p  = f"{num} = {s} \t ---> {e} \t ---> {d} = {ds}"
    print(p)


In [None]:
def rand_string(length):
    bits = length * 8
    largest, lowest = (1<<bits)-1, (1<<(bits-1))
    assert largest.bit_length() == bits
    num = random.randint(lowest,largest)
    s = num_to_string(num)
    return s
def rand_legible_string(length):
    bits = length * 8
    s = ""
    for _ in range(length):
        s += chr(random.randint(32,126))
    return s


In [None]:
while True:
    try:
        pkey, skey = generate_keys(50)
        break
    except:
        continue
pkey,skey

In [None]:
data = np.array([])
for _ in range(1000):
    data = np.append(data, rand_legible_string(10))
print(data[:5])
print(len(data))


In [None]:
word = rand_legible_string(10)
word_code = string_to_num(word)
print("legible string: ", word)
print("encoded: ", word_code)

In [None]:
locked = encrypt(word_code,pkey)
print(locked)

In [None]:
unlocked = decrypt(locked,skey)
print(unlocked)

In [None]:
primitive_data = np.array([])

In [None]:
base = (1<<32)-1
print(bin(base))
for i in range(5):
    if i >0 :
        num = (1<<(8*i))-1
        shift = 32-8*i
    else:
        num = 1
        shift = 31
    num = num<<shift
    print(bin(num), len(bin(num))-2)

In [None]:
def generate_simple_data(bit_length):
    data = []
    for i in range(bit_length//8 +1):
        if i > 0 :
            num = (1<<(8*i))-1
            shift = bit_length-8*i
        else:
            num = 1
            shift = bit_length-1
        num = num<<shift
        data.append((num, bit_length-shift))
    return data

In [None]:
data = generate_simple_data(128)
len(data)

In [None]:
test = []
for num,size in data:
    tstart = time.perf_counter()
    decrypt(num,skey)
    tend = time.perf_counter()
    res = tend-tstart
    arr = [size, res]
#     print(arr)
    test.append(arr)
#     test[size] = tend-tstart

In [None]:
results = np.array(test)
x_values = results[:,0]
y_values = results[:,1]
results[:5]

In [None]:
dataframe = pd.DataFrame(results)
dataframe

In [None]:
plt.plot(x_values, y_values)

In [None]:
seconds_to_prefix(results[0][1])

In [None]:
messages = [string_to_num(rand_legible_string(10)) for _ in range(10000)]
print(messages[:3])
print(pkey, skey)

Remember public key is $N, e$

In [19]:
# We want to create two sets of random variables. One where the first bit is
M1, M2 = [],[]
N, d = skey
for m in messages[:10]:
#     print("message: ", str(m)[:3]+"...")
    val = exp_powering_with_reductions(m, d, N)
    if oracle(m,N):
        M1.append(m)
    else:
        M2.append(m)
print(f"M1 has {len(M1)} messages")
print(f"M2 has {len(M2)} messages")

NameError: name 'skey' is not defined

In [21]:
M = [i for i in range(2**4)]
pkey_candidate, skey_candidate = generate_keys(4)
pair1, pair2 = None, None
while True:
    N,d = skey_candidate
    assert d!=0
    print(skey_candidate)
    b1 = d>>(d.bit_length()-1)
    if b1&1:
        pair1 = pkey_candidate, skey_candidate
    else:
        pair2 = pkey_candidate, skey_candidate
    pkey_candidate, skey_candidate = generate_keys(4)   
    if pair1 is not None and pair2 is not None:
        break
print("public key: ", public_key)
print("secret_key:", secret_key)
print(f"All values: [{M[0]}, {M[1]}, ..., {M[-1]}]")
print(f"  = [{bin(M[0])}, {bin(M[1])}, ..., {bin(M[-1])}]")

NameError: name 'generate_keys' is not defined

In [17]:
# def oracle(a,b,N):
def oracle(M, N):
    """decides whether the first reduction of (a**b) is computed with a subtraction or not"""
    pwr_of_R = math.ceil(math.log(N)/math.log(2))
    R = 1<<pwr_of_R # R is the smallest power of 2 that is 
    # Then we precompute R',N'
    u,v = extended_euclidiean(R,N)
    Rp,Np = u+N, R-v
    res_bar = montgomery(1,R,N)
    a_bar = montgomery(M,R,N)
    b_1 = ... # we are trying to deduce this bit!
    res_bar = res_bar * res_bar
    res_bar = REDC(res_bar, R,Rp, N,Np, pwr_of_R)
    # if b_1 is 1 then this part will be executed
    res_bar = res_bar * a_bar
    # will a reduction occur after this step??
#     res_bar = REDC(res_bar, R,Rp, N,Np, pwr_of_R)
    T = res_bar
    # T in [0, RN − 1]
    assert 0<=T<= R*N -1
    print(f"0<= {std_form(T)} <= RN-1={std_form(R*N-1)}")
    m = ((T % R)*Np) % R
#     print(f"t is: {(T + m*N)} >> {pwr_of_R}")
    t = (T + m*N)>>pwr_of_R
    if t>=N: return True
    return False

In [18]:
M1, M2 = [],[]
N, d = skey
print("b_1 is: ", bin(d)[3])
print(bin(d))
for m in M:
#     print("message: ", str(m)[:3]+"...")
    val = exp_powering_with_reductions(m, d, N)
    if oracle(m,N):
        M1.append(m)
    else:
        M2.append(m)
print(f"M1 has {len(M1)} messages")
print(f"M2 has {len(M2)} messages")

NameError: name 'skey' is not defined