In [1]:

from pynq import  Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import numpy as np
import time
import random
import sympy
import socket
import json
import time
from enum import IntEnum
import logging
from contextlib import contextmanager

# -------------------- Define the constants --------------------
MODULUS         = [1073750017, 1073815553, 1073872897]
MODULUS_ROOT    = [625534531, 646391299, 647613940]
MODULUS_INV     = [627281114, 777819041, 538279817]

PLAINTEXT_MODULUS = 65537
DELTA = 100
BRAM_COUNT = 4
POLY_DEGREE = 4096
TF_SIZE = 2048
DATA_WIDTH = 32

# -------------------- Configure the logger --------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
time_logger = logging.getLogger('time')
time_logger.setLevel(logging.INFO)

# ---------------------- Timer ----------------------
@contextmanager
def time_measure(name):
    start = time.perf_counter()
    yield
    elapsed = (time.perf_counter() - start) * 1000  # 毫秒
    time_logger.info(f"{name.ljust(30)}: {elapsed:.3f} ms")



In [2]:
# -------------------- Define the enums --------------------
class OpCode(IntEnum):
    ADD_MOD                     = 0
    SUB_MOD                     = 1
    MUL_MOD                     = 2
    NTT                         = 3
    INTT                        = 4
    POLY_WRITE                  = 5
    POLY_READ                   = 6
    TWIDDLE_WRITE               = 7
    POLY_MOD_PLAINTEXTMODULUS   = 8
    
class CryptoRegisters(IntEnum):
    CONTROL = 0x0000
    RAM_SELECT = 0x0010
    RAM_SELECT1 = 0x0018
    OP_CODE = 0x0020
    

In [3]:
# -------------------- Define Twiddle Factor Generator --------------------
class TwiddleFactorGenerator:
    "Generate twiddle factors for NTT and INTT"
    def __init__(self):
        self.modulus = MODULUS
        self.root = MODULUS_ROOT
        self.root_inv = MODULUS_INV
        
        self.twiddle_factors, self.inv_twiddle_factors = self._generate_factors()
    
    def _generate_factors(self):
        "Generate twiddle factors"
        all_factors = []
        all_inv_factors = []
        for i in range(len(self.modulus)):
            factors, inv_factors = self._generate_single_factor(i)
            all_factors.append(factors)
            all_inv_factors.append(inv_factors)
        return all_factors, all_inv_factors
        
    def _generate_single_factor(self, idx):
        "Generate twiddle factors for a single modulus"
        factors = []
        inv_factors = []
        modulus = self.modulus[idx]
        root = self.root[idx]
        root_inv = self.root_inv[idx]
        
        for i in range(TF_SIZE):
            factor = pow(root, i, modulus)
            inv_factor = pow(root_inv, i, modulus)
            factors.append(factor)
            inv_factors.append(inv_factor)
        return factors, inv_factors

In [4]:
# -------------------- Define Hardware Driver --------------------
class RandomGeneratorDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)    
    bindto = ['xilinx.com:hls:random_generator:1.0']  # 匹配Vitis HLS生成的IP ID
        
    def set_seed(self, seed):
        self.write(0x10, seed)
        
    def start_generation(self):
        self.write(0x00, 0x01)
        
    def check_ready(self):
        return self.read(0x00) & 0x2
    
    def generate_random(self, seed:int, timeout:float = 5.0):
        "Generate a random number with a given seed (With timeout)"
        self.set_seed(seed)
        self.start_generation()
        
        start_time = time.time() 
        while not self.check_ready():
            if time.time() - start_time > timeout:
                raise TimeoutError("Random number generation timeout")
    
    def read_random_data(self) -> np.ndarray:
        "Read the random data from the BRAM with POLY_DEGREE"
        buf = np.zeros(POLY_DEGREE, dtype=np.uint32)
        # for each random number, which is 32-bit
        # I want to write each bit to each address of the BRAM
        
        for i in range(POLY_DEGREE//DATA_WIDTH):
            for j in range(DATA_WIDTH):
                buf[i * DATA_WIDTH + j] = (self.read(0x1000 + i * 4) >> j) & 0x1
        return buf
    

In [5]:
class CryptoDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)
    bindto = ['xilinx.com:hls:Crypto:1.0']
    
    TF_BASE     = 0x08000
    TF_INV_BASE = 0x20000
    POLY_BASE   = 0x10000
    
    def start_operation(self):
        self.write(CryptoRegisters.CONTROL, 0x1)
        
    def check_busy(self):
        return (self.read(CryptoRegisters.CONTROL) & 0x4) == 0
    
    def _setup_operation(self, addr: CryptoRegisters, value: int):
        "Setup the register value"
        self.write(addr, int(value))
        
    def _wait_operation(self, timeout: float = 5.0):
        "Wait for the operation to finish"
        start_time = time.time()
        while self.check_busy():
            if time.time() - start_time > timeout:
                raise TimeoutError("Operation timeout")
            pass
        
    def load_twiddle_factor(self, tf: list, tf_inv: list):
        "Load twiddle factors to the BRAM"

        for i in range(3):
            for j in range(TF_SIZE):
                self.write(self.TF_BASE     + i * TF_SIZE * 4 + j * 4,  tf[i][j])
                self.write(self.TF_INV_BASE + i * TF_SIZE * 4 + j * 4,  tf_inv[i][j])
        
        self._setup_operation(CryptoRegisters.OP_CODE, 7)
        self.start_operation()
        self._wait_operation()
     
        
    def load_polynomial(self, poly: list, ram_sel: int = 0):
        "Load the polynomial to the BRAM"
 
        for i in range(POLY_DEGREE * 3):
            self.write(self.POLY_BASE + i * 4, (int(poly[i])))
                       
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.OP_CODE, 5)
        self.start_operation()
        self._wait_operation()
        
    def execute_operation(self, op_code: OpCode, ram_sel: int = 0, ram_sel1: int = 0):
        "Execute the operation"
        self._setup_operation(CryptoRegisters.OP_CODE, op_code)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.RAM_SELECT1, ram_sel1)
        self.start_operation()
        self._wait_operation()
        
    def read_result(self, ram_sel: int = 0) -> list:
        self._setup_operation(CryptoRegisters.OP_CODE, OpCode.POLY_READ)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self.start_operation()
        self._wait_operation()
        
        result = []
        for i in range(POLY_DEGREE * 3):
            result.append(self.read(self.POLY_BASE + i * 4))
        return result
    

