In [1]:
from pynq import  Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import numpy as np
import time
import random
import sympy
from enum import IntEnum
import logging
from contextlib import contextmanager

# -------------------- Define the constants --------------------
MODULUS         = [1073750017, 1073815553, 1073872897]
MODULUS_ROOT    = [625534531, 646391299, 647613940]
MODULUS_INV     = [627281114, 777819041, 538279817]

PLAINTEXT_MODULUS = 65537
DELTA = 100
BRAM_COUNT = 4
POLY_DEGREE = 4096
TF_SIZE = 2048
DATA_WIDTH = 32

# -------------------- Configure the logger --------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
time_logger = logging.getLogger('time')
time_logger.setLevel(logging.INFO)

# ---------------------- Timer ----------------------
@contextmanager
def time_measure(name):
    start = time.perf_counter()
    yield
    elapsed = (time.perf_counter() - start) * 1000  # 毫秒
    time_logger.info(f"{name.ljust(30)}: {elapsed:.3f} ms")



In [None]:
# -------------------- Define the enums --------------------
class OpCode(IntEnum):
    ADD_MOD                     = 0
    SUB_MOD                     = 1
    MUL_MOD                     = 2
    NTT                         = 3
    INTT                        = 4
    POLY_WRITE                  = 5
    POLY_READ                   = 6
    TWIDDLE_WRITE               = 7
    POLY_MOD_PLAINTEXTMODULUS   = 8
    
class CryptoRegisters(IntEnum):
    CONTROL = 0x0000
    RAM_SELECT = 0x0010
    RAM_SELECT1 = 0x0018
    OP_CODE = 0x0020
    

In [None]:
# -------------------- Define Twiddle Factor Generator --------------------
class TwiddleFactorGenerator:
    "Generate twiddle factors for NTT and INTT"
    def __init__(self):
        self.modulus = MODULUS
        self.root = MODULUS_ROOT
        self.root_inv = MODULUS_INV
        
        self.twiddle_factors, self.inv_twiddle_factors = self._generate_factors()
    
    def _generate_factors(self):
        "Generate twiddle factors"
        all_factors = []
        all_inv_factors = []
        for i in range(len(self.modulus)):
            factors, inv_factors = self._generate_single_factor(i)
            all_factors.append(factors)
            all_inv_factors.append(inv_factors)
        return all_factors, all_inv_factors
        
    def _generate_single_factor(self, idx):
        "Generate twiddle factors for a single modulus"
        factors = []
        inv_factors = []
        modulus = self.modulus[idx]
        root = self.root[idx]
        root_inv = self.root_inv[idx]
        
        for i in range(TF_SIZE):
            factor = pow(root, i, modulus)
            inv_factor = pow(root_inv, i, modulus)
            factors.append(factor)
            inv_factors.append(inv_factor)
        return factors, inv_factors

