In [None]:
!pip install cupy


In [None]:
pip install cryptography


In [7]:
import cupy as cp

LIMB_BITS = 64
LIMB_MASK = (1 << LIMB_BITS) - 1
NUM_LIMBS = 2 # 256-bit numbers with 64-bit limbs

def print_limb_array(name, arr):
    print(f"{name} shape: {arr.shape}")
    for i, row in enumerate(arr):
        print(f" {name}[{i}]: ", end='')
        print(' '.join(f"{x:#018x}" for x in row))

def pad_limbs(arr, target_limbs):
    diff = target_limbs - arr.shape[1]
    if diff > 0:
        padding = cp.zeros((arr.shape[0], diff), dtype=cp.uint64)
        return cp.concatenate([arr, padding], axis=1)
    return arr
    

def limb_mul(a, b):
    batch_size = a.shape[0]
    k_a = a.shape[1]
    k_b = b.shape[1]
    result = cp.zeros((batch_size, k_a + k_b), dtype=cp.uint64)
    for i in range(k_a):
        for j in range(k_b):
            # Move data to CPU for arbitrary-precision multiplication
            a_ij = cp.asnumpy(a[:, i])
            b_j = cp.asnumpy(b[:, j])
            for batch in range(batch_size):
                prod = int(a_ij[batch]) * int(b_j[0])  # b is broadcasted
                low = prod & LIMB_MASK
                high = prod >> LIMB_BITS
                result[batch, i + j] += low
                if i + j + 1 < k_a + k_b:
                    result[batch, i + j + 1] += high
    return result

def limbs_gte(a, b):
    # Compare if a >= b (both shape (batch, limbs))
    # Returns boolean mask array shape (batch,)
    batch, limbs = a.shape
    # Compare from most significant limb down to least
    for i in reversed(range(limbs)):
        gt_mask = a[:, i] > b[:, i]
        lt_mask = a[:, i] < b[:, i]
        # If greater found, return True
        if cp.any(gt_mask):
            return gt_mask
        # If less found, return False
        if cp.any(lt_mask):
            return ~lt_mask
    # If all limbs equal, a == b, return True
    return cp.ones(batch, dtype=bool)




def int_to_limbs(x, num_limbs=NUM_LIMBS):
    limbs = [(x >> (LIMB_BITS * i)) & LIMB_MASK for i in range(num_limbs)]
    return cp.array([limbs], dtype=cp.uint64)  # note the extra [] to make it 2D (1 x num_limbs)

def limbs_to_int(limbs):
    val = 0
    for i in reversed(range(limbs.shape[0])):
        val = (val << LIMB_BITS) | int(limbs[i])
    return val

def int_to_limbs(x, num_limbs):
    return cp.array([(x >> (64 * i)) & 0xFFFFFFFFFFFFFFFF for i in range(num_limbs)], dtype=cp.uint64)



def batch_limbs_to_int(batch_limbs):
    ints = []
    for limbs in batch_limbs:
        ints.append(limbs_to_int(limbs.get()))
    return ints
def batch_int_to_limbs(int_list, num_limbs):
    """
    Converts a list of Python ints to a 2D CuPy array of limbs.
    Each row is the limb representation of one integer.
    """
    LIMB_MASK = (1 << 64) - 1
    batch = []
    for x in int_list:
        limbs = [(x >> (64 * i)) & LIMB_MASK for i in range(num_limbs)]
        batch.append(limbs)
    return cp.array(batch, dtype=cp.uint64)
    
