In [1]:
from pynq import Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import time
import random
import numpy as np
import pynq.lib.dma
from pynq import allocate
import asyncio

In [2]:
overlay = Overlay('/home/xilinx/pynq/overlays/Crypto/Crypto.bit')
overlay?

[0;31mType:[0m            Overlay
[0;31mString form:[0m     <pynq.overlay.Overlay object at 0xb38b1760>
[0;31mFile:[0m            /usr/local/share/pynq-venv/lib/python3.10/site-packages/pynq/overlay.py
[0;31mDocstring:[0m      
Default documentation for overlay /home/xilinx/pynq/overlays/Crypto/Crypto.bit. The following
attributes are available on this overlay:

IP Blocks
----------
Top_0                : pynq.overlay.DefaultIP
encode_0             : pynq.overlay.DefaultIP
random_generator_0   : pynq.overlay.DefaultIP
processing_system7_0 : pynq.overlay.DefaultIP

Hierarchies
-----------
None

Interrupts
----------
None

GPIO Outputs
------------
None

Memories
------------
axi_bram_ctrl_0      : Memory
axi_bram_ctrl_1      : Memory
axi_bram_ctrl_2      : Memory
axi_bram_ctrl_3      : Memory
axi_bram_ctrl_4      : Memory
PSDDR                : Memory
[0;31mClass docstring:[0m
This class keeps track of a single bitstream's state and contents.

The overlay class holds the state

In [4]:
poly = []
with open('./poly.txt', 'r') as f:
    for line in f:
        poly.append(int(line))
f.close()
basis = []
with open('./basis.txt', 'r') as f:
    for line in f:
        basis.append(int(line))
f.close()
golden_ret = []
with open('./ret.txt', 'r') as f:
    for line in f:
        golden_ret.append(int(line))
f.close()

class Encode():
    def __init__(self):
        self.encode_module = overlay.encode_0
        self.poly_start_addr = 0x4000
        self.basis_start_addr = 0x8000
        self.ret_start_addr = 0xC000
        
    def write_poly(self, poly):
        for i in range(4096):
            self.encode_module.write(self.poly_start_addr + i*4, poly[i])
            
    def write_basis(self, basis):
        for i in range(4096):
            self.encode_module.write(self.basis_start_addr + i*4, basis[i])
            
    def read_ret(self):
        ret = []
        for i in range(4096):
            ret.append(self.encode_module.read(self.ret_start_addr + i*4))
        return ret
    
    def encode(self, poly, basis):
        '''
        0x0000 : Control signals
                bit 0  - ap_start (Read/Write/COH)
                bit 1  - ap_done (Read/COR)
                bit 2  - ap_idle (Read)
                bit 3  - ap_ready (Read/COR)
                bit 7  - auto_restart (Read/Write)
                bit 9  - interrupt (Read)
                others - reserved
        '''
        print('Writing poly')
        self.write_poly(poly)
        print('Writing basis')
        self.write_basis(basis)
        print('Starting encoding')
        self.encode_module.write(0x00, 0x1)
        start = time.time()
        while not self.encode_module.read(0x00) & 0x2:
            pass
        end = time.time()
        print('Encoding done in', end-start, 's')
        print('Encoding done')
        return self.read_ret()
    
print('Start encoding')
encode = Encode()
ret = encode.encode(poly, basis)

print('Encoding done')
print('Checking correctness')
for i in range(4096):
    if ret[i] != golden_ret[i]:
        print('Error at index', i)
        break
print('All correct')



Start encoding
Writing poly
Writing basis
Starting encoding
Encoding done in 19.28334069252014 s
Encoding done
Encoding done
Checking correctness
All correct


In [2]:
class Top_Driver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)

    bindto = ['xilinx.com:hls:Top:1.0']

    def compute(self, input1, input2, OP, PENum):
        # 写入操作码
        self.write(0x10, OP)

        # 写入 input1 数据
        for i, value in enumerate(input1):
            self.write(0x200 + i * 4, value)

        # 写入 input2 数据
        for i, value in enumerate(input2):
            self.write(0x400 + i * 4, value)

        # 读取结果
        result = [self.read(0x600 + i * 4) for i in range(PENum)]

        return result

    def compute_pipeline(self, input1_list, input2_list, OP_list, PENum):
        results = []
        for input1, input2, OP in zip(input1_list, input2_list, OP_list):
            # 写入操作码
            self.write(0x10, OP)

            # 写入 input1 数据
            for i, value in enumerate(input1):
                self.write(0x200 + i * 4, value)

            # 写入 input2 数据
            for i, value in enumerate(input2):
                self.write(0x400 + i * 4, value)

            # 读取结果
            result = [self.read(0x600 + i * 4) for i in range(PENum)]
            results.append(result)

        return results
    
# class RNG_Driver(DefaultIP):
#     def __init__(self, description):
#         super().__init__(description=description)

#     bindto = ['xilinx.com:hls:random_generator:1.0']

#     def generate_random_number(self, seed):
#         self.write(0x10, seed)
#         self.write(0x00, 1)  # 启动随机数生成器
#         while not (self.read(0x1c) & 0x1):  # 等待 out_data_ap_vld 变为 1
#             pass
#         random_number = self.read(0x18)
#         self.read(0x1c)  # 清除完成标志
#         return random_number

overlay = Overlay('/home/xilinx/pynq/overlays/Crypto/Crypto.bit')

# class BRAM_Driver():
#     def __init__(self):
#         bram_num = 8
#         for i in range(bram_num):
#             exec(f'bram{i} = overlay.axi_bram_ctrl_{i}')
        


In [None]:
import random

# 定义操作码
class OP:
    ADD = 0
    SUB = 1
    MUL = 2

# 模数列表
moulus = [1073750017, 1073815553, 1073872897]
PENum = 32

# 创建 Top_Driver 实例
top = overlay.Top_0

# 生成随机输入数据
input1_list, input2_list, OP_list = [], [], []
for _ in range(3):  # 生成三组数据
    input1, input2 = [], []
    for i in range(len(moulus)):
        for j in range(PENum):
            input1.append(random.randint(0, moulus[i] - 1))
            input2.append(random.randint(0, moulus[i] - 1))
    input1_list.append(input1)
    input2_list.append(input2)
    OP_list.append(OP.MUL)  # 使用乘法操作

# 使用 compute 方法计算结果
naive_time = time.time()
results_compute = []
for input1, input2, op in zip(input1_list, input2_list, OP_list):
    result = top.compute(input1, input2, op, PENum)
    results_compute.append(result)
naive_time = time.time() - naive_time
# 使用 compute_pipeline 方法计算结果
pipeline_time = time.time()
results_pipeline = top.compute_pipeline(input1_list, input2_list, OP_list, PENum)
pipeline_time = time.time() - pipeline_time


# 计算正确结果
correct_results = []
standard_time = time.time()
for input1, input2, op in zip(input1_list, input2_list, OP_list):
    correct_result = []
    if op == OP.ADD:
        for i in range(len(moulus)):
            for j in range(PENum):
                correct_result.append((input1[i * PENum + j] + input2[i * PENum + j]) % moulus[i])
    elif op == OP.SUB:
        for i in range(len(moulus)):
            for j in range(PENum):
                correct_result.append((input1[i * PENum + j] - input2[i * PENum + j]) % moulus[i])
    elif op == OP.MUL:
        for i in range(len(moulus)):
            for j in range(PENum):
                correct_result.append((input1[i * PENum + j] * input2[i * PENum + j]) % moulus[i])
    correct_results.append(correct_result)
standard_time = time.time() - standard_time
results_compute = np.array(results_compute).flatten()
results_pipeline = np.array(results_pipeline).flatten()
correct_results = np.array(correct_results).flatten()

print(f"Standard method time: {standard_time}")
print(f"Naive method time: {naive_time}")
print(f"Pipeline method time: {pipeline_time}")


# # 打印计算结果和正确结果
# print(f"Computation results (compute method): {results_compute}")
# print(f"Computation results (compute_pipeline method): {results_pipeline}")
# print(f"Correct results: {correct_results}")

# 验证结果是否正确
assert results_compute.all() == correct_results.all() , "The computation result using compute method is incorrect!"
assert results_pipeline.all()  == correct_results.all() , "The computation result using compute_pipeline method is incorrect!"
print("The computation results are correct!")

In [None]:
rng = overlay.random_generator_0
for i in range(1):
    seed = np.random.randint(0, 100)
    print(rng.generate_random_number(seed))
    bram0.write(0, rng.generate_random_number(seed))
    

In [None]:
print(bram0.read(0))

In [117]:
bram_num = 8
for i in range(bram_num):
    exec(f'bram{i} = overlay.axi_bram_ctrl_{i}')

In [None]:
# simulate CKKS encryption, polynomial degree 4096
MOD = 193
Poly_degree = 4096
a = np.random.randint(0, MOD, Poly_degree)
s = np.random.randint(0, 2, Poly_degree)
m = np.random.randint(0, MOD, Poly_degree)
e = np.random.randint(0, 2, Poly_degree)

start_time = time.time()
enc1 = (a*s)%193
enc2 = (enc1 + m)%193
enc = (enc2 + e)%193
end_time = time.time()
print("Time: ", end_time - start_time)

In [119]:
bram_size = 8192/4
# write a to BRAM
for i in range(Poly_degree):
    temp = int(a[i])
    if i < 2048:
        bram0.write(i*4, temp)
    else:
        bram1.write((i-2048)*4, temp)
        
# write s to BRAM
for i in range(Poly_degree):
    temp = int(s[i])
    if i < 2048:
        bram2.write(i*4, temp)
    else:
        bram3.write((i-2048)*4, temp)
        
# write e to BRAM
for i in range(Poly_degree):
    temp = int(e[i])
    if i < 2048:
        bram6.write(i*4, temp)
    else:
        bram7.write((i-2048)*4, temp)
        
start_time = time.time()
# write m to BRAM
write_start_time = time.time()
for i in range(Poly_degree):
    temp = int(m[i])
    if i < 2048:
        bram4.write(i*4, temp)
    else:
        bram5.write((i-2048)*4, temp)
write_end_time = time.time()
        



In [120]:
# read a, s from BRAM and compute a*s
for i in range(4096):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram2.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram3.read((i-2048)*4)
        
    temp3 = top.compute(temp1, temp2, 2)
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)





In [121]:
      
# read a*s from BRAM and compute a*s + m
for i in range(Poly_degree):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram4.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram5.read((i-2048)*4)
        
    temp3 = top.compute(temp1, temp2, 0)
    
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)



In [None]:


# read a*s + m from BRAM and compute a*s + m + e
for i in range(4096):
    if i < 2048:
        temp1 = bram0.read(i*4)
        temp2 = bram6.read(i*4)
    else:
        temp1 = bram1.read((i-2048)*4)
        temp2 = bram7.read((i-2048)*4)
    temp3 = top.compute(temp1, temp2, 0)
    if i < 2048:
        bram0.write(i*4, temp3)
    else:
        bram1.write((i-2048)*4, temp3)
end_time = time.time()
print("Time: ", end_time - start_time)
print("Time: ", write_end_time - write_start_time)

# Read the result from BRAM0, BRAM1
result = np.zeros(4096)
for i in range(4096):
    if i < 2048:
        result[i] = bram0.read(i*4)
    else:
        result[i] = bram1.read((i-2048)*4)
        
if np.array_equal(enc, result):
    print("a*s + m + e is correct")
else:
    print("a*s + m + e is wrong")
        

In [None]:
time1 = 4096 /100000000
print("Time (s): ", time1)

bram_time = time.time()
for i in range(bram0.size//4):
    num = np.random.randint(0, 100)
    bram0.write(i*4, num)
bram_time = time.time() - bram_time
print("BRAM write time: ", bram_time)

ddr_time = time.time()
for i in range(4096):
    num = np.random.randint(0, 100)
    overlay.PSDDR.write(i*4, num)
ddr_time = time.time() - ddr_time
print("DDR write time: ", ddr_time)
    


In [None]:

timer1 = overlay.axi_timer_0
async def wait_for_timer1(cycles):
    timer1.register_map.TLR0 = cycles
    timer1.register_map.TCSR0.LOAD0 = 1
    timer1.register_map.TCSR0.LOAD0 = 0
    timer1.register_map.TCSR0.ENIT0 = 1
    timer1.register_map.TCSR0.ENT0 = 1
    timer1.register_map.TCSR0.UDT0 = 1
    await timer1.interrupt.wait()
    timer1.register_map.TCSR0.T0INT = 1
    
wait_for_timer1(1000000)