In [18]:

from pynq import  Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import numpy as np
import time
import random
import sympy
import socket
import json
import time
from enum import IntEnum
import logging
from contextlib import contextmanager

# -------------------- Define the constants --------------------
MODULUS         = [1073750017, 1073815553, 1073872897]
MODULUS_ROOT    = [625534531, 646391299, 647613940]
MODULUS_INV     = [627281114, 777819041, 538279817]

PLAINTEXT_MODULUS = 65537
DELTA = 100
BRAM_COUNT = 4
POLY_DEGREE = 4096
TF_SIZE = 2048
DATA_WIDTH = 32
BUFFER_SIZE = 4096 * 3

# -------------------- Configure the logger --------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
time_logger = logging.getLogger('time')
time_logger.setLevel(logging.INFO)

# ---------------------- Timer ----------------------
@contextmanager
def time_measure(name):
    start = time.perf_counter()
    yield
    elapsed = (time.perf_counter() - start) * 1000  # 毫秒
    time_logger.info(f"{name.ljust(30)}: {elapsed:.3f} ms")



In [19]:
# -------------------- Define the enums --------------------
class OpCode(IntEnum):
    ADD_MOD                     = 0
    SUB_MOD                     = 1
    MUL_MOD                     = 2
    NTT                         = 3
    INTT                        = 4
    POLY_WRITE                  = 5
    POLY_READ                   = 6
    TWIDDLE_WRITE               = 7
    POLY_MOD_PLAINTEXTMODULUS   = 8
    
class CryptoRegisters(IntEnum):
    CONTROL = 0x0000
    RAM_SELECT = 0x0010
    RAM_SELECT1 = 0x0018
    OP_CODE = 0x0020
    

In [20]:
# -------------------- Define Twiddle Factor Generator --------------------
class TwiddleFactorGenerator:
    "Generate twiddle factors for NTT and INTT"
    def __init__(self):
        self.modulus = MODULUS
        self.root = MODULUS_ROOT
        self.root_inv = MODULUS_INV
        
        self.twiddle_factors, self.inv_twiddle_factors = self._generate_factors()
    
    def _generate_factors(self):
        "Generate twiddle factors"
        all_factors = []
        all_inv_factors = []
        for i in range(len(self.modulus)):
            factors, inv_factors = self._generate_single_factor(i)
            all_factors.append(factors)
            all_inv_factors.append(inv_factors)
        return all_factors, all_inv_factors
        
    def _generate_single_factor(self, idx):
        "Generate twiddle factors for a single modulus"
        factors = []
        inv_factors = []
        modulus = self.modulus[idx]
        root = self.root[idx]
        root_inv = self.root_inv[idx]
        
        for i in range(TF_SIZE):
            factor = pow(root, i, modulus)
            inv_factor = pow(root_inv, i, modulus)
            factors.append(factor)
            inv_factors.append(inv_factor)
        return factors, inv_factors

In [21]:
# -------------------- Define Hardware Driver --------------------
class RandomGeneratorDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)    
    bindto = ['xilinx.com:hls:random_generator:1.0']  # 匹配Vitis HLS生成的IP ID
        
    def set_seed(self, seed):
        self.write(0x10, seed)
        
    def start_generation(self):
        self.write(0x00, 0x01)
        
    def check_ready(self):
        return self.read(0x00) & 0x2
    
    def generate_random(self, seed:int, timeout:float = 5.0):
        "Generate a random number with a given seed (With timeout)"
        self.set_seed(seed)
        self.start_generation()
        
        start_time = time.time() 
        while not self.check_ready():
            if time.time() - start_time > timeout:
                raise TimeoutError("Random number generation timeout")
    
    def read_random_data(self) -> np.ndarray:
        "Read the random data from the BRAM with POLY_DEGREE"
        buf = np.zeros(POLY_DEGREE, dtype=np.uint32)
        # for each random number, which is 32-bit
        # I want to write each bit to each address of the BRAM
        
        for i in range(POLY_DEGREE//DATA_WIDTH):
            for j in range(DATA_WIDTH):
                buf[i * DATA_WIDTH + j] = (self.read(0x1000 + i * 4) >> j) & 0x1
        return buf
    

In [22]:
overlay = Overlay("/home/xilinx/pynq/overlays/Crypto/Crypto.bit")
class CryptoDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)
    bindto = ['xilinx.com:hls:Crypto:1.0']
    
    TF_BASE     = 0x08000
    TF_INV_BASE = 0x20000
    POLY_BASE   = 0x10000
    
    dma_send_buf = allocate(shape=(BUFFER_SIZE,), dtype=np.uint32)
    dma_recv_buf = allocate(shape=(BUFFER_SIZE,), dtype=np.uint32)
    
    def start_operation(self):
        self.write(CryptoRegisters.CONTROL, 0x1)
        
    def check_busy(self):
        return (self.read(CryptoRegisters.CONTROL) & 0x4) == 0
    
    def _setup_operation(self, addr: CryptoRegisters, value: int):
        "Setup the register value"
        self.write(addr, int(value))
        
    def _wait_operation(self, timeout: float = 5.0):
        "Wait for the operation to finish"
        start_time = time.time()
        while self.check_busy():
            if time.time() - start_time > timeout:
                raise TimeoutError("Operation timeout")
            pass
        
    def load_twiddle_factor(self, tf: list, tf_inv: list):
        "Load twiddle factors to the BRAM"

        for i in range(3):
            for j in range(TF_SIZE):
                self.write(self.TF_BASE     + i * TF_SIZE * 4 + j * 4,  tf[i][j])
                self.write(self.TF_INV_BASE + i * TF_SIZE * 4 + j * 4,  tf_inv[i][j])
        
        self._setup_operation(CryptoRegisters.OP_CODE, 7)
        self.start_operation()
        self._wait_operation()
     
        
    def load_polynomial(self, dma, poly: list, ram_sel: int = 0):
        "Load the polynomial to the BRAM"
 
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.OP_CODE, 5)
        self.start_operation()
        self.dma_send_buf[:] = poly
        dma.sendchannel.transfer(self.dma_send_buf)
        dma.sendchannel.wait()
        
    def execute_operation(self, op_code: OpCode, ram_sel: int = 0, ram_sel1: int = 0):
        "Execute the operation"
        self._setup_operation(CryptoRegisters.OP_CODE, op_code)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.RAM_SELECT1, ram_sel1)
        self.start_operation()
        self._wait_operation()
        
    def read_result(self, dma, ram_sel: int = 0):
        self.dma_recv_buf[:] = 0
        dma.recvchannel.transfer(self.dma_recv_buf)
        self._setup_operation(CryptoRegisters.OP_CODE, OpCode.POLY_READ)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self.start_operation()
        return self.dma_recv_buf
    