In [None]:
# -------------------- Define Hardware Driver --------------------
class RandomGeneratorDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)    
    bindto = ['xilinx.com:hls:random_generator:1.0']  # 匹配Vitis HLS生成的IP ID
        
    def set_seed(self, seed):
        self.write(0x10, seed)
        
    def start_generation(self):
        self.write(0x00, 0x01)
        
    def check_ready(self):
        return self.read(0x00) & 0x2
    
    def generate_random(self, seed:int, timeout:float = 5.0):
        "Generate a random number with a given seed (With timeout)"
        self.set_seed(seed)
        self.start_generation()
        
        start_time = time.time() 
        while not self.check_ready():
            if time.time() - start_time > timeout:
                raise TimeoutError("Random number generation timeout")
    
    def read_random_data(self) -> np.ndarray:
        "Read the random data from the BRAM with POLY_DEGREE"
        buf = np.zeros(POLY_DEGREE, dtype=np.uint32)
        # for each random number, which is 32-bit
        # I want to write each bit to each address of the BRAM
        
        for i in range(POLY_DEGREE//DATA_WIDTH):
            for j in range(DATA_WIDTH):
                buf[i * DATA_WIDTH + j] = (self.read(0x1000 + i * 4) >> j) & 0x1
        return buf
    

In [None]:
class CryptoDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)
    bindto = ['xilinx.com:hls:Crypto:1.0']
    
    TF_BASE     = 0x08000
    TF_INV_BASE = 0x20000
    POLY_BASE   = 0x10000
    
    def start_operation(self):
        self.write(CryptoRegisters.CONTROL, 0x1)
        
    def check_busy(self):
        return (self.read(CryptoRegisters.CONTROL) & 0x4) == 0
    
    def _setup_operation(self, addr: CryptoRegisters, value: int):
        "Setup the register value"
        self.write(addr, int(value))
        
    def _wait_operation(self, timeout: float = 5.0):
        "Wait for the operation to finish"
        start_time = time.time()
        while self.check_busy():
            if time.time() - start_time > timeout:
                raise TimeoutError("Operation timeout")
            pass
        
    def load_twiddle_factor(self, tf: list, tf_inv: list):
        "Load twiddle factors to the BRAM"

        for i in range(3):
            for j in range(TF_SIZE):
                self.write(self.TF_BASE     + i * TF_SIZE * 4 + j * 4,  tf[i][j])
                self.write(self.TF_INV_BASE + i * TF_SIZE * 4 + j * 4,  tf_inv[i][j])
        
        self._setup_operation(CryptoRegisters.OP_CODE, 7)
        self.start_operation()
        self._wait_operation()
     
        
    def load_polynomial(self, poly: list, ram_sel: int = 0):
        "Load the polynomial to the BRAM"
 
        for i in range(POLY_DEGREE * 3):
            self.write(self.POLY_BASE + i * 4, (int(poly[i])))
                       
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.OP_CODE, 5)
        self.start_operation()
        self._wait_operation()
        
    def execute_operation(self, op_code: OpCode, ram_sel: int = 0, ram_sel1: int = 0):
        "Execute the operation"
        self._setup_operation(CryptoRegisters.OP_CODE, op_code)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.RAM_SELECT1, ram_sel1)
        self.start_operation()
        self._wait_operation()
        
    def read_result(self, ram_sel: int = 0) -> list:
        self._setup_operation(CryptoRegisters.OP_CODE, OpCode.POLY_READ)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self.start_operation()
        self._wait_operation()
        
        result = []
        for i in range(POLY_DEGREE * 3):
            result.append(self.read(self.POLY_BASE + i * 4))
        return result
    

In [None]:

# Initialize the overlay
logger.info("Initializing the Crypto System")
overlay = Overlay("/home/xilinx/pynq/overlays/Crypto/Crypto.bit")

# Initialize the drivers
crypto = overlay.Crypto_0
random_gen = overlay.random_generator_0


# Generate the twiddle factors
with time_measure("Generate Twiddle Factors"):
    tf_gen = TwiddleFactorGenerator()
    crypto.load_twiddle_factor(tf_gen.twiddle_factors, tf_gen.inv_twiddle_factors)

