In [64]:
from pynq import Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import time
import numpy as np
import pynq.lib.dma
from pynq import allocate
import asyncio

In [98]:
file_pth = 'test.txt'
with open(file_pth, 'r') as f:
    data = f.read()
    data = data.split('\n')
    data = [int(i) for i in data if i]
    data = np.array(data)
print(data)


[   1  213  213    4  123  412  312 3124  123  312]


In [65]:
class Top_Driver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)

    bindto = ['xilinx.com:hls:TOP_MODULE:1.0']

    def compute(self, input1, input2, OP):
        self.write(0x10, input1)
        self.write(0x18, input2)
        self.write(0x30, OP)
        while self.read(0x24) == 0:
            pass
        return self.read(0x20)
overlay = Overlay('/home/xilinx/pynq/overlays/Crypto/Crypto.bit')

In [66]:
# 100ns * 10000 = 1ms
start_time = time.time()
for i in range(1):
    overlay.TOP_MODULE_0.compute(1, 2, 2)
end_time = time.time()
print("Used time: ", end_time - start_time)

Used time:  0.004181861877441406


In [67]:
top = overlay.TOP_MODULE_0

In [68]:
bram_num = 8
for i in range(bram_num):
    exec(f'bram{i} = overlay.axi_bram_ctrl_{i}')

In [69]:
# simulate CKKS encryption, polynomial degree 4096
MOD = 193
Poly_degree = 4096
a = np.random.randint(0, MOD, Poly_degree)
s = np.random.randint(0, 2, Poly_degree)
m = np.random.randint(0, MOD, Poly_degree)
e = np.random.randint(0, 2, Poly_degree)

start_time = time.time()
enc1 = (a*s)%193
enc2 = (enc1 + m)%193
enc = (enc2 + e)%193
end_time = time.time()
print("Time: ", end_time - start_time)

Time:  0.004825592041015625


In [70]:
bram_size = 8192/4
# write a to BRAM
for i in range(Poly_degree):
    temp = int(a[i])
    if i < 2048:
        bram0.write(i*4, temp)
    else:
        bram1.write((i-2048)*4, temp)
        
# write s to BRAM
for i in range(Poly_degree):
    temp = int(s[i])
    if i < 2048:
        bram2.write(i*4, temp)
    else:
        bram3.write((i-2048)*4, temp)
        
# write e to BRAM
for i in range(Poly_degree):
    temp = int(e[i])
    if i < 2048:
        bram6.write(i*4, temp)
    else:
        bram7.write((i-2048)*4, temp)
        
start_time = time.time()
# write m to BRAM
write_start_time = time.time()
for i in range(Poly_degree):
    temp = int(m[i])
    if i < 2048:
        bram4.write(i*4, temp)
    else:
        bram5.write((i-2048)*4, temp)
write_end_time = time.time()
        



In [71]:
# read a, s from BRAM and compute a*s
for i in range(4096):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram2.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram3.read((i-2048)*4)
        
    temp3 = top.compute(temp1, temp2, 2)
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)





In [72]:
      
# read a*s from BRAM and compute a*s + m
for i in range(Poly_degree):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram4.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram5.read((i-2048)*4)
        
    temp3 = top.compute(temp1, temp2, 0)
    
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)



In [73]:


# read a*s + m from BRAM and compute a*s + m + e
for i in range(4096):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram6.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram7.read((i-2048)*4)
    temp3 = top.compute(temp1, temp2, 0)
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)
end_time = time.time()
print("Time: ", end_time - start_time)
print("Time: ", write_end_time - write_start_time)

# Read the result from BRAM0, BRAM1
result = np.zeros(4096)
for i in range(4096):
    if i < 2048:
        result[i] = bram0.read(i*4)
    else:
        result[i] = bram1.read((i-2048)*4)
        
if np.array_equal(enc, result):
    print("a*s + m + e is correct")
else:
    print("a*s + m + e is wrong")
        

Time:  2.6975605487823486
Time:  0.12714123725891113
a*s + m + e is correct


In [None]:
time1 = 4096 /100000000
print("Time (s): ", time1)

bram_time = time.time()
for i in range(bram0.size//4):
    num = np.random.randint(0, 100)
    bram0.write(i*4, num)
bram_time = time.time() - bram_time
print("BRAM write time: ", bram_time)

ddr_time = time.time()
for i in range(4096):
    num = np.random.randint(0, 100)
    overlay.PSDDR.write(i*4, num)
ddr_time = time.time() - ddr_time
print("DDR write time: ", ddr_time)
    


Time (s):  4.096e-05
BRAM write time:  0.21050381660461426
DDR write time:  0.5393013954162598


In [85]:

timer1 = overlay.axi_timer_0
async def wait_for_timer1(cycles):
    timer1.register_map.TLR0 = cycles
    timer1.register_map.TCSR0.LOAD0 = 1
    timer1.register_map.TCSR0.LOAD0 = 0
    timer1.register_map.TCSR0.ENIT0 = 1
    timer1.register_map.TCSR0.ENT0 = 1
    timer1.register_map.TCSR0.UDT0 = 1
    await timer1.interrupt.wait()
    timer1.register_map.TCSR0.T0INT = 1
    
wait_for_timer1(1000000)

<coroutine object wait_for_timer1 at 0xaa1e1bf8>