In [23]:

# Initialize the overlay
logger.info("Initializing the Crypto System")
overlay = Overlay("/home/xilinx/pynq/overlays/Crypto/Crypto.bit")

# Initialize the drivers
crypto = overlay.Crypto_0
random_gen = overlay.random_generator_0
dma = overlay.axi_dma_0


# Generate the twiddle factors
with time_measure("Generate Twiddle Factors"):
    tf_gen = TwiddleFactorGenerator()
    crypto.load_twiddle_factor(tf_gen.twiddle_factors, tf_gen.inv_twiddle_factors)

with time_measure("Generate Random Numbers NTT_A"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_a = random_gen.read_random_data() 
    crypto.load_polynomial(dma, random_num_a.tolist() * 3, 0)
    crypto.execute_operation(OpCode.NTT, 0)
    
with time_measure("Generate Random Numbers NTT_S"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_s = random_gen.read_random_data()
    crypto.load_polynomial(dma, random_num_s.tolist()*3, 1)
    crypto.execute_operation(OpCode.NTT, 1)
    
with time_measure("NTT A * S"):
    crypto.execute_operation(OpCode.MUL_MOD, 0, 1)
    
with time_measure("Generate Random Numbers NTT_E"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_e = random_gen.read_random_data()
    crypto.load_polynomial(dma, random_num_e.tolist()*3, 1)
    crypto.execute_operation(OpCode.NTT, 1)

with time_measure("Add the result to E"):
    crypto.execute_operation(OpCode.ADD_MOD, 0, 1)
    
with time_measure("Load Delta"):
    delta = [DELTA] * POLY_DEGREE * 3
    crypto.load_polynomial(dma, delta, 1)


INFO:__main__:Initializing the Crypto System
INFO:time:Generate Twiddle Factors      : 718.299 ms
INFO:time:Generate Random Numbers NTT_A : 131.339 ms
INFO:time:Generate Random Numbers NTT_S : 131.759 ms
INFO:time:NTT A * S                     : 0.288 ms
INFO:time:Generate Random Numbers NTT_E : 134.164 ms
INFO:time:Add the result to E           : 0.277 ms
INFO:time:Load Delta                    : 10.366 ms


In [27]:
from socket import *
import json
import struct
import time
from contextlib import contextmanager

# 假设这些都已经正确导入或定义
# from fpga_interface import OpCode, crypto, dma

# Constants
POLY_DEGREE = 4096
BUFFER_SIZE = 4096 * 4
MAX_REQUESTS = 100  # 最大处理请求数

# 定义计时器类来累计时间
class TimeStats:
    def __init__(self):
        self.load_time = 0
        self.encrypt_time = 0
        self.read_time = 0
        self.count = 0
    
    def add_time(self, operation, duration):
        if operation == "Load Message":
            self.load_time += duration
        elif operation == "Encrypt the message":
            self.encrypt_time += duration
        elif operation == "Read Encrypted Result":
            self.read_time += duration
    
    def print_stats(self):
        if self.count == 0:
            print("No operations performed")
            return
        
        print("\n===== Performance Statistics =====")
        print(f"Total requests processed: {self.count}")
        print(f"Average Load Message time: {self.load_time/self.count:.3f} ms")
        print(f"Average Encrypt Message time: {self.encrypt_time/self.count:.3f} ms")
        print(f"Average Read Result time: {self.read_time/self.count:.3f} ms")
        print(f"Average total operation time: {(self.load_time + self.encrypt_time + self.read_time)/self.count:.3f} ms")
        print("=================================")

# 全局时间统计对象
time_stats = TimeStats()

@contextmanager
def time_measure(operation):
    """Context manager to measure execution time"""
    start = time.time()
    try:
        yield
    finally:
        end = time.time()
        duration = (end - start) * 1000  # 转换为毫秒
        print(f"INFO:time:{operation:<30}: {duration:.3f} ms")
        time_stats.add_time(operation, duration)

def send_numbers(sock, numbers):
    """Send numbers list with length header"""
    data = json.dumps(numbers).encode('utf-8')
    sock.send(struct.pack('!I', len(data)))  # Send data length
    sock.send(data)    
    
def handle_client(sock, crypto):
    try:
        # Receive numbers from client
        length_data = sock.recv(4)
        if not length_data:
            return
        data_length = struct.unpack('!I', length_data)[0]
        
        received = bytearray()
        while len(received) < data_length:
            remaining = data_length - len(received)
            received += sock.recv(min(remaining, BUFFER_SIZE))
        
        numbers = json.loads(received.decode('utf-8'))
        
        # Validate and pad data
        if len(numbers) != POLY_DEGREE:
            numbers = numbers[:POLY_DEGREE] + [0]*(POLY_DEGREE - len(numbers))
        print(f"Received {len(numbers)} numbers")
        print(f"First 10 numbers: {numbers[:10]}")
        
        # FPGA Encryption Process
        # 1. Prepare polynomial data (replicate 3x for FPGA buffer)
        poly_data = numbers * 3
    
        # 2. Load data to FPGA
        with time_measure("Load Message"):
            crypto.load_polynomial(dma, poly_data, 2)    
    
        # 3. Execute encryption operations
        with time_measure("Encrypt the message"):
            crypto.execute_operation(OpCode.MUL_MOD, 2, 1)
            crypto.execute_operation(OpCode.POLY_MOD_PLAINTEXTMODULUS, 2)  
            crypto.execute_operation(OpCode.NTT, 2)
            crypto.execute_operation(OpCode.ADD_MOD, 2, 0)
        
        # 4. Read encrypted result from FPGA
        with time_measure("Read Encrypted Result"):
            encrypted_result = crypto.read_result(dma, 2)
            print(encrypted_result[:10])
            
        # 可选：将结果发送回客户端
        result_list = [int(x) for x in encrypted_result]
        send_numbers(sock, result_list)
        
        # 更新请求计数
        time_stats.count += 1
        
    except Exception as e:
        print(f"Client handling error: {str(e)}")

def main():
    server_address = ('192.168.2.99', 123)
    
    with socket(AF_INET, SOCK_STREAM) as server_sock:
        server_sock.bind(server_address)
        server_sock.listen(1)
        print(f"FPGA Server listening... (will process {MAX_REQUESTS} requests)")
        
        try:
            request_count = 0
            while request_count < MAX_REQUESTS:
                client_sock, addr = server_sock.accept()
                print(f"Connection from: {addr} [{request_count+1}/{MAX_REQUESTS}]")
                handle_client(client_sock, crypto)
                client_sock.close()
                request_count += 1
                
            print(f"\nCompleted {MAX_REQUESTS} requests")
            time_stats.print_stats()
                
        except KeyboardInterrupt:
            print("\nServer terminated by user")
        except Exception as e:
            print(f"Server error: {str(e)}")
        finally:
            # 如果提前退出，仍然打印统计信息
            if time_stats.count > 0:
                time_stats.print_stats()
            server_sock.close()

if __name__ == "__main__":
    main()

FPGA Server listening... (will process 100 requests)
Connection from: ('192.168.2.1', 40752) [1/100]
Received 4096 numbers
First 10 numbers: [10008, 9268, 10694, 1203, 3246, 10766, 5392, 4004, 7089, 5475]
INFO:time:Load Message                  : 9.904 ms
INFO:time:Encrypt the message           : 19.985 ms
[133346330 961680399 692215205 764693654 869828802 111731800 734148517
 329528522 743380351 356902276]
INFO:time:Read Encrypted Result         : 1.821 ms
Connection from: ('192.168.2.1', 40766) [2/100]
Received 4096 numbers
First 10 numbers: [3402, 9957, 5200, 4756, 12182, 3490, 4029, 8359, 939, 7652]
INFO:time:Load Message                  : 9.855 ms
INFO:time:Encrypt the message           : 19.982 ms
[ 131980913   95150106 1053059887  515908888  698904395    8894063
  211352371  972369323  343295122  357721447]
INFO:time:Read Encrypted Result         : 1.825 ms
Connection from: ('192.168.2.1', 40768) [3/100]
Received 4096 numbers
First 10 numbers: [9175, 539, 10051, 9803, 7261, 919