In [1]:
from pynq import  Overlay, DefaultHierarchy, allocate, DefaultIP, MMIO
import numpy as np
import time
import random
import sympy
from enum import IntEnum
import logging
from contextlib import contextmanager

# -------------------- Define the constants --------------------
MODULUS = 1073750017
MODULUS_ROOT = 625534531
MODULUS_INV = 627281114
MODULUS_N_INV = 1073487871
PLAINTEXT_MODULUS = 65537
DELTA = 100
BRAM_COUNT = 4
POLY_DEGREE = 4096
TF_SIZE = 2048
DATA_WIDTH = 32

# -------------------- Configure the logger --------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
time_logger = logging.getLogger('time')
time_logger.setLevel(logging.INFO)

# ---------------------- Timer ----------------------
@contextmanager
def time_measure(name):
    start = time.perf_counter()
    yield
    elapsed = (time.perf_counter() - start) * 1000  # 毫秒
    time_logger.info(f"{name.ljust(30)}: {elapsed:.3f} ms")


# -------------------- Define the enums --------------------
class OpCode(IntEnum):
    ADD_MOD = 0
    SUB_MOD = 1
    MUL_MOD = 2
    NTT = 3
    INTT = 4
    POLY_WRITE = 5
    POLY_READ = 6
    TWIDDLE_WRITE = 7
    
class CryptoRegisters(IntEnum):
    CONTROL = 0x0000
    RAM_SELECT = 0x0010
    MOD_INDEX = 0x0014
    OP_CODE = 0x0018
    
# -------------------- Define Twiddle Factor Generator --------------------
class TwiddleFactorGenerator:
    "Generate twiddle factors for NTT and INTT"
    def __init__(self):
        self.modulus = MODULUS
        self.root = MODULUS_ROOT
        self.root_inv = MODULUS_INV
        
        self.twiddle_factors, self.inv_twiddle_factors = self._generate_factors()
    
    def _generate_factors(self):
        "Generate twiddle factors"
        tf = [1] * TF_SIZE
        tf_inv = [1] * TF_SIZE
        
        for i in range(1, TF_SIZE):
            tf[i] = (tf[i-1] * self.root) % self.modulus
            tf_inv[i] = (tf_inv[i-1] * self.root_inv) % self.modulus
        return tf, tf_inv
    