def barrett_reduce_gpu(x, n, mu):
    print(f"barret_x: {[hex(int(i)) for i in x[0].get()]}")
    print(f"barret_n: {[hex(int(i)) for i in n[0].get()]}")
    print(f"barret_mu: {[hex(int(i)) for i in mu[0].get()]}")
    batch_size = x.shape[0]
    k = n.shape[1]
    print(f"x.shape={x.shape}, n.shape={n.shape}, mu.shape={mu.shape}")
    # For 2-limb x, take q1 = floor(x / β^{k-1}) = x itself if k = 1
    #q1 = x[:, (k-1):] #    print("q1:", [hex(int(v)) for v in cp.asnumpy(q1[0])])
    if x.shape[1] < 2 * k:
        x = pad_limbs(x, 2 * k)
    print(f"barret_x: {[hex(int(i)) for i in x[0].get()]}")  
    q1 = x[:, k-1:-1] 
    # for k=1, floor(x / β^0) = x
    # Ensure q1 and mu are same length
    mu = pad_limbs(mu, x.shape[1])
    q1 = pad_limbs(q1, x.shape[1])
    print("q1:", [hex(int(v)) for v in cp.asnumpy(q1[0])])
    # q2 = q1 * mu
    q2 = limb_mul(q1, mu)
    print(f"q2: {[hex(int(i)) for i in q2[0].get()]}")

    # Take upper limbs (top k+1 = 3 limbs here)
    q3 = q2[:, -(k+1):]    
    print(f"q3: {[hex(int(i)) for i in q3[0].get()]}")

    # Multiply q3 * n directly without padding
    q3n = limb_mul(q3, n)  # Let limb_mul handle different sizes
    print(f"q3*n: {[hex(int(i)) for i in q3n[0].get()]}")
    # r = x - q3 * n
    # Convert limbs arrays to integers
    x_int = limbs_to_int(x[0].get())     # Assuming x shape (1, k)
    q3n_int = limbs_to_int(q3n[0].get())

# Subtract
    r_int = x_int - q3n_int
    if r_int < 0:
        r_int += (1 << (64 * x.shape[1]))  # If negative, add modulus range or handle properly

# Convert back to limbs
    r = int_to_limbs(r_int, n.shape[1])    
    if r.ndim == 1:
        r = r[cp.newaxis, :]
    print(f"r : {[hex(int(i)) for i in r[0].get()]}")
    print(type(r))
    print(type(n))
    result_int = r.dot(1 << cp.arange(NUM_LIMBS * 64, step=64))  # assuming 64-bit limbs
    n_int = int("".join([f"{x:016x}" for x in reversed(n[0].tolist())]), 16)
    r_int = r_int % n_int
    r = int_to_limbs(r_int, n.shape[1])  

    if r.ndim == 1:
        r = r[cp.newaxis, :]
    print(f"Final result: {[hex(int(i)) for i in r[0].get()]}")
    return r

def mod_mul(a, b, n, mu):
    print(f"\nModular Multiplication:")

    
    x = limb_mul(a, b)
    
    r = barrett_reduce(x, n, mu)
    return r


def modexp_barrett_gpu(base, exp, n, mu):
    """
    base: (batch_size, num_limbs) cp.uint64
    exp: Python int (scalar exponent)
    n, mu: (batch_size, num_limbs) cp.uint64
    Returns: base^exp % n using Barrett reduction, shape (batch_size, num_limbs)
    """
    batch_size, num_limbs = base.shape
    result = cp.zeros((batch_size, num_limbs), dtype=cp.uint64)
    result[:, 0] = 1  # Initialize to 1 in multi-limb representation

    base_copy = cp.copy(base)

    while exp > 0:
        if exp & 1:
            prod = limb_mul(result, base_copy)
            result = barrett_reduce_gpu(prod, n, mu)

        # base = base^2 mod n
        prod = limb_mul(base_copy, base_copy)
        base_copy = barrett_reduce_gpu(prod, n, mu)

        exp >>= 1

    return result





In [8]:
import cupy as cp
import time

def compute_barrett_mu(n_int, k):
    """Compute Barrett mu = floor(b^{2k} / n)"""
    b = 1 << 64
    mu = (1 << (2 * k * 64)) // n_int
    return mu

