In [30]:
import numpy as np


N = 512
MOD = 7681
DELTA = 2  

def mulmod(a, b):
    return (a * b) % MOD

def addmod(a, b):
    res = a + b
    return res - MOD if res >= MOD else res

def submod(a, b):
    res = a - b
    return res + MOD if res < 0 else res

def modpow(x, e):
    res = 1
    while e:
        if e & 1:
            res = mulmod(res, x)
        x = mulmod(x, x)
        e >>= 1
    return res


def bit_reverse(x, logn):
    res = 0
    for i in range(logn):
        if x & (1 << i):
            res |= 1 << (logn - 1 - i)
    return res

def ntt(a, root=7146):
    logn = N.bit_length() - 1
    a = np.copy(a)
    
    for i in range(N):
        j = bit_reverse(i, logn)
        if i < j:
            a[i], a[j] = a[j], a[i]

    len_ = 2
    while len_ <= N:
        wlen = modpow(root, N // len_)
        for i in range(0, N, len_):
            w = 1
            for j in range(len_ // 2):
                u = a[i + j]
                v = mulmod(a[i + j + len_ // 2], w)
                a[i + j] = addmod(u, v)
                a[i + j + len_ // 2] = submod(u, v)
                w = mulmod(w, wlen)
        len_ <<= 1
    return a

def intt(a, root_inv=7480, n_inv=7666):
    a = ntt(a, root=root_inv)
    return [(mulmod(x, n_inv)) for x in a]


In [31]:
def string_to_bitarray(message):
    bit_array = []
    for c in message:
        ascii_val = ord(c)  
        bits = [(ascii_val >> i) & 1 for i in reversed(range(8))]
        bit_array.extend(bits)
    return np.array(bit_array, dtype=np.uint32)


In [32]:
def bitarray_to_string(bit_array):
    chars = []
    for i in range(0, len(bit_array), 8):
        byte = bit_array[i:i+8]
        if len(byte) < 8:
            break  
        val = 0
        for bit in byte:
            val = (val << 1) | bit
        chars.append(chr(val))
    return ''.join(chars)


In [33]:
def keygen(A, S, E):
    assert len(A) == len(S) == len(E) == N

    A_ntt = ntt(A[:])
    S_ntt = ntt(S[:])
    pk_ntt = [mulmod(A_ntt[i], S_ntt[i]) for i in range(N)]
    pk = intt(pk_ntt)
    pk = [addmod(pk[i], mulmod(E[i], 2)) for i in range(N)]
    return pk


In [34]:
A = [i % MOD for i in range(N)]
S = [(3 * i + 1) % MOD for i in range(N)]
E = [1 for _ in range(N)]

pk = keygen(A, S, E)

print("Public Key (first 8):", pk[:8])


Public Key (first 8): [7633, 6814, 4459, 568, 2822, 3540, 2722, 368]


In [None]:
def encrypt(A, pk, R, M):
    assert len(A) == len(pk) == len(R) == len(M) == N

    # Encrypt → V = A ⋆ R
    A_ntt = ntt(A[:])
    R_ntt = ntt(R[:])
    V_ntt = [mulmod(A_ntt[i], R_ntt[i]) for i in range(N)]
    V = intt(V_ntt)

    # Encrypt → W = pk ⋆ R + M
    pk_ntt = ntt(pk[:])
    R_ntt = ntt(R[:])  
    W_ntt = [mulmod(pk_ntt[i], R_ntt[i]) for i in range(N)]
    W = intt(W_ntt)
    W = [addmod(W[i], M[i]) for i in range(N)]

    return V, W


In [36]:
R  = [1 for _ in range(N)]
message = "HELLO WORLD!"
msg_bits = string_to_bitarray(message)
bit_len = len(msg_bits)


if bit_len > 512:
    raise ValueError("Message too long! Max 64 characters.")

M = np.zeros(512, dtype=np.uint32)
M[:bit_len] = msg_bits

V, W = encrypt(A, pk, R, M)

print("V (first 8):", V[:8])
print("W (first 8):", W[:8])


V (first 8): [239, 239, 239, 239, 239, 239, 239, 239]
W (first 8): [2877, 2878, 2877, 2877, 2878, 2877, 2877, 2877]


In [37]:
def decrypt(V, W, S):
    assert len(V) == len(W) == len(S) == N

    V_ntt = ntt(V[:])
    S_ntt = ntt(S[:])
    U_ntt = [mulmod(V_ntt[i], S_ntt[i]) for i in range(N)]
    U = intt(U_ntt)

    M_out = [submod(W[i], U[i]) & 1 for i in range(N)]
    return M_out


In [38]:
# Use V, W generated from encrypt()
# S = [(3 * i + 1) % MOD for i in range(N)]

m_recovered = decrypt(V, W, S)
print("Recovered Message (first 32 bits):", m_recovered[:32])


Recovered Message (first 32 bits): [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0]


In [45]:
for i in range(N):
    m_recovered[i] = 1 - m_recovered[i]  # Invert bits for comparison
    if M[i] != m_recovered[i]:
        print(f"Error at index {i}: {m_recovered[i]} != {1 - m_recovered[i]}")
    

In [46]:
print("Recovered Message (first 32 bits):", m_recovered[:32])

Recovered Message (first 32 bits): [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0]


In [22]:
print(" Message (first 32 bits):", M[:32])

 Message (first 32 bits): [0 1 0 0 1 0 0 0 0 1 0 0 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 0 1 1 0 0]


# Hardware Implementation

In [29]:
N = 512
MOD = 7681


In [None]:
ol = Overlay("keygen.bit")
dma = ol.axi_dma_0  

in_buffer = allocate(shape=(N * 3,), dtype=np.uint32)  
out_buffer = allocate(shape=(N,), dtype=np.uint32)   

In [None]:
in_buffer[0::3] = A
in_buffer[1::3] = S
in_buffer[2::3] = E

In [None]:
dma.sendchannel.transfer(in_buffer)
dma.recvchannel.transfer(out_buffer)

dma.sendchannel.wait()
dma.recvchannel.wait()

In [None]:
pk_hw = out_buffer[:N]
print("Public Key (first 8):", pk_hw[:8])

### Encryption

In [None]:
ol = Overlay("encrypt.bit")
dma = ol.axi_dma_0  

In [None]:
in_buffer = allocate(shape=(N * 4,), dtype=np.uint32)  
out_buffer = allocate(shape=(N*2,), dtype=np.uint32)   

in_buffer[0::4] = A
in_buffer[1::4] = pk_hw
in_buffer[2::4] = R
in_buffer[3::4] = M


In [None]:

dma.sendchannel.transfer(in_buffer)
dma.recvchannel.transfer(out_buffer)

dma.sendchannel.wait()
dma.recvchannel.wait()

In [None]:

v= out_buffer[:N]
w= out_buffer[N:]

print("v (first 8):", v[:8])
print("w (first 8):", w[:8])

### Decryption

In [None]:
ol = Overlay("decrypt.bit")
dma = ol.axi_dma_0  


In [None]:
in_buffer = allocate(shape=(N * 3,), dtype=np.uint32)  
out_buffer = allocate(shape=(N,), dtype=np.uint32)   

in_buffer[0::3] = v
in_buffer[1::3] = w
in_buffer[2::3] = S

In [None]:

dma.sendchannel.transfer(in_buffer)
dma.recvchannel.transfer(out_buffer)

dma.sendchannel.wait()
dma.recvchannel.wait()

In [None]:

decrypted = out_buffer[:N]

print("Decrypted (first 8):", decrypted[:8])

In [None]:
msg_recovered = bitarray_to_string(decrypted)

print("Message Sent: ", message)
print("Recovered Message: ", msg_recovered)