In [25]:
import numpy as np
import math
import time
from pynq import Overlay
from pynq import allocate
from pynq import MMIO

# Loading the bit file to configure the PL
ol = Overlay("Bitstreams/polymult_kyber2.bit")

# Kyber-v2 parameters
q = 3329
n2 = 256
n = 128
inv_n = 3303
psin = 17
inv_psin = 1175
k = 2
eta1 = 2
eta2 = 3
du = 10
dv = 4

In [26]:
# Function to perform bit-reversal (duh...)
def bitReverse(num, logn):
    rev_num = 0
    for i in range(logn):
        if (num >> i) & 1:
            rev_num |= 1 << (logn - 1 - i)
    return rev_num

# Function to generate twiddle factors (for both forward and inverse NTT)
def gen_tf(psin, inv_psin, n, q):
    positions = [bitReverse(x, int(np.log2(n))) for x in range(n)]
    tmp1, tmp2 = [], []
    psis, inv_psis = [], []
    psi = 1
    inv_psi = 1
    for x in range(n):
        tmp1.append(psi)
        tmp2.append(inv_psi)
        psi = psi * psin % q
        inv_psi = inv_psi * inv_psin % q
    for x in range(n):
        val = tmp1[positions[x]]
        inv_val = tmp2[positions[x]]
        psis.append(val)
        inv_psis.append(inv_val)
    return psis, inv_psis

# Function to generate scaling factors for point wise multiplication
def gen_pwmf(psin, n, q):
    pwmf = []
    for i in range(n):
        val = (psin**(2*bitReverse(i, int(np.log2(n))) + 1))%q
        pwmf.append(val)
    return pwmf

# Functions to generate Centered Binomial Distribution
def _cbd(n, eta):
    i = 0
    while i < eta:
        p1 = np.random.randint(0, 2, n)
        if i == 0:
            p = p1
        else:
            p = p + p1
        i = i + 1
    return p

def cbd(n, eta):
    a = _cbd(n, eta)
    b = _cbd(n, eta)
    return a - b
    
def cbd_vector(n, eta, k):
    result = []

    for i in range(k):
        result.append(cbd(n, eta))

    return np.squeeze(np.array(result, dtype=np.int16))

# Compression function
def compress(x, q, d):
    q1 = 2**d
    x = np.round(q1 / q * x).astype(np.int16)
    x = np.remainder(x, q1)
    return x

# De-compression function
def decompress(x, q, d):
    q1 = 2**d
    x = np.round(q / q1 * x).astype(np.int16)
    x = np.remainder(x, q)
    return x

In [None]:
# 128 point Forward NTT using Cooley-Tukey (TC) algorithm
def ct_ntt(a, psis, q, n):
    t = n
    m = 1
    while m < n:
        t = t // 2
        for i in range(m):
            j1 = 2 * i * t
            j2 = j1 + t - 1
            S = psis[m + i]
            for j in range(j1, j2 + 1):
                U = a[j]
                V = a[j + t] * S
                a[j] = (U + V) % q
                a[j + t] = (U - V) % q
        m = 2 * m
    return a
  
# 128 point Inverse NTT using Gentleman-Sande (GS) algorithm
def gs_intt(a, inv_psis, q, n, inv_n):
    t = 1
    m = n
    while m > 1:
        j1 = 0
        h = m // 2
        for i in range(h):
            j2 = j1 + t - 1
            S = inv_psis[h + i]
            for j in range(j1, j2 + 1):
                U = a[j]
                V = a[j + t]
                a[j] = (U + V) % q
                a[j + t] = (U - V) * S % q
            j1 = j1 + 2 * t
        t = 2 * t
        m = m // 2
    for i in range(n):
        a[i] = a[i] * inv_n % q
    return a

# 256 point NTT using two 128 point NTTs
def ntt_256(x, psis, q, n):
    xe, xo = [], []
    for i in range(n2):
        if i%2 == 0:
            xe.append(x[i])
        else:
            xo.append(x[i])
    ye = ct_ntt(xe, psis, q, n)
    yo = ct_ntt(xo, psis, q, n)
    return ye, yo

# 256 point INTT using two 128 point INTTs
def intt_256(ye, yo, inv_psis, q, n, inv_n):
    ze = gs_intt(ye, inv_psis, q, n, inv_n)
    zo = gs_intt(yo, inv_psis, q, n, inv_n)
    z = []
    for i in range(n):
        z.append(ze[i])
        z.append(zo[i])
    return z

