In [3]:
!pip install cupy


Defaulting to user installation because normal site-packages is not writeable
Collecting cupy
  Downloading cupy-13.4.1.tar.gz (3.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m37.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting fastrlock>=0.5 (from cupy)
  Using cached fastrlock-0.8.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Using cached fastrlock-0.8.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl (54 kB)
Building wheels for collected packages: cupy
  Building wheel for cupy (setup.py) ... [?25ldone
[?25h  Created wheel for cupy: filename=cupy-13.4.1-cp311-cp311-linux_x86_64.whl size=71903078 sha256=3d26e06b3b8c4bcebdcedf0066246cf0b5a096b98add05a1206ef89de3f6dfab
  Stored in directory: /home/anchaudhary/.cache/pip/wheels/a8/ac/7c/9a43f828e4b9ec5ad848523874d3e915e07914b1a9e2599fe4
Successfully

In [27]:
pip install cryptography


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [32]:
import cupy as cp

LIMB_BITS = 64
LIMB_MASK = (1 << LIMB_BITS) - 1
NUM_LIMBS = 2 # 256-bit numbers with 64-bit limbs
def print_limb_array(name, arr):
    print(f"{name} shape: {arr.shape}")
    for i, row in enumerate(arr):
        # show first 3 and last 3 limbs in hex
        print(f" {name}[{i}]: ", end='')
        print(' '.join(f"{x:#018x}" for x in row[:3]), end=' ... ')
        print(' '.join(f"{x:#018x}" for x in row[-3:]))

def int_to_limbs(x, num_limbs=NUM_LIMBS):
    limbs = [(x >> (LIMB_BITS * i)) & LIMB_MASK for i in range(num_limbs)]
    return cp.array(limbs, dtype=cp.uint64)

def batch_int_to_limbs(int_list, num_limbs=NUM_LIMBS):
    batch = [int_to_limbs(x, num_limbs) for x in int_list]
    return cp.stack(batch)

def limbs_to_int(limbs):
    result = 0
    for i in range(limbs.shape[0]):
        result += int(limbs[i]) << (LIMB_BITS * i)
    return result

def batch_limbs_to_int(batch_limbs):
    ints = []
    for limbs in batch_limbs:
        ints.append(limbs_to_int(limbs.get()))
    return ints


def pad_limbs(arr, target_limbs):
    diff = target_limbs - arr.shape[1]
    if diff > 0:
        padding = cp.zeros((arr.shape[0], diff), dtype=cp.uint64)
        return cp.concatenate([arr, padding], axis=1)
    return arr

def limb_mul(a, b):
    batch_size, k = a.shape
    result = cp.zeros((batch_size, 2 * k), dtype=cp.uint64)
    for i in range(k):
        for j in range(k):
            result[:, i + j] += a[:, i] * b[:, j]
    return result


def limb_sub(A, B):
    batch_size, n = A.shape
    R = cp.zeros_like(A)
    borrow = cp.zeros(batch_size, dtype=cp.uint64)
    for i in range(n):
        sub = A[:, i].astype(cp.uint64) - B[:, i].astype(cp.uint64) - borrow
        borrow = (sub > A[:, i]).astype(cp.uint64)
        R[:, i] = sub & LIMB_MASK
    return R, borrow

def limb_gte(A, B):
    batch_size, n = A.shape
    res = cp.ones(batch_size, dtype=cp.bool_)
    for i in reversed(range(n)):
        a = A[:, i]
        b = B[:, i]
        gt = a > b
        lt = a < b
        res = cp.where(res, cp.where(gt, True, cp.where(lt, False, True)), res)
    return res

def barrett_reduce_gpu(x, n, mu):
    batch_size = x.shape[0]
    k = NUM_LIMBS

    print(f"x.shape={x.shape}, n.shape={n.shape}, mu.shape={mu.shape}")

    # For 2-limb x, take q1 = floor(x / β^{k-1}) = x itself if k = 1
    q1 = x  # for k=1, floor(x / β^0) = x
    print(f"q1: {[hex(int(i)) for i in q1[0].get()]}")

    # Ensure q1 and mu are same length
    mu = pad_limbs(mu, x.shape[1])
    q1 = pad_limbs(q1, x.shape[1])

    # q2 = q1 * mu
    q2 = limb_mul(q1, mu)
    print(f"q2: {[hex(int(i)) for i in q2[0].get()]}")

    # Take upper limbs (top k+1 = 2 limbs here)
    q3 = q2[:, -2:]
    print(f"q3: {[hex(int(i)) for i in q3[0].get()]}")

    # Multiply q3 * n
    n_pad = pad_limbs(n, q3.shape[1])
    q3n = limb_mul(q3, n_pad)[:, :x.shape[1]]
    print(f"q3*n: {[hex(int(i)) for i in q3n[0].get()]}")
    # r = x - q3 * n
    r = x - q3n
    print(f"x: {[hex(int(i)) for i in x[0].get()]}")
    print(f"r : {[hex(int(i)) for i in r[0].get()]}")

    # If underflow occurred, correct
    for i in range(2):
        mask = cp.all(r >= n, axis=1)
        if not mask.any():
            break
        r = r - pad_limbs(n, r.shape[1])

    print(f"Final result: {[hex(int(i)) for i in r[0].get()]}")
    return r

def mod_mul(a, b, n, mu):
    print(f"\nModular Multiplication:")

    
    x = limb_mul(a, b)
    
    r = barrett_reduce(x, n, mu)
    return r

def mod_exp(base, exp, n, mu):
    batch_size = base.shape[0]
    result = cp.zeros_like(base)
    for i in range(batch_size):
        # Initialize r as [1, 0] for proper multiplication
        r = cp.zeros(NUM_LIMBS, dtype=cp.uint64)
        r[0] = 1  # Set first limb to 1
        b = base[i].copy()  # Make a copy to avoid modifying original
        e = exp
        while e > 0:
            if e & 1:
                r = mod_mul(r[None, :], b[None, :], n[i:i + 1], mu[i:i + 1])[0]
            b = mod_mul(b[None, :], b[None, :], n[i:i + 1], mu[i:i + 1])[0]
            e >>= 1
        result[i] = r
    return result

def compute_barrett_mu(n):
    batch_size = n.shape[0]
    mu = cp.zeros((batch_size, NUM_LIMBS + 1), dtype=cp.uint64)
    for i in range(batch_size):
        n_int = limbs_to_int(n[i])
        # For k-limb modulus, compute ⌊β^(2k)/n⌋
        b_pow = 1 << (2 * NUM_LIMBS * LIMB_BITS)
        mu_int = b_pow // n_int
        mu_limbs = batch_int_to_limbs([mu_int], NUM_LIMBS + 1)[0]
        mu[i] = mu_limbs
    return mu



In [33]:
def rsa_barrett_demo():
    e, d, n_int = generate_rsa_keypair()

    print("\nRSA Parameters:")
    print(f"e = {e}")
    print(f"d = {hex(d)}")
    print(f"n = {hex(n_int)}")

    # Two test messages < n
    message_ints = [
        0x123456789abcdef0,  # Use smaller test messages first
        0xfedcba9876543210
    ]

    m = batch_int_to_limbs(message_ints, NUM_LIMBS)
    n = batch_int_to_limbs([n_int] * len(message_ints), NUM_LIMBS)
    mu = compute_barrett_mu(n)

    print("\nTest Messages:")
    print("Original:", [hex(x) for x in message_ints])

    # Encrypt: c = m^e mod n
    c = mod_exp(m, e, n, mu)
    c_ints = batch_limbs_to_int(c)
    print("Encrypted:", [hex(x) for x in c_ints])

    # Decrypt: m = c^d mod n
    decrypted = mod_exp(c, d, n, mu)
    d_ints = batch_limbs_to_int(decrypted)
    print("Decrypted:", [hex(x) for x in d_ints])

    # Verify
    for i in range(len(message_ints)):
        assert message_ints[i] == d_ints[i], f"Mismatch at index {i}: {hex(message_ints[i])} vs {hex(d_ints[i])}"
    print("\n✅ RSA encryption and decryption verified using Barrett reduction.")


In [34]:
from sympy import randprime
import random

def generate_64bit_prime():
    # Generate a random 64-bit prime
    return randprime(2**63, 2**64 - 1)

def generate_rsa_keypair():
    # Generate primes p, q
    p = generate_64bit_prime()
    q = generate_64bit_prime()
    while p == q:
        q = generate_64bit_prime()

    n = p * q  # 128-bit modulus
    phi = (p - 1) * (q - 1)

    # Public exponent (common choice)
    e = 65537

    # Compute private exponent d (modular inverse of e mod phi)
    # Use extended Euclidean algorithm for modular inverse
    d = modinv(e, phi)

    return e, d, n

def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('Modular inverse does not exist')
    else:
        return x % m


In [35]:
import cupy as cp
from sympy import randprime

LIMB_BITS = 64
LIMB_MASK = (1 << LIMB_BITS) - 1
NUM_LIMBS = 2  # 128-bit numbers with 64-bit limbs

# (Include your limb/int conversion and modular arithmetic functions here)
# For brevity, assume all those functions like batch_int_to_limbs, batch_limbs_to_int,
# mod_exp, mod_mul, compute_barrett_mu, etc. are already defined exactly as you wrote.

def generate_64bit_prime():
    return randprime(2**63, 2**64 - 1)

def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    g, x, _ = egcd(a, m)
    if g != 1:
        raise Exception('Modular inverse does not exist')
    else:
        return x % m

def generate_rsa_keypair():
    p = generate_64bit_prime()
    q = generate_64bit_prime()
    while p == q:
        q = generate_64bit_prime()

    n = p * q
    phi = (p - 1) * (q - 1)
    e = 65537
    d = modinv(e, phi)
    return e, d, n

def rsa_barrett_demo():
    e, d, n_int = generate_rsa_keypair()

    print(f"Generated RSA keypair:")
    print(f"  e = {e}")
    print(f"  d = {hex(d)}")
    print(f"  n = {hex(n_int)}")

    # Some test messages < n
    message_ints = [
        0x123456789abcdef0,
        0xfedcba9876543210
    ]

    m = batch_int_to_limbs(message_ints, NUM_LIMBS)
    n = batch_int_to_limbs([n_int] * len(message_ints), NUM_LIMBS)
    mu = compute_barrett_mu(n)

    print("Original messages:", [hex(x) for x in batch_limbs_to_int(m)])

    # Encrypt
    c = mod_exp(m, e, n, mu)
    print("Encrypted ciphertexts:", [hex(x) for x in batch_limbs_to_int(c)])

    # Decrypt
    decrypted = mod_exp(c, d, n, mu)
    print("Decrypted messages:", [hex(x) for x in batch_limbs_to_int(decrypted)])

    # Verify correctness
    m_ints = batch_limbs_to_int(m)
    d_ints = batch_limbs_to_int(decrypted)
    for i in range(len(m_ints)):
        assert m_ints[i] == d_ints[i], f"Mismatch at index {i}: {hex(m_ints[i])} != {hex(d_ints[i])}"
    print("✅ RSA encryption and decryption verified successfully.")

if __name__ == "__main__":
    rsa_barrett_demo()


Generated RSA keypair:
  e = 65537
  d = 0x7b91f9c3f61badb4d98241f4b0df044d
  n = 0xaca6018186537d4f7f4196ab80ab4211
Original messages: ['0x123456789abcdef0', '0xfedcba9876543210']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0x123456789abcdef0', '0x0']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0xa5e20890f2a52100', '0x0']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0x75d1bce78e410000', '0x0']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0x805a2c8100000000', '0x0']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0x0', '0x0']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0x0', '0x0']

Modular Multiplication:
q2=['0x0', '0x0', '0x0', '0x0']
q3=['0x0', '0x0']
Final result=['0x0', '0x0']

Modular Multiplication:
q2=['0x0', '0

AssertionError: Mismatch at index 0: 0x123456789abcdef0 != 0x0