In [17]:
from pynq import Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import time
import numpy as np
import pynq.lib.dma
from pynq import allocate

import asyncio

In [18]:
class Top_Driver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)

    bindto = ['xilinx.com:hls:TOP_MODULE:1.0']

    def compute(self, input1, input2, OP):
        self.write(0x10, input1)
        self.write(0x18, input2)
        self.write(0x30, OP)
        return self.read(0x20)
overlay = Overlay('/home/xilinx/pynq/overlays/Crypto/Crypto.bit')

In [19]:
top = overlay.TOP_MODULE_0

In [20]:
bram_num = 8
for i in range(bram_num):
    exec(f'bram{i} = overlay.axi_bram_ctrl_{i}')

In [21]:
# simulate CKKS encryption, polynomial degree 4096
MOD = 193
Poly_degree = 4096
a = np.random.randint(0, MOD, Poly_degree)
s = np.random.randint(0, 2, Poly_degree)
m = np.random.randint(0, MOD, Poly_degree)
e = np.random.randint(0, 2, Poly_degree)

start_time = time.time()
enc1 = (a*s)%193
enc2 = (enc1 + m)%193
enc = (enc2 + e)%193
end_time = time.time()
print("Time: ", end_time - start_time)

Time:  0.005315065383911133


In [22]:
bram_size = 8192/4
# write a to BRAM
for i in range(Poly_degree):
    temp = int(a[i])
    if i < 2048:
        bram0.write(i*4, temp)
    else:
        bram1.write((i-2048)*4, temp)
        
# write s to BRAM
for i in range(Poly_degree):
    temp = int(s[i])
    if i < 2048:
        bram2.write(i*4, temp)
    else:
        bram3.write((i-2048)*4, temp)
        
# write e to BRAM
for i in range(Poly_degree):
    temp = int(e[i])
    if i < 2048:
        bram6.write(i*4, temp)
    else:
        bram7.write((i-2048)*4, temp)
        
start_time = time.time()
# write m to BRAM
write_start_time = time.time()
for i in range(Poly_degree):
    temp = int(m[i])
    if i < 2048:
        bram4.write(i*4, temp)
    else:
        bram5.write((i-2048)*4, temp)
write_end_time = time.time()
        



In [23]:
# read a, s from BRAM and compute a*s
for i in range(4096):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram2.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram3.read((i-2048)*4)
        
    temp3 = top.compute(temp1, temp2, 2)
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)





In [24]:
      
# read a*s from BRAM and compute a*s + m
for i in range(Poly_degree):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram4.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram5.read((i-2048)*4)
        
    temp3 = top.compute(temp1, temp2, 0)
    
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)



In [25]:


# read a*s + m from BRAM and compute a*s + m + e
for i in range(4096):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram6.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram7.read((i-2048)*4)
    temp3 = top.compute(temp1, temp2, 0)
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)
end_time = time.time()
print("Time: ", end_time - start_time)
print("Time: ", write_end_time - write_start_time)

# Read the result from BRAM0, BRAM1
result = np.zeros(4096)
for i in range(4096):
    if i < 2048:
        result[i] = bram0.read(i*4)
    else:
        result[i] = bram1.read((i-2048)*4)
        
if np.array_equal(enc, result):
    print("a*s + m + e is correct")
else:
    print("a*s + m + e is wrong")
        

Time:  2.484999418258667
Time:  0.12742114067077637
a*s + m + e is correct