# Point-wise multiplication in NTT domain
def point_wise_mult(y1e, y1o, y2e, y2o, pwmf):
    y3e, y3o = [], []
    for i in range(n):
        y3e.append(((y1e[i] * y2e[i]) % q + (((y1o[i] * y2o[i]) % q) * pwmf[i]) % q) % q)
        y3o.append(((y1e[i] * y2o[i]) % q + (y1o[i] * y2e[i]) % q) % q)
    return y3e, y3o

In [27]:
# Polynomial multiplication under mod (x^n + 1) in Software
# [i.e negative wrapped convolution] using NTT-INTT method
def poly_mul_sw(x1, x2):

    y1e, y1o = ntt_256(x1, psis, q, n)
    y2e, y2o = ntt_256(x2, psis, q, n)

    y3e, y3o = point_wise_mult(y1e, y1o, y2e, y2o, pwmf)

    z = intt_256(y3e, y3o, inv_psis, q, n, inv_n)

    return z

# Polynomial multiplication under mod (x^n + 1) in Hardware
# [i.e negative wrapped convolution] using NTT-INTT method
def poly_mul_hw(x1, x2):
    
    # Physical memory locations in DRAM allocated to
    # input_buffers and output_buffers
    input_buffer = allocate(shape=(n2,), dtype=np.int32)
    output_buffer = allocate(shape=(n2,), dtype=np.int16)
    
    x = x1 + (2**16*x2)
    input_buffer[0:n2] = x
    
    # starting the operation
    mmio.write(0x0, 1)

    # DMA send
    dma_send.transfer(input_buffer)
    # DMA receive
    dma_recv.transfer(output_buffer)
    
    # output_buffer sent to y1
    y = list(output_buffer)
    
    # delete the buffers
    del input_buffer, output_buffer
    
    return y

In [28]:
# Kyber PKE functions entirely in SW
# Key generation function (to be performed by server)
def key_gen():
    a = np.random.randint(q, size=(k,k,n2))
    s = cbd_vector(n2, eta1, k)
    e = cbd_vector(n2, eta1, k)
    b0 = (poly_mul_sw(a[0,0], s[0]) + e[0]) % q
    b1 = (poly_mul_sw(a[0,1], s[1]) + e[1]) % q
    b2 = (poly_mul_sw(a[1,0], s[0]) + e[0]) % q
    b3 = (poly_mul_sw(a[1,1], s[1]) + e[1]) % q
    b01 = (b0 + b1) % q
    b23 = (b2 + b3) % q
    b = np.array([b01, b23])
    return s, a, b

# Encryption function (to be performed by client)
def encrypt(a, b, m):
    r = cbd_vector(n2, eta1, k)
    e1 = cbd_vector(n2, eta2, k)
    e2 = cbd(n2, eta2)
    u0 = (poly_mul_sw(a[0,0], r[0]) + e1[0]) % q
    u1 = (poly_mul_sw(a[1,0], r[1]) + e1[1]) % q
    u2 = (poly_mul_sw(a[0,1], r[0]) + e1[0]) % q
    u3 = (poly_mul_sw(a[1,1], r[1]) + e1[1]) % q
    u01 = (u0 + u1) % q
    u23 = (u2 + u3) % q
    u = np.array([u01, u23])
    v0 = np.array(poly_mul_sw(b[0], r[0]))
    v1 = np.array(poly_mul_sw(b[1], r[1]))
    v = (v0 + v1 + e2 + m) % q
    u = compress(u, q, du)
    v = compress(v, q, dv)
    return u, v

# Decryption function (to be performed by server)
def decrypt(s, u, v):
    u = decompress(u, q, du)
    v = decompress(v, q, dv)
    p0 = np.array(poly_mul_sw(s[0], u[0]))
    p1 = np.array(poly_mul_sw(s[1], u[1]))
    p = (p0 + p1) % q
    d = (v - p) % q
    return d

In [31]:
# Get pre-computed factors
psis, inv_psis = gen_tf(psin, inv_psin, n, q)
pwmf = gen_pwmf(psin, n, q)

start_sw = time.time()

# Randomly generated binary message, m
m = np.random.randint(2, size=(n2,))
ms = decompress(m, q, 1)