with time_measure("Generate Random Numbers NTT_A"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_a = random_gen.read_random_data() 
    crypto.load_polynomial(random_num_a.tolist() * 3, 0)
    crypto.execute_operation(OpCode.NTT, 0)
    
with time_measure("Generate Random Numbers NTT_S"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_s = random_gen.read_random_data()
    crypto.load_polynomial(random_num_s.tolist()*3, 1)
    crypto.execute_operation(OpCode.NTT, 1)
    
with time_measure("NTT A * S"):
    crypto.execute_operation(OpCode.MUL_MOD, 0, 1)
    
with time_measure("Generate Random Numbers NTT_E"):
    seed = random.randint(0, 2**32 - 1)
    random_gen.generate_random(seed)
    random_num_e = random_gen.read_random_data()
    crypto.load_polynomial(random_num_e.tolist()*3, 1)
    crypto.execute_operation(OpCode.NTT, 1)

with time_measure("Add the result to E"):
    crypto.execute_operation(OpCode.ADD_MOD, 0, 1)
    
with time_measure("Load Delta"):
    delta = [DELTA] * POLY_DEGREE * 3
    crypto.load_polynomial(delta, 1)

with time_measure("Total Encryption"):
    with time_measure("Message Reading"):
        with open("/home/xilinx/jupyter_notebooks/message.txt", "r") as f:
            message = [int(line.strip()) for line in f.readlines()]
        msg = message * 3
        
    with time_measure("Encrypt the message"):
        crypto.load_polynomial(msg, 2)
        crypto.execute_operation(OpCode.MUL_MOD, 2, 1)
        crypto.execute_operation(OpCode.POLY_MOD_PLAINTEXTMODULUS, 2)
        crypto.execute_operation(OpCode.NTT, 2)
        crypto.execute_operation(OpCode.ADD_MOD, 2, 0)
        
with time_measure("Verification"):
    expected_result = []
    for i in range(3):
        delta_msg = [x * DELTA % PLAINTEXT_MODULUS for x in message]
        M = sympy.ntt(delta_msg, MODULUS[i])
        NTT_A = sympy.ntt(random_num_a, MODULUS[i])
        NTT_S = sympy.ntt(random_num_s, MODULUS[i])
        NTT_E = sympy.ntt(random_num_e, MODULUS[i])
        temp_expected_result = [(a*s+e+m)%MODULUS[i] for a,s,e,m in zip(NTT_A, NTT_S, NTT_E, M)]
        expected_result.extend(temp_expected_result)
    
    expected_result = np.array(expected_result)
    actual_result = np.array(crypto.read_result(2))
    if np.array_equal(expected_result, actual_result):
        logger.info("Verification successful: The results match.")
    else:
        logger.error("Verification failed: The results do not match.")
        for i in range(len(expected_result)):
            if expected_result[i] != actual_result[i]:
                logger.error(f"Expected: {expected_result[i]}, Actual: {actual_result[i]}")

INFO:__main__:Initializing the Crypto System


INFO:time:Generate Twiddle Factors      : 692.701 ms


[[1, 625534531, 180790047, 986624439, 381780781, 810360320, 260402737, 553963441, 387049130, 806543545, 739895375, 333249896, 882255214, 581712732, 423496283, 126222482, 440707649, 370695803, 312017177, 41017673, 142790058, 974594381, 698239950, 4811621, 1044181966, 1052963739, 368395236, 9172197, 715296025, 382969377, 116247299, 217454590, 187094501, 632788151, 494954670, 580332292, 520444094, 956443095, 260447782, 408966222, 754587317, 791194670, 874903248, 882036381, 4607654, 988267514, 572531104, 548381812, 901096861, 587023463, 936252620, 569365762, 290055771, 86206039, 883974859, 110743389, 470197401, 770264767, 583956016, 343619297, 648345811, 224933888, 747161101, 209576904, 817737403, 400645471, 987640676, 568146135, 201194640, 377044554, 68361894, 583276735, 642344479, 341373928, 21246300, 715568548, 82262102, 542695835, 327687217, 1727698, 1071029070, 1017676574, 180523752, 477579089, 625727145, 325786494, 893935970, 414715582, 290047592, 258107995, 755953855, 814584580, 350

INFO:time:Generate Random Numbers NTT_A : 442.964 ms
INFO:time:Generate Random Numbers NTT_S : 427.365 ms
INFO:time:NTT A * S                     : 0.299 ms
INFO:time:Generate Random Numbers NTT_E : 423.040 ms
INFO:time:Add the result to E           : 0.300 ms
INFO:time:Load Delta                    : 306.510 ms
INFO:time:Message Reading               : 22.981 ms
INFO:time:Encrypt the message           : 322.471 ms
INFO:time:Total Encryption              : 355.276 ms
INFO:__main__:Verification successful: The results match.
INFO:time:Verification                  : 4365.861 ms