# -------------------- Define Hardware Driver --------------------
class RandomGeneratorDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)    
    bindto = ['xilinx.com:hls:random_generator:1.0']  # 匹配Vitis HLS生成的IP ID
        
    def set_seed(self, seed):
        self.write(0x10, seed)
        
    def start_generation(self):
        self.write(0x00, 0x01)
        
    def check_ready(self):
        return self.read(0x00) & 0x2
    
    def generate_random(self, seed:int, timeout:float = 5.0):
        "Generate a random number with a given seed (With timeout)"
        self.set_seed(seed)
        self.start_generation()
        
        start_time = time.time() 
        while not self.check_ready():
            if time.time() - start_time > timeout:
                raise TimeoutError("Random number generation timeout")
    
    def read_random_data(self) -> np.ndarray:
        "Read the random data from the BRAM with POLY_DEGREE"
        buf = np.zeros(POLY_DEGREE, dtype=np.uint32)
        # for each random number, which is 32-bit
        # I want to write each bit to each address of the BRAM
        
        for i in range(POLY_DEGREE//DATA_WIDTH):
            for j in range(DATA_WIDTH):
                buf[i * DATA_WIDTH + j] = (self.read(0x1000 + i * 4) >> j) & 0x1
        return buf
    
class CryptoDriver(DefaultIP):
    def __init__(self, description):
        super().__init__(description=description)
    bindto = ['xilinx.com:hls:Crypto:1.0']
    
    TF_BASE = 0x2000
    TF_INV_BASE = 0x8000
    POLY_BASE = 0x4000
    
    def start_operation(self):
        self.write(CryptoRegisters.CONTROL, 0x1)
        
    def check_busy(self):
        return (self.read(CryptoRegisters.CONTROL) & 0x4) == 0
    
    def _setup_operation(self, addr: CryptoRegisters, value: int):
        "Setup the register value"
        self.write(addr, int(value))
        
    def _wait_operation(self, timeout: float = 5.0):
        "Wait for the operation to finish"
        start_time = time.time()
        while self.check_busy():
            if time.time() - start_time > timeout:
                raise TimeoutError("Operation timeout")
            pass
        
    def load_twiddle_factor(self, tf: list, tf_inv: list):
        "Load twiddle factors to the BRAM"
        if len(tf) != TF_SIZE or len(tf_inv) != TF_SIZE:
            raise ValueError("Invalid twiddle factor size")
        
        for i in range(TF_SIZE):
            self.write(self.TF_BASE + i * 4, tf[i])
            self.write(self.TF_INV_BASE + i * 4, tf_inv[i])
        
        self._setup_operation(CryptoRegisters.OP_CODE, 7)
        self.start_operation()
        self._wait_operation()
     
        
    def load_polynomial(self, poly: list, ram_sel: int = 0):
        "Load the polynomial to the BRAM"
        if len(poly) != POLY_DEGREE:
            raise ValueError("Invalid polynomial size")
        
        if not 0 <= ram_sel < BRAM_COUNT:
            raise ValueError("Invalid BRAM selection")
        
        for i in range(POLY_DEGREE):
            self.write(self.POLY_BASE + i * 4, (int(poly[i])))
           
            
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.OP_CODE, 5)
        self.start_operation()
        self._wait_operation()
        
    def execute_operation(self, op_code: OpCode, mod_index: int = 0, ram_sel: int = 0):
        "Execute the operation"
        self._setup_operation(CryptoRegisters.OP_CODE, op_code)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self._setup_operation(CryptoRegisters.MOD_INDEX, mod_index)
        self.start_operation()
        self._wait_operation()
        
    def read_result(self, ram_sel: int = 0) -> list:
        self._setup_operation(CryptoRegisters.OP_CODE, OpCode.POLY_READ)
        self._setup_operation(CryptoRegisters.RAM_SELECT, ram_sel)
        self.start_operation()
        self._wait_operation()
        
        result = []
        for i in range(POLY_DEGREE):
            result.append(self.read(self.POLY_BASE + i * 4))
        return result
    
    
        
# -------------------- Define the Memory Controller --------------------
class MemoryController:
    def __init__(self, overlay):
        self.brams = [
            overlay.axi_bram_ctrl_0,
            overlay.axi_bram_ctrl_1,
            overlay.axi_bram_ctrl_2,
            overlay.axi_bram_ctrl_3
        ]
        
    def write(self, bram_sel: int, data: np.ndarray):
        if not 0 <= bram_sel < BRAM_COUNT:
            raise ValueError("Invalid BRAM selection")
        
        if len(data) != POLY_DEGREE:
            raise ValueError("Invalid data size")
        
        for i in range(POLY_DEGREE):
            self.brams[bram_sel].write(i * 4, int(data[i]))
            
    def read(self, bram_sel: int) -> np.ndarray:
        if not 0 <= bram_sel < BRAM_COUNT:
            raise ValueError("Invalid BRAM selection")
        
        buf = np.zeros(POLY_DEGREE, dtype=np.uint32)
        for i in range(POLY_DEGREE):
            buf[i] = self.brams[bram_sel].read(i * 4)
        return buf



# Initialize the overlay
logger.info("Initializing the Crypto System")
overlay = Overlay("/home/xilinx/pynq/overlays/Crypto/Crypto.bit")

# Initialize the drivers
crypto = overlay.Crypto_0
random_gen = overlay.random_generator_0
memory = MemoryController(overlay)

# Generate the twiddle factors
with time_measure("Generate Twiddle Factors"):
    tf_gen = TwiddleFactorGenerator()
    crypto.load_twiddle_factor(tf_gen.twiddle_factors, tf_gen.inv_twiddle_factors)