# Generating private key (s) and publik keys (a,b)
s, a, b = key_gen()

# Encrypting the message using public keys to provide cipher texts (u,v)
u, v = encrypt(a, b, ms)

# Decrypt the cipher using private key to obtain back the message (d)
d = decrypt(s, u, v)

# Decoding the decrypted message
md = []
for i in d:
    if i > math.floor(q/4) and i < math.floor(3*q/4):
        md.append(1)
    else:
        md.append(0)
md = np.array(md)

end_sw = time.time()

# Comparision and printing results
print("Actual message    :\n", m)
print("Decrypted message :\n", md)

if (list(m) == list(md)):
    print("Actual message and decrypted message are the same!")
else:
    print("There is mismatch ....")

print()
print("Time taken by SW only =", end_sw - start_sw, "seconds")

Actual message    :
 [1 0 0 1 0 0 0 0 1 0 0 1 0 1 1 1 0 0 1 1 1 0 1 1 0 1 0 0 1 1 0 0 1 1 1 1 1
 0 0 1 1 1 0 1 0 1 0 1 1 0 1 0 1 0 1 0 1 1 0 1 1 0 1 0 1 1 0 1 1 1 1 0 1 0
 1 0 0 1 0 1 1 1 1 0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 1 0 1 1 0 1 0 1 0
 1 0 1 1 1 0 0 0 1 0 1 1 1 1 1 0 1 1 1 0 0 1 1 1 0 0 1 0 0 1 0 1 0 1 1 1 0
 1 0 0 0 1 0 1 0 1 0 1 0 1 1 0 0 0 0 0 0 0 1 1 1 0 1 0 0 1 0 1 1 0 1 1 1 1
 1 0 1 1 0 1 0 1 1 0 1 1 1 1 1 0 1 0 0 1 0 1 1 0 1 0 1 1 0 1 0 0 0 0 1 1 1
 0 1 1 0 0 0 0 1 0 0 1 1 1 0 0 0 1 1 1 0 1 0 1 1 0 1 1 1 0 1 1 1 1 0]