def rsa_barrett_encrypt_decrypt_test(msg_int, e, d, n_int, num_limbs=2, num_iterations=1000):
    print("\n=== RSA Parameters ===")
    print(f"Modulus (n):      {n_int}")
    print(f"Public Key (e):   {e}")
    print(f"Private Key (d):  {d}")
    print(f"Message:          {msg_int}")

    mu_int = compute_barrett_mu(n_int, num_limbs)
    print(f"Computed Mu:      {mu_int}")

    # Convert to limb arrays
    base = batch_int_to_limbs([msg_int], num_limbs)
    n = batch_int_to_limbs([n_int], num_limbs)
    mu = batch_int_to_limbs([mu_int], 2 * num_limbs)

    total_enc_time = 0.0
    total_dec_time = 0.0
    decrypted_int = -1
    cipher_int = -1

    for i in range(num_iterations):
        # Encrypt
        start = time.time()
        enc_limbs = modexp_barrett_gpu(base, e, n, mu)
        cp.cuda.Device(0).synchronize()
        enc_time = time.time() - start
        total_enc_time += enc_time * 1000

        # Decrypt
        start = time.time()
        dec_limbs = modexp_barrett_gpu(enc_limbs, d, n, mu)
        cp.cuda.Device(0).synchronize()
        dec_time = time.time() - start
        total_dec_time += dec_time * 1000

        # Convert result to integer
        decrypted_int = batch_limbs_to_int(dec_limbs)[0]
        cipher_int = batch_limbs_to_int(enc_limbs)[0]

        if decrypted_int != msg_int:
            print(f"\nVerification: FAILED! Decrypted message ({decrypted_int}) does NOT match original ({msg_int}).")
            return

        if i == num_iterations - 1:
            print(f"\n--- Final Verification (last iteration) ---")
            print(f"Encrypted: {cipher_int}")
            print(f"Decrypted: {decrypted_int}")
            print("Verification: SUCCESS!" if decrypted_int == msg_int else "FAILED!")

    print(f"\n--- Benchmarking Summary ({num_iterations} iterations) ---")
    print(f"Average Encryption Time: {total_enc_time / num_iterations:.4f} ms")
    print(f"Average Decryption Time: {total_dec_time / num_iterations:.4f} ms")



In [11]:
p = 61
q = 53
n = p * q          # 3233
e = 17
d = 2753
# Use a large message to ensure x >= n during encryption
msg = 123  # This is much larger than n

rsa_barrett_encrypt_decrypt_test(msg, e, d, n, num_limbs=2, num_iterations=1)


