# Precomputation

We start by precomputing various RNS representations in the constructor of the `MontgomeryArithmeticGPU` class. Note that these computations depend only on the modulus $N$ and not on any of the inputs to the Montgomery product.

Keep in mind that you want $N$, $R$ and $R'$ (from the lecture notes) should satisfy the following conditions: 

1. $R > N$ 
2. $(R, N) = (R', N) = (R, R') = 1$

for the RNS bases.


(R,R') means R and R' are relatively prime

# Montgomery Product (GPU-friendly version)

In [1]:
import math
import random
from math import gcd


def to_rns(x, rns_basis): 
    """
    Args:
        x: integer to convert
        rns_basis: RNS basis
        
    Returns:
        The representation of 'x' with respect to 'rns_basis'.  
    """
    # add your code here
    #alloc mem
    x_rns_rep = [0]*len(rns_basis)
    #convert to rns in parallel
    for i,m in enumerate(rns_basis):
        x_rns_rep[i] = x % m # TODO: this should be done in a faster way such as barret reduction
    return x_rns_rep


def from_rns(x_rns, R, rns_basis):
    """
    Converts an RNS representation to an integer. 
    
    Args:
        x_rns: the RNS representation of 'x' with respect to 'rns_basis'
        R: the product of the moduli in 'rns_basis' (we pass it as an argument for efficiency)
        rns_basis: the RNS basis
        
    Returns: 
        The integer 'x'.
    """
    # add your code here
    result = 0
    #every thread should calculate its own result and then it should be added to the final result in parallel
    for i, (m, r) in enumerate(zip(rns_basis, x_rns)):
        # Calculate the coefficient for this residue
        coeff = R // m                
        # Calculate the modular inverse of the coefficient
        inverse = pow(coeff, -1, m)        
        # Add this residue multiplied by the coefficient to the result
        result += r * coeff * inverse    
        
    # Take the result modulo the product of all moduli to get the original number
    return result % R


def add(x_rns, y_rns, rns_basis):
    """
    Adds two RNS representations. 
    
    Args:
        x_rns: RNS representation of an integer 'x' with respect to 'rns_basis'
        y_rns: RNS representation of an integer 'y' with respect to 'rns_basis'
        
    Returns
        The representation of 'x + y' with respect to 'rns_basis'.
    """
    # TODO: add your code here
    #alloc mem
    result = [0]*len(rns_basis)
    #add in parallel
    for i, (m, x, y) in enumerate(zip(rns_basis, x_rns, y_rns)):
        result[i] = (x + y) % m
    return result

def mult(x_rns, y_rns, rns_basis):
    """
    Multiplies two RNS representations. 
    
    Args:
        x_rns: RNS representation of an integer 'x' with respect to 'rns_basis'
        y_rns: RNS representation of an integer 'y' with respect to 'rns_basis'
        
    Returns
        The representation of 'x * y' with respect to 'rns_basis'.
    """
    # add your code here
    #alloc mem
    result = [0]*len(rns_basis)
    #add in parallel
    for i, (m, x, y) in enumerate(zip(rns_basis, x_rns, y_rns)):
        result[i] = (x * y) % m
    return result

def convert(x_rns, R, in_basis, out_basis):
    """
    Converts a RNS basis representation with respect one basis to a RNS representation for a different basis. 
    
    Args: 
        xbar - RNS representation with respect to the basis in_basis
        in_basis - the RNS for the 
        out_basis - the RNS basis for the output
    
    Returns: 
        The RNS representation of 'x' in basis 'out_basis'
    """
    # add your code here
    return to_rns(from_rns(x_rns, R, in_basis), out_basis)
# test the code

rns_basis_test = [3,5,7,11,17]
aux_basis_test= [19,23,29]
R_test  = 1
for m in rns_basis_test :
    R_test  *= m
R_aux_test = 1
for m in aux_basis_test :
    R_aux_test  *= m

print(f"R = {R_test }")
x_test  = 17
x_rns_test  = to_rns(x_test , rns_basis_test )
y_test = 42
y_rns_test  = to_rns(y_test , rns_basis_test )
print(f"{from_rns(x_rns_test , R_test , rns_basis_test )} == {x_test } ? {from_rns(x_rns_test , R_test , rns_basis_test ) == x_test }")
print(f"{from_rns(y_rns_test , R_test , rns_basis_test )} == {y_test } ? {from_rns(y_rns_test , R_test , rns_basis_test ) == y_test }")
print(f"{from_rns(add(x_rns_test , y_rns_test , rns_basis_test ), R_test , rns_basis_test )} == {x_test +y_test } ? {from_rns(add(x_rns_test , y_rns_test , rns_basis_test ), R_test , rns_basis_test ) == x_test +y_test }")
print(f"{from_rns(mult(x_rns_test , y_rns_test , rns_basis_test ), R_test , rns_basis_test )} == {x_test *y_test } ? {from_rns(mult(x_rns_test , y_rns_test , rns_basis_test ), R_test , rns_basis_test ) == x_test *y_test }")
print(f"{from_rns(convert(x_rns_test , R_test , rns_basis_test , aux_basis_test ), R_aux_test , aux_basis_test )} == {x_test } ? {from_rns(convert(x_rns_test , R_test , rns_basis_test , aux_basis_test ), R_aux_test , aux_basis_test ) == x_test }")


R = 19635
17 == 17 ? True
42 == 42 ? True
59 == 59 ? True
714 == 714 ? True
17 == 17 ? True


In [2]:

class MontgomeryArithmeticGPU:
    def __init__(self, N, rns_basis, aux_basis):
        """
        Constructor for Montgomery arithmetic (GPU-friendly version).        
        Args: 
            N - modulus
            R - Montgomery radix
            rns_basis - the RNS basis to be used in the concurrent Montgomery multiplication
            aux_basis - auxiliary basis
        """
        # add your code here
        self.N = N
        self.R=1
        for m in rns_basis:
            self.R *= m        
        self.aux_R = 1 
        for m in aux_basis:
            self.aux_R *= m


        assert (N<self.R), "R must be Larger than N"

        self.rns_basis = rns_basis
        self.aux_basis = aux_basis

        #precomputations 
        self.R_inv_for_N = pow(self.R, -1, N)
        self.N_inv_for_R = pow(N, -1, self.R)
        self.rns_N_inv_for_R = to_rns(self.N_inv_for_R, rns_basis)
        self.rns_neg_N_inv_for_R = to_rns(-1*self.N_inv_for_R, rns_basis)

        self.N_inv_for_aux_R = pow(N, -1, self.aux_R)
        self.rns_N_inv_for_aux_R = to_rns(self.N_inv_for_aux_R, aux_basis)


        self.R_inv_for_aux_R  = pow(self.R, -1, self.aux_R)
        self.rns_R_inv_for_aux_R = to_rns(self.R_inv_for_aux_R, aux_basis)
        #print(f"N invers = {self.N_inv_for_R}= {from_rns(self.rns_N_inv_for_R,self.R ,self.rns_basis   ) } = {self.rns_N_inv_for_R} in rns { self.rns_basis} " ,f"N inverse times N = { from_rns(  mult(self.rns_N_inv_for_R, to_rns(self.N ,self.rns_basis) ,self.rns_basis) ,self.R ,self.rns_basis  ) } ")
        #print(f"R * R invers { self.R * self.R_inv_for_aux_R  %  self.aux_R}  in aux basis {self.aux_basis} ")
        #print( f"rns_R* rns_R_inv_for_aux_R in aux bases  1 == {from_rns(mult( to_rns(self.R,self.aux_basis),self.rns_R_inv_for_aux_R ,self.aux_basis),self.aux_R,self.aux_basis) }")
        self.R_squared = pow(self.R,2,self.N)


    def to_montgomery_form(self, x):
        """
        Converting an integer to Montgomery form. 
        
        Args:
            x - integer to be converted
            
        Returns:
            xbar - the Montgomery form of x, that is xbar = xR mod N
        """
        # add your code here
        return self.montgomery_mult(x ,self.R_squared) # not in parallel and does not use R^2 trick ???? TODO 
        
    def montgomery_mult(self, mx, my):
        """
        Multiplication of two integers (in Montgomery form).

        Args: 
            xbar - Montgomery form of an integer x
            ybar - Montgomery form of an integer y

        Returns:
            Montgomery form of xy.
        """
        print(f"Result shoudl be {mx*my*self.R_inv_for_N % self.N }")# this works so i have a bug ! 
    def montgomery_mult(self, mx, my):
        """
        Multiplication of two integers (in Montgomery form).

        Args: 
            xbar - Montgomery form of an integer x
            ybar - Montgomery form of an integer y

        Returns:
            Montgomery form of xy.
        """
        #print(f"Result shoudl be {mx*my*self.R_inv_for_N % self.N }")# this works so i have a bug ! 

        x_rns = to_rns(mx, self.rns_basis)
        y_rns = to_rns(my, self.rns_basis)

        x_rns_prim = to_rns(mx, self.aux_basis)
        y_rns_prim = to_rns(my, self.aux_basis)

        # Step 1: Compute Z and Z0
        Z = mult(x_rns, y_rns, self.rns_basis)
        Z_prim  = mult(x_rns_prim, y_rns_prim, self.aux_basis)

        #print(f"Z      = {Z} = {from_rns(Z,self.R ,self.rns_basis  ) } ={ mx * my%self.R} in rns { self.rns_basis} ")
        #print(f"Z_prim = {Z_prim} = {from_rns(Z_prim,self.aux_R ,self.aux_basis  ) } ={ mx * my%self.aux_R}  in rns {self.aux_basis}")
         

        # Step 2: Compute Q
        Q = mult(self.rns_neg_N_inv_for_R, Z, self.rns_basis)
        # print(f"Q = {Q} = {from_rns(Q,self.R ,self.rns_basis  ) } in rns { self.rns_basis} ")
        
        # Step 3: Convert Q to Q0
        Q_prim  = convert(Q, self.R, self.rns_basis, self.aux_basis)
        #print(f"Q_prim = {Q_prim} = {from_rns(Q_prim,self.aux_R ,self.aux_basis  ) } in rns {self.aux_basis}")



        #4. Compute C' := (Z' + Q' · N′) · R^−1 in B2,
        #C_prim = mult(Q_prim, self.rns_N_inv_for_aux_R, self.aux_basis)
        
        #print(f"Q' · N'  = {from_rns(C_prim,self.aux_R ,self.aux_basis  ) }= {C_prim} in rns { self.aux_basis} ")
        #this should be (Z' - Q' · N′) maybe ??? 
        #C_prim = mult(C_prim, to_rns(-1,self.aux_basis) , self.aux_basis)
        #C_prim = add(C_prim, Z_prim, self.aux_basis) 
        #print(f"(Z' + Q' · N')  = {from_rns(C_prim,self.aux_R ,self.aux_basis  ) }= {C_prim} in rns { self.aux_basis} ")
        #C_prim = mult(C_prim, self.rns_R_inv_for_aux_R, self.aux_basis)
        
        #4. Compute C' := (Z' + Q' · N′) · R^−1 in B2,
        # Step 4: Compute C0
        C_prim  = mult(add(Z_prim , mult(Q_prim , to_rns(self.N, self.aux_basis), self.aux_basis), self.aux_basis), self.rns_R_inv_for_aux_R, self.aux_basis)
        # print(f"C_prime = {C_prime} = {from_rns(C_prime,self.aux_R ,self.aux_basis  ) } in rns {self.aux_basis}")




        # Step 5: Convert C0 to C
        C = convert(C_prim , self.aux_R, self.aux_basis, self.rns_basis)
        #print(f"C = {C} = {from_rns(C,self.R ,self.rns_basis  ) } in rns {self.rns_basis}")

        # Return the result in Montgomery form
        return from_rns(C, self.R, self.rns_basis)

    def exp(self, x, e):
        """
        Montgomery exponentiation.
        """        
        # add your code here
        x_mong = self.to_montgomery_form(x)
        #print(f" x is {x} == {self.montgomery_mult(self.to_montgomery_form(x) , 1 )} " )
        Z = self.to_montgomery_form(1)     
        binary_exponent = bin(e)[2:]  # Convert exponent to binary string without the '0b' prefix
        #print(f"binary exponent {binary_exponent}")
        
        for bit in binary_exponent:
            
            Z = self.montgomery_mult(Z , Z)   # Perform modular squaring
            #print(f"Z is {Z} out of mong {self.montgomery_mult(Z , 1 )} " )
            if bit == '1':  # Check if the current bit is 1                
                Z = self.montgomery_mult(Z , x_mong)   # Perform modular multiplication
               
        return  self.montgomery_mult(Z , 1 ) # Convert back to normal form

N = 67
x = 31
y = 17
rns_basis = [3, 7, 13]
aux_basis = [5, 11, 17]

m = MontgomeryArithmeticGPU(N, rns_basis, aux_basis)
res = m.montgomery_mult(x, y)
print(res  ) # 49 
print( f"exp {pow( 3, 7,N)}" ,  m.exp( 3, 7) )

25
exp 43 43


# Test Case

In [3]:
N = 67
x = 42
y = 17
rns_basis = [3, 7, 13]
aux_basis = [5, 11, 17]

m = MontgomeryArithmeticGPU(N, rns_basis, aux_basis)
res = m.montgomery_mult(x, y)
R = math.prod(rns_basis)
print(res) # 49
print(f"{res} == {x * y * pow(R, -1, N) % N} ") 



49
49 == 49 


In [4]:
5* 11* 17

935

You should be able to verify your result by simply computing

In [5]:
R = math.prod(rns_basis)
x * y * pow(R, -1, N) % N

49

# Applications to RSA Encryption / Decryption

In [6]:
# Simple RSA no padding ...etc 
class RSA_mong:

    def __init__(self, p, q , rns_basis , aux_basis ) :
        self.p = p
        self.q = q
        self.N = p * q
        self.phi = (p - 1) * (q - 1)
        self.e = self.find_e()
        self.d = self.mod_inverse(self.e, self.phi)
        #print(f"e = {self.e}, d = {self.d}")

        self.mong = MontgomeryArithmeticGPU(self.N, rns_basis, aux_basis)
        #print("created ")
        assert math.prod(rns_basis) >self.N, f"rns basis {math.prod (rns_basis) } should be bigger than N {self.N}"
        assert math.prod(aux_basis) >self.N, f"aux basis {math.prod (aux_basis) } should be bigger than N {self.N}"

    def find_e(self):
        e = 2
        while e < self.phi:
            if gcd(e, self.phi) == 1:
                return e
            e += 1
        raise ValueError("No suitable 'e' value found")

    def mod_inverse(self, a, m):
        g, x, _ = self.extended_gcd(a, m)
        if g != 1:
            raise ValueError("Modular inverse does not exist")
        return x % m

    def extended_gcd(self, a, b):
        if a == 0:
            return b, 0, 1
        else:
            g, x, y = self.extended_gcd(b % a, a)
            return g, y - (b // a) * x, x

    def string_to_int(self, message):
        return int.from_bytes(message.encode(), 'big')

    def int_to_string(self, message_int):
        return message_int.to_bytes((message_int.bit_length() + 7) // 8, 'big').decode()

    def encrypt(self, message):
        message_int = self.string_to_int(message)
        
        #print(f"message_int {message_int}")       
        assert self.N >message_int, f"message {message_int} should be smaller than N {self.N}"
        
        encrypted = self.mong.exp(message_int, self.e ) #pow(message_int, self.e , self.N)
        return encrypted

    def decrypt(self, encrypted):

        decrypted_int = self.mong.exp(encrypted, self.d ) #pow(encrypted, self.d, self.N)
        
        decrypted = self.int_to_string(decrypted_int)
        return decrypted


# Test the RSA_mong 
# https://t5k.org/curios/index.php?start=10&stop=11
p = 10010010010010010010010010010010011 # prime number 
q = 100100011111111110011011100010011011  # prime number somehow
# https://prime-numbers.info/list/primes
rns_basis = [3, 7,  13 , 19 , 29,  37,  43, 53, 61 , 71  ,83,97,103, 109,127 ,137,149,157,167,179 , 191, 197, 211, 227 , 233, 241, 257, 269, 277 ,283 ,307,313,331,347,353 ,367,379,389,401,] # 						  		 																									 						
aux_basis = [5, 11, 17 , 23 , 31 , 41,  47, 59, 67  , 79 ,89,101,107,113, 131 ,139,151,163,173,181 ,193 , 199,223, 229 ,239 , 251, 263, 271, 281 , 293,311,317,337,349,359 ,373,383,397,409,]
rsa = RSA_mong(p, q ,rns_basis ,aux_basis)

message = "Hello CS_58009,  LBC & FHE !"
print(f"Message string: {message}")
encrypted_message = rsa.encrypt(message)
print(f"Encrypted message: {encrypted_message}")

decrypted_message = rsa.decrypt(encrypted_message)
print(f"Decrypted message: {decrypted_message}")

Message string: Hello CS_58009,  LBC & FHE !
Encrypted message: 122578925036443991289781353393213471679972579707928169479215735991301
Decrypted message: Hello CS_58009,  LBC & FHE !


In [7]:
#---> 36 self.R_inv_for_aux_R  = pow(self.R, -1, self.aux_R)
pow(math.prod(rns_basis), -1, math.prod(aux_basis) )
print(len("7721223245249930633031798799077552269136415756723407808837440470360087790642827701198477023560353910422891495169138662851272809287733600831981538375778403257729343800468897884871615137703401940054173753583879747921815027071273   "))
print(len(str(p*q)))
print(len(str(math.prod(rns_basis))))
print(len(str(math.prod(aux_basis))))


229
70
82
83


In [8]:
# Test the RSA_mong 
p = 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003469 # prime number somehow
q = 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000016599  # prime number somehow

rns_basis = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641,] # 
aux_basis = [643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289, 1291, 1297, 1301, 1303,] 
rsa = RSA_mong(p, q ,rns_basis ,aux_basis)

message = "Special Topics in CS: Lattice-Based Cryptography and Homomorphic Encryption Schemes"
print(f"Message string   : {message}")
encrypted_message = rsa.encrypt(message)
print(f"Encrypted message: {encrypted_message}")

decrypted_message = rsa.decrypt(encrypted_message)
print(f"Decrypted message: {decrypted_message}")
# takes 0.5 seconds

Message string   : Special Topics in CS: Lattice-Based Cryptography and Homomorphic Encryption Schemes
Encrypted message: 68320436182364694500499454149333911809143604338173703109240965140961642263886990661876841202465901552677734458688764879333672680890397149763877603823160473008751736511747364447032023814222970169055903
Decrypted message: Special Topics in CS: Lattice-Based Cryptography and Homomorphic Encryption Schemes


msc tests

In [9]:
def chinese_remainder(moduli, residues):
    """
    Calculates the original number given a list of moduli and residues using the Chinese Remainder Theorem.
    Assumes that all moduli are pairwise coprime.
    """
    # Calculate the product of all moduli
    prod = 1
    for m in moduli:
        prod *= m
    
    # Calculate the sum of all residues multiplied by the appropriate coefficient
    result = 0
    for i, (m, r) in enumerate(zip(moduli, residues)):
        # Calculate the coefficient for this residue
        coeff = prod // m
        
        # Calculate the modular inverse of the coefficient
        inverse = pow(coeff, -1, m)
        
        # Add this residue multiplied by the coefficient to the result
        result += r * coeff * inverse
    
    # Take the result modulo the product of all moduli to get the original number
    return result % prod

In [10]:
# understanding  residue number system 
num1= 42
num2= 17
#moduli   
mod1= 19
mod2= 37
mod3= 67
moduli= [mod1, mod2, mod3]
#residue of num1 
num1_mod1= num1%mod1
num1_mod2= num1%mod2
num1_mod3= num1%mod3
residues_of_num1= [num1_mod1, num1_mod2, num1_mod3]
print(f"num1 : {num1} calculated as {chinese_remainder(moduli, residues_of_num1)} using CRT ")
#residue of num2 
num2_mod1= num2%mod1
num2_mod2= num2%mod2
num2_mod3= num2%mod3
residues_of_num2= [num2_mod1, num2_mod2, num2_mod3]
print(f"num2 : {num2} calculated as {chinese_remainder(moduli, residues_of_num2)} using CRT ")
#residue of num1+num2
add_rns_result_mod1= (num1_mod1+num2_mod1) % mod1
add_rns_result_mod2= (num1_mod2+num2_mod2) % mod2
add_rns_result_mod3= (num1_mod3+num2_mod3) % mod3
residues_of_add= [add_rns_result_mod1, add_rns_result_mod2, add_rns_result_mod3]
print(f"num1+num2 : {num1+num2} calculated as {chinese_remainder(moduli, residues_of_add)} using CRT ")
#residue of num1*num2
mult_rns_result_mod1= (num1_mod1*num2_mod1) % mod1
mult_rns_result_mod2= (num1_mod2*num2_mod2) % mod2
mult_rns_result_mod3= (num1_mod3*num2_mod3) % mod3
residues_of_mult= [mult_rns_result_mod1, mult_rns_result_mod2, mult_rns_result_mod3]
print(f"num1*num2 : {num1*num2} calculated as {chinese_remainder(moduli, residues_of_mult)} using CRT ")


num1 : 42 calculated as 42 using CRT 
num2 : 17 calculated as 17 using CRT 
num1+num2 : 59 calculated as 59 using CRT 
num1*num2 : 714 calculated as 714 using CRT 


In [11]:
# montgomery arithmetic
# https://en.wikipedia.org/wiki/Montgomery_modular_multiplication
N=7 
X=95
print( (((X+N) // 2)%N) ==  ((X *pow(2,-1,N) )%N) )

True