Decrypted message :
 [1 0 0 1 0 0 0 0 1 0 0 1 0 1 1 1 0 0 1 1 1 0 1 1 0 1 0 0 1 1 0 0 1 1 1 1 1
 0 0 1 1 1 0 1 0 1 0 1 1 0 1 0 1 0 1 0 1 1 0 1 1 0 1 0 1 1 0 1 1 1 1 0 1 0
 1 0 0 1 0 1 1 1 1 0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 1 0 1 0 1 0 1 1 0 1 0 1 0
 1 0 1 1 1 0 0 0 1 0 1 1 1 1 1 0 1 1 1 0 0 1 1 1 0 0 1 0 0 1 0 1 0 1 1 1 0
 1 0 0 0 1 0 1 0 1 0 1 0 1 1 0 0 0 0 0 0 0 1 1 1 0 1 0 0 1 0 1 1 0 1 1 1 1
 1 0 1 1 0 1 0 1 1 0 1 1 1 1 1 0 1 0 0 1 0 1 1 0 1 0 1 1 0 1 0 

In [23]:
# Kyber PKE functions with PolyMult in HW and rest in SW
# Key generation function (to be performed by server)
def key_gen2():
    a = np.random.randint(q, size=(k,k,n2))
    s = cbd_vector(n2, eta1, k)
    e = cbd_vector(n2, eta1, k)
    b0 = (poly_mul_hw(a[0,0], s[0]) + e[0]) % q
    b1 = (poly_mul_hw(a[0,1], s[1]) + e[1]) % q
    b2 = (poly_mul_hw(a[1,0], s[0]) + e[0]) % q
    b3 = (poly_mul_hw(a[1,1], s[1]) + e[1]) % q
    b01 = (b0 + b1) % q
    b23 = (b2 + b3) % q
    b = np.array([b01, b23])
    return s, a, b

# Encryption function (to be performed by client)
def encrypt2(a, b, m):
    r = cbd_vector(n2, eta1, k)
    e1 = cbd_vector(n2, eta2, k)
    e2 = cbd(n2, eta2)
    u0 = (poly_mul_hw(a[0,0], r[0]) + e1[0]) % q
    u1 = (poly_mul_hw(a[1,0], r[1]) + e1[1]) % q
    u2 = (poly_mul_hw(a[0,1], r[0]) + e1[0]) % q
    u3 = (poly_mul_hw(a[1,1], r[1]) + e1[1]) % q
    u01 = (u0 + u1) % q
    u23 = (u2 + u3) % q
    u = np.array([u01, u23])
    v0 = np.array(poly_mul_hw(b[0], r[0]))
    v1 = np.array(poly_mul_hw(b[1], r[1]))
    v = (v0 + v1 + e2 + m) % q
    u = compress(u, q, du)
    v = compress(v, q, dv)
    return u, v

# Decryption function (to be performed by server)
def decrypt2(s, u, v):
    u = decompress(u, q, du)
    v = decompress(v, q, dv)
    p0 = np.array(poly_mul_hw(s[0], u[0]))
    p1 = np.array(poly_mul_hw(s[1], u[1]))
    p = (p0 + p1) % q
    d = (v - p) % q
    return d

In [32]:
# Base address of control register (axilite)
base_addr = 0x43C00000
addr_range = 0x10000
mmio = MMIO(base_addr, addr_range)

# Direct memory access (axistream)
dma = ol.axi_dma_0
dma_send = dma.sendchannel
dma_recv = dma.recvchannel

start_hw = time.time()

# Randomly generated binary message, m
m = np.random.randint(2, size=(n2,))
ms = decompress(m, q, 1)

# Generating private key (s) and publik keys (a,b)
s, a, b = key_gen2()

# Encrypting the message using public keys to provide cipher texts (u,v)
u, v = encrypt2(a, b, ms)

# Decrypt the cipher using private key to obtain back the message (d)
d = decrypt2(s, u, v)

# Decoding the decrypted message
md = []
for i in d:
    if i > math.floor(q/4) and i < math.floor(3*q/4):
        md.append(1)
    else:
        md.append(0)
md = np.array(md)

end_hw = time.time()

# Comparision and printing results
print("Actual message    :\n", m)
print("Decrypted message :\n", md)

if (list(m) == list(md)):
    print("Actual message and decrypted message are the same!")
else:
    print("There is mismatch ....")
    
print()
print("Time taken by HW-SW =", end_hw - start_hw, "seconds")

Actual message    :
 [1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 1
 0 1 1 0 1 1 0 1 1 0 0 1 0 0 1 0 1 0 0 0 0 1 1 1 0 1 1 0 0 1 0 0 0 0 0 1 0
 1 1 0 0 0 0 1 0 0 1 1 1 0 0 1 1 0 1 0 1 0 1 1 0 0 0 1 0 0 0 0 0 0 1 1 0 0
 0 1 0 0 0 1 1 0 1 0 1 1 0 0 1 0 1 1 0 0 1 0 1 0 1 0 0 1 0 0 1 1 1 0 1 1 1
 0 1 0 1 0 1 0 1 0 0 1 1 0 1 1 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 0
 0 1 0 1 0 1 1 1 0 1 0 1 1 0 1 0 1 1 1 0 1 1 0 1 0 1 0 0 0 0 0 0 1 0 0 1 0
 1 1 1 1 0 1 1 1 1 1 1 1 0 1 1 0 0 0 1 1 1 0 1 1 1 0 0 0 1 0 1 1 0 1]
Decrypted message :
 [1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 1
 0 1 1 0 1 1 0 1 1 0 0 1 0 0 1 0 1 0 0 0 0 1 1 1 0 1 1 0 0 1 0 0 0 0 0 1 0
 1 1 0 0 0 0 1 0 0 1 1 1 0 0 1 1 0 1 0 1 0 1 1 0 0 0 1 0 0 0 0 0 0 1 1 0 0
 0 1 0 0 0 1 1 0 1 0 1 1 0 0 1 0 1 1 0 0 1 0 1 0 1 0 0 1 0 0 1 1 1 0 1 1 1
 0 1 0 1 0 1 0 1 0 0 1 1 0 1 1 0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 0
 0 1 0 1 0 1 1 1 0 1 0 1 1 0 1 0 1 1 1 0 1 1 0 1 0 1 0 0 0 0 0 

In [33]:
SF = (end_sw - start_sw)/(end_hw - start_hw)
print("Speed-up factor =", SF)

Speed-up factor = 7.935091636636969