In [6]:

# Initialize the overlay
logger.info("Initializing the Crypto System")
overlay = Overlay("/home/xilinx/pynq/overlays/Crypto/Crypto.bit")

# Initialize the drivers
crypto = overlay.Crypto_0
random_gen = overlay.random_generator_0


# Generate the twiddle factors
with time_measure("Generate Twiddle Factors"):
    tf_gen = TwiddleFactorGenerator()
    crypto.load_twiddle_factor(tf_gen.twiddle_factors, tf_gen.inv_twiddle_factors)

with time_measure("Generate Random Numbers NTT_A"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_a = random_gen.read_random_data() 
    crypto.load_polynomial(random_num_a.tolist() * 3, 0)
    crypto.execute_operation(OpCode.NTT, 0)
    
with time_measure("Generate Random Numbers NTT_S"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_s = random_gen.read_random_data()
    crypto.load_polynomial(random_num_s.tolist()*3, 1)
    crypto.execute_operation(OpCode.NTT, 1)
    
with time_measure("NTT A * S"):
    crypto.execute_operation(OpCode.MUL_MOD, 0, 1)
    
with time_measure("Generate Random Numbers NTT_E"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_e = random_gen.read_random_data()
    crypto.load_polynomial(random_num_e.tolist()*3, 1)
    crypto.execute_operation(OpCode.NTT, 1)

with time_measure("Add the result to E"):
    crypto.execute_operation(OpCode.ADD_MOD, 0, 1)
    
with time_measure("Load Delta"):
    delta = [DELTA] * POLY_DEGREE * 3
    crypto.load_polynomial(delta, 1)

with open("/home/xilinx/jupyter_notebooks/message.txt", "w") as f:
    for i in range(POLY_DEGREE * 10):
        f.write(f"{random.randint(0, PLAINTEXT_MODULUS - 1)}\n")
        
# with time_measure("Total Encryption"):
#     tcpSerSock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
#     tcpSerSock.bind(addr)
#     tcpSerSock.listen(5)

#     try:
#         print("Waiting for connection")
#         tcpCliSock, addr = tcpSerSock.accept()
#         print("...connected from:", addr)
#         while True:
#             with time_measure("Read the message"):
#                 try:
#                     length_bytes = tcpCliSock.recv(4)
#                     if not length_bytes or len(length_bytes) != 4:
#                         print("Invalid message size")
#                         break
#                     length = int.from_bytes(length_bytes, byteorder='big')
#                     print(f"Message length: {length}")
                
      
#                     data = b''
#                     while len(data) < length:
#                         packet = tcpCliSock.recv(min(length - len(data), buff_size))
#                         if not packet:
#                             print("Connection closed during data transfer")
#                             break
#                         data += packet
                    
#                     # 解析JSON数据
#                     numbers = json.loads(data.decode('utf-8'))
#                     print(f"First 10 numbers: {numbers[:10]}")
#                     print(f"Received message: {len(numbers)} numbers")
                    
#                     if len(numbers) != POLY_DEGREE:
#                         print(f"Warning: Expected {POLY_DEGREE} numbers, got {len(numbers)}")
#                         if len(numbers) < POLY_DEGREE:
#                             # 补齐数据
#                             numbers.extend([0] * (POLY_DEGREE - len(numbers)))
#                             print(f"Padded message to {POLY_DEGREE} numbers")
#                         else:
#                             # 截断数据
#                             numbers = numbers[:POLY_DEGREE]
#                             print(f"Truncated message to {POLY_DEGREE} numbers")
                
#                 except Exception as e:
#                     print(f"Error: {e}")
#                     break
#                 data = [int(x) for x in data.decode().split("\n") if x] * 3
#         with time_measure("Encrypt the message"):
#             with time_measure("Load Message"):
#                 crypto.load_polynomial(data, 2)
#             with time_measure("MUL MOD"):
#                 crypto.execute_operation(OpCode.MUL_MOD, 2, 1)
#             with time_measure("MOD PLAINTEXT MODULUS"):
#                 crypto.execute_operation(OpCode.POLY_MOD_PLAINTEXTMODULUS, 2)
#             with time_measure("NTT"):
#                 crypto.execute_operation(OpCode.NTT, 2)
#             with time_measure("Add the result to NTT_A * NTT_S + NTT_E"):
#                 crypto.execute_operation(OpCode.ADD_MOD, 2, 0)
                
#     except Exception as e:
#         print(f"Error: {e}")
        
#     finally:
#         if 'tcpCliSock' in locals():
#             tcpCliSock.close()
#         tcpSerSock.close()


INFO:__main__:Initializing the Crypto System


INFO:time:Generate Twiddle Factors      : 721.252 ms
INFO:time:Generate Random Numbers NTT_A : 447.373 ms
INFO:time:Generate Random Numbers NTT_S : 455.393 ms
INFO:time:NTT A * S                     : 0.283 ms
INFO:time:Generate Random Numbers NTT_E : 442.698 ms
INFO:time:Add the result to E           : 0.294 ms
INFO:time:Load Delta                    : 322.653 ms


In [7]:
from socket import *
import json
import struct

# Constants
POLY_DEGREE = 4096
BUFFER_SIZE = 4096 * 4

def handle_client(sock, crypto):
    try:
        # Receive numbers from client
        length_data = sock.recv(4)
        if not length_data:
            return
        data_length = struct.unpack('!I', length_data)[0]
        
        received = bytearray()
        while len(received) < data_length:
            remaining = data_length - len(received)
            received += sock.recv(min(remaining, BUFFER_SIZE))
        
        numbers = json.loads(received.decode('utf-8'))
        
        # Validate and pad data
        if len(numbers) != POLY_DEGREE:
            numbers = numbers[:POLY_DEGREE] + [0]*(POLY_DEGREE - len(numbers))
        print(f"Received {len(numbers)} numbers")
        print(f"First 10 numbers: {numbers[:10]}")
        # FPGA Encryption Process
        # ------------------------
        # 1. Prepare polynomial data (replicate 3x for FPGA buffer)
        poly_data = numbers * 3
        
        # 2. Load data to FPGA
        with time_measure("Encrypt the message"):
            with time_measure("Load Message"):
                crypto.load_polynomial(poly_data, 2)    
        
        # 3. Execute encryption operations
            with time_measure("Encrypt the message"):
                crypto.execute_operation(OpCode.MUL_MOD, 2, 1)
                crypto.execute_operation(OpCode.POLY_MOD_PLAINTEXTMODULUS, 2)  
                crypto.execute_operation(OpCode.NTT, 2)
                crypto.execute_operation(OpCode.ADD_MOD, 2, 0)
        
        # 4. Read encrypted result from FPGA
        encrypted_result = crypto.read_result(2)  # Use actual read method
        print(encrypted_result[:10])
        # # Send back encrypted result
        # send_numbers(sock, encrypted_result)
        
    except Exception as e:
        print(f"Client handling error: {str(e)}")

def main():
    server_address = ('192.168.2.99', 123)
    
    with socket(AF_INET, SOCK_STREAM) as server_sock:
        server_sock.bind(server_address)
        server_sock.listen(1)
        print("FPGA Server listening...")
        
        try:
            while True:
                client_sock, addr = server_sock.accept()
                print(f"Connection from: {addr}")
                handle_client(client_sock, crypto)
                client_sock.close()
                
        except KeyboardInterrupt:
            print("\nServer terminated")
        except Exception as e:
            print(f"Server error: {str(e)}")
        finally:
            server_sock.close()

if __name__ == "__main__":
    main()

OSError: [Errno 98] Address already in use

In [None]:
import numpy as np

data = np.ones((10000, 10000), dtype=np.uint32)

In [None]:

    
        
# with time_measure("Verification"):
#     expected_result = []
#     for i in range(3):
#         delta_msg = [x * DELTA % PLAINTEXT_MODULUS for x in message]
#         M = sympy.ntt(delta_msg, MODULUS[i])
#         NTT_A = sympy.ntt(random_num_a, MODULUS[i])
#         NTT_S = sympy.ntt(random_num_s, MODULUS[i])
#         NTT_E = sympy.ntt(random_num_e, MODULUS[i])
#         temp_expected_result = [(a*s+e+m)%MODULUS[i] for a,s,e,m in zip(NTT_A, NTT_S, NTT_E, M)]
#         expected_result.extend(temp_expected_result)
    
#     expected_result = np.array(expected_result)
#     actual_result = np.array(crypto.read_result(2))
#     if np.array_equal(expected_result, actual_result):
#         logger.info("Verification successful: The results match.")
#     else:
#         logger.error("Verification failed: The results do not match.")
#         for i in range(len(expected_result)):
#             if expected_result[i] != actual_result[i]:
#                 logger.error(f"Expected: {expected_result[i]}, Actual: {actual_result[i]}")
#         quit()