=== RSA Parameters ===
Modulus (n):      3233
Public Key (e):   17
Private Key (d):  2753
Message:          123
Computed Mu:      35815678700066871457955763998975535989257650685320310559683756265979934933
barret_x: ['0x7b', '0x0', '0x0', '0x0']
barret_n: ['0xca1', '0x0']
barret_mu: ['0x3103b8bc23ca78d5', '0x618dd14822d73875', '0xa0a9c4edddf37d30', '0x14455d5b74f433']
x.shape=(1, 4), n.shape=(1, 2), mu.shape=(1, 4)
barret_x: ['0x7b', '0x0', '0x0', '0x0']
q1: ['0x0', '0x0', '0x0', '0x0']
q2: ['0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0']
q3: ['0x0', '0x0', '0x0']
q3*n: ['0x0', '0x0', '0x0', '0x0', '0x0']
r : ['0x7b', '0x0']
<class 'cupy.ndarray'>
<class 'cupy.ndarray'>
Final result: ['0x7b', '0x0']
barret_x: ['0x3b19', '0x0', '0x0', '0x0']
barret_n: ['0xca1', '0x0']
barret_mu: ['0x3103b8bc23ca78d5', '0x618dd14822d73875', '0xa0a9c4edddf37d30', '0x14455d5b74f433']
x.shape=(1, 4), n.shape=(1, 2), mu.shape=(1, 4)
barret_x: ['0x3b19', '0x0', '0x0', '0x0']
q1: ['0x0', '0x0', '0x0', 

In [19]:
def batch_int_to_limbs(int_list, num_limbs):
    """
    Converts a list or array of ints to a 2D CuPy array of limbs.
    Each row is the limb representation of one integer.
    """
    import cupy as cp
    LIMB_MASK = (1 << 64) - 1
    batch = []
    for x in int_list:
        x = int(x)  # <-- Ensure Python int for bitwise ops
        limbs = [(x >> (64 * i)) & LIMB_MASK for i in range(num_limbs)]
        batch.append(limbs)
    return cp.array(batch, dtype=cp.uint64)
def modexp_barrett_gpu(base, exp, n, mu):
    # base, n, mu: (1, num_limbs) cp.uint64 arrays
    # result: (1, num_limbs) cp.uint64
    num_limbs = base.shape[1]
    result = cp.zeros((1, num_limbs), dtype=cp.uint64)
    result[0, 0] = 1  # Set to 1 in limb form

    base_copy = cp.copy(base)
    while exp > 0:
        if exp & 1:
            prod = limb_mul(result, base_copy)
            result = barrett_reduce_gpu(prod, n, mu)
        prod = limb_mul(base_copy, base_copy)
        base_copy = barrett_reduce_gpu(prod, n, mu)
        exp >>= 1
    return result
    
batch_msg = [123]  # Example batch
msg_cp = batch_int_to_limbs(batch_msg, num_limbs)
n_cp = batch_int_to_limbs([n], num_limbs)
mu_cp = batch_int_to_limbs([mu_int], num_limbs * 2)

# Encrypt
cipher_cp = modexp_barrett_gpu(msg_cp, e, n_cp, mu_cp)
# Decrypt
plain_cp = modexp_barrett_gpu(cipher_cp, d, n_cp, mu_cp)
decrypted_ints = batch_limbs_to_int(plain_cp)
print(decrypted_ints)
batch_msg = np.array(batch_msg)
if np.all(decrypted_ints == batch_msg):
    print("Batch RSA verification: SUCCESS! All messages decrypted correctly.")
else:
    print("Batch RSA verification: FAILED!")
    for i, (orig, dec) in enumerate(zip(batch_msg, decrypted_ints)):
        if orig != dec:
            print(f"Mismatch at index {i}: original={orig}, decrypted={dec}")


barret_x: ['0x7b', '0x0', '0x0', '0x0']
barret_n: ['0xca1', '0x0']
barret_mu: ['0x3103b8bc23ca78d5', '0x618dd14822d73875', '0xa0a9c4edddf37d30', '0x14455d5b74f433']
x.shape=(1, 4), n.shape=(1, 2), mu.shape=(1, 4)
barret_x: ['0x7b', '0x0', '0x0', '0x0']
q1: ['0x0', '0x0', '0x0', '0x0']
q2: ['0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0']
q3: ['0x0', '0x0', '0x0']
q3*n: ['0x0', '0x0', '0x0', '0x0', '0x0']
r : ['0x7b', '0x0']
<class 'cupy.ndarray'>
<class 'cupy.ndarray'>
Final result: ['0x7b', '0x0']
barret_x: ['0x3b19', '0x0', '0x0', '0x0']
barret_n: ['0xca1', '0x0']
barret_mu: ['0x3103b8bc23ca78d5', '0x618dd14822d73875', '0xa0a9c4edddf37d30', '0x14455d5b74f433']
x.shape=(1, 4), n.shape=(1, 2), mu.shape=(1, 4)
barret_x: ['0x3b19', '0x0', '0x0', '0x0']
q1: ['0x0', '0x0', '0x0', '0x0']
q2: ['0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0', '0x0']
q3: ['0x0', '0x0', '0x0']
q3*n: ['0x0', '0x0', '0x0', '0x0', '0x0']
r : ['0x3b19', '0x0']
<class 'cupy.ndarray'>
<class 'cupy.ndarray'>
Fi