def process_parameter(name: str, bram_sel: int) -> list:
    with time_measure(f"Generate {name}"):
        seed = random.randint(0, 2**32 - 1)
        logger.debug(f"Use the seed {seed} to generate the random number")
        random_gen.generate_random(seed)
        random_num = random_gen.read_random_data()
    with time_measure(f"NTT {name}"):
        crypto.load_polynomial(random_num)
        crypto.execute_operation(OpCode.NTT)
        memory.write(bram_sel, crypto.read_result(0))
    
    NTT_random_num = crypto.read_result(0)
    return NTT_random_num
    
A = process_parameter("A", 0)
S = process_parameter("S", 1)
E = process_parameter("E", 2)


with time_measure("Preprocess A * S"):
    crypto.load_polynomial(memory.read(0), 0)
    crypto.load_polynomial(memory.read(1), 1)
    crypto.execute_operation(OpCode.MUL_MOD, 0, 0)


with time_measure("Add the result to E"):
    crypto.load_polynomial(memory.read(2), 1)
    crypto.execute_operation(OpCode.ADD_MOD, 0, 0)
    memory.write(0, crypto.read_result(0))

with time_measure("Total Encryption"):
    with time_measure("Message Reading"):       
        with open("/home/xilinx/jupyter_notebooks/message.txt", "r") as f:
            message = [int(line.strip()) for line in f.readlines()[:POLY_DEGREE]]
                
    with time_measure("Encrypt the message"):    
        DELTA_seq = [DELTA] * POLY_DEGREE
        crypto.load_polynomial(message, 0)
        crypto.load_polynomial(DELTA_seq, 1)
        crypto.execute_operation(OpCode.MUL_MOD, 0, 0)
        message_delta = crypto.read_result(0)
        message_delta = [x % PLAINTEXT_MODULUS for x in message_delta]
        crypto.load_polynomial(message_delta, 1)
        crypto.execute_operation(OpCode.NTT, 0, 1)

    with time_measure("Add the message_delta to AS+E"):      
        crypto.load_polynomial(memory.read(0), 0)
        crypto.execute_operation(OpCode.ADD_MOD, 0, 0)            
            

with time_measure("Verification"):
    # Convert message to numpy array and multiply by DELTA
    message_np = np.array(message) * DELTA % PLAINTEXT_MODULUS
    # Perform NTT on the message
    M = sympy.ntt(message_np, MODULUS)
    # Calculate the expected result
    expected_result = [(a*s+e+m)%MODULUS for a,s,e,m in zip(A,S,E,M)]
    # Read the actual result from the hardware
    actual_result = np.array(crypto.read_result(0))

    # Compare the results
    if np.array_equal(expected_result, actual_result):
            logger.info("Verification successful: The results match.")
    else:
            logger.error("Verification failed: The results do not match.")
            logger.debug(f"Expected: {expected_result}")
            logger.debug(f"Actual: {actual_result}")

    print("Verification complete.")


# Decrypt the message


INFO:__main__:Initializing the Crypto System


INFO:time:Generate Twiddle Factors      : 116.806 ms
INFO:time:Generate A                    : 106.098 ms
INFO:time:NTT A                         : 326.206 ms
INFO:time:Generate S                    : 103.113 ms
INFO:time:NTT S                         : 328.851 ms
INFO:time:Generate E                    : 103.430 ms
INFO:time:NTT E                         : 327.396 ms
INFO:time:Preprocess A * S              : 443.091 ms
INFO:time:Add the result to E           : 432.104 ms
INFO:time:Message Reading               : 202.269 ms
INFO:time:Encrypt the message           : 424.109 ms
INFO:time:Add the message_delta to AS+E : 227.249 ms
INFO:time:Total Encryption              : 868.959 ms
INFO:__main__:Verification successful: The results match.
INFO:time:Verification                  : 552.746 ms


Verification complete.
