In [1]:
import numpy as np
from pynq import Overlay, allocate
from time import time

In [2]:
# 1) Load your FPGA bitstream (make sure ntt.bit is on the board)
overlay = Overlay("ntt_mult.bit")

In [3]:
# overlay.download()

# 2) Grab the DMA IP (name it exactly as in your block design)
dma = overlay.axi_dma_0  # or adjust if yours is named differently

In [4]:
# 3) Prepare input data (16-point vector)
#    For example: a simple ramp 0,1,2,...,15
in_data = np.arange(1024, dtype=np.uint32)
in_data=[i for i in in_data]
# 4) Allocate contiguous buffers in the PL-connected DDR
in_buf  = allocate(shape=(1024,), dtype=np.uint32)
out_buf = allocate(shape=(512,), dtype=np.uint32)

# 5) Copy your data into the input buffer
in_buf[:] = in_data

In [5]:
# 5) Start
start_time=time()
dma.sendchannel.transfer(in_buf)
dma.recvchannel.transfer(out_buf)   # receive first

dma.sendchannel.wait()              # wait for TX to finish
dma.recvchannel.wait()              # then RX
end_time=time()

In [11]:
# 8) Inspect the result
print("Input: ", in_buf)
print("NTT Mult Output:", out_buf[:])

# 9) (Optional) Clean up
# in_buf.freebuffer()
# out_buf.freebuffer()


Input:  [   0    1    2 ... 1021 1022 1023]
NTT Mult Output: [1936 1663  878 7262 5453 3132  299 4635  778 4090 6890 1497 3273 4537
 5289 5529 5257 4473 3177 1369 6730 3898  554 4379   11 2812 5101 6878
  462 1215 1456 1185  402 6788 4981 2662 7512 4169  314 3628 6430 1039
 2817 4083 4837 5079 4809 4027 2733  927 6290 3460  118 3945 7260 2382
 4673 6452   38  793 1036  767 7667 6374 4569 2252 7104 3763 7591 3226
 6030  641 2421 3689 4445 4689 4421 3641 2349  545 5910 3082 7423 3571
 6888 2012 4305 6086 7355  431  676  409 7311 6020 4217 1902 6756 3417
 7247 2884 5690  303 2085 3355 4113 4359 4093 3315 2025  223 5590 2764
 7107 3257 6576 1702 3997 5780 7051  129  376  111 7015 5726 3925 1612
 6468 3131 6963 2602 5410   25 1809 3081 3841 4089 3825 3049 1761 7642
 5330 2506 6851 3003 6324 1452 3749 5534 6807 7568  136 7554 6779 5492
 3693 1382 6240 2905 6739 2380 5190 7488 1593 2867 3629 3879 3617 2843
 1557 7440 5130 2308 6655 2809 6132 1262 3561 5348 6623 7386 7637 7376
 6603 5318 3521 

In [12]:
hw_time=end_time-start_time
print(f'Hardware Time = {hw_time} sec')

Hardware Time = 0.004726886749267578 sec


In [18]:
import numpy as np
from time import time


First 8 output values: [239, 4814, 1228, 3417, 1212, 4290, 6624, 1829]
Time taken: 0.055737 seconds


In [22]:

MOD = 7681
ROOT = 7146
N = 512


In [23]:
def mulmod(a, b):
    return (a * b) % MOD

def addmod(a, b):
    res = a + b
    return res - MOD if res >= MOD else res

def submod(a, b):
    res = a - b
    return res + MOD if res < 0 else res

def modpow(x, e):
    res = 1
    while e > 0:
        if e & 1:
            res = mulmod(res, x)
        x = mulmod(x, x)
        e >>= 1
    return res



In [24]:
def bit_reverse(x, logn):
    res = 0
    for i in range(logn):
        if x & (1 << i):
            res |= 1 << (logn - 1 - i)
    return res



In [25]:
def ntt(a):
#     logN = N.bit_length() - 1
    logN=9
    # Bit-reversal permutation
    a = [a[bit_reverse(i, logN)] for i in range(N)]

    # Cooley-Tukey NTT
    len_ = 2
    while len_ <= N:
        wlen = modpow(ROOT, N // len_)
        for i in range(0, N, len_):
            w = 1
            for j in range(len_ // 2):
                u = a[i + j]
                v = mulmod(a[i + j + len_ // 2], w)
                a[i + j] = addmod(u, v)
                a[i + j + len_ // 2] = submod(u, v)
                w = mulmod(w, wlen)
        len_ <<= 1
    return a


In [27]:
input_data = [i % MOD for i in range(N)]

start_time = time()
output_data = ntt(input_data.copy())
end_time = time()

print("First 8 output values:", output_data[:8])
print("Time taken: {:.6f} seconds".format(end_time - start_time))
sw_time=end_time-start_time

First 8 output values: [239, 4814, 1228, 3417, 1212, 4290, 6624, 1829]
Time taken: 0.065758 seconds


In [29]:
speedup=sw_time/hw_time
print(speedup)

10.40906517718987


In [7]:
exp=[1936, 1663, 878, 7262, 5453, 3132, 299, 4635, 778, 4090, 6890, 1497, 3273, 4537, 5289, 5529, 5257, 4473, 3177, 1369, 6730, 3898, 554, 4379, 11, 2812, 5101, 6878, 462, 1215, 1456, 1185, 402, 6788, 4981, 2662, 7512, 4169, 314, 3628, 6430, 1039, 2817, 4083, 4837, 5079, 4809, 4027, 2733, 927, 6290, 3460, 118, 3945, 7260, 2382, 4673, 6452, 38, 793, 1036, 767, 7667, 6374, 4569, 2252, 7104, 3763, 7591, 3226, 6030, 641, 2421, 3689, 4445, 4689, 4421, 3641, 2349, 545, 5910, 3082, 7423, 3571, 6888, 2012, 4305, 6086, 7355, 431, 676, 409, 7311, 6020, 4217, 1902, 6756, 3417, 7247, 2884, 5690, 303, 2085, 3355, 4113, 4359, 4093, 3315, 2025, 223, 5590, 2764, 7107, 3257, 6576, 1702, 3997, 5780, 7051, 129, 376, 111, 7015, 5726, 3925, 1612, 6468, 3131, 6963, 2602, 5410, 25, 1809, 3081, 3841, 4089, 3825, 3049, 1761, 7642, 5330, 2506, 6851, 3003, 6324, 1452, 3749, 5534, 6807, 7568, 136, 7554, 6779, 5492, 3693, 1382, 6240, 2905, 6739, 2380, 5190, 7488, 1593, 2867, 3629, 3879, 3617, 2843, 1557, 7440, 5130, 2308, 6655, 2809, 6132, 1262, 3561, 5348, 6623, 7386, 7637, 7376, 6603, 5318, 3521, 1212, 6072, 2739, 6575, 2218, 5030, 7330, 1437, 2713, 3477, 3729, 3469, 2697, 1413, 7298, 4990, 2170, 6519, 2675, 6000, 1132, 3433, 5222, 6499, 7264, 7517, 7258, 6487, 5204, 3409, 1102, 5964, 2633, 6471, 2116, 4930, 7232, 1341, 2619, 3385, 3639, 3381, 2611, 1329, 7216, 4910, 2092, 6443, 2601, 5928, 1062, 3365, 5156, 6435, 7202, 7457, 7200, 6431, 5150, 3357, 1052, 5916, 2587, 6427, 2074, 4890, 7194, 1305, 2585, 3353, 3609, 3353, 2585, 1305, 7194, 4890, 2074, 6427, 2587, 5916, 1052, 3357, 5150, 6431, 7200, 7457, 7202, 6435, 5156, 3365, 1062, 5928, 2601, 6443, 2092, 4910, 7216, 1329, 2611, 3381, 3639, 3385, 2619, 1341, 7232, 4930, 2116, 6471, 2633, 5964, 1102, 3409, 5204, 6487, 7258, 7517, 7264, 6499, 5222, 3433, 1132, 6000, 2675, 6519, 2170, 4990, 7298, 1413, 2697, 3469, 3729, 3477, 2713, 1437, 7330, 5030, 2218, 6575, 2739, 6072, 1212, 3521, 5318, 6603, 7376, 7637, 7386, 6623, 5348, 3561, 1262, 6132, 2809, 6655, 2308, 5130, 7440, 1557, 2843, 3617, 3879, 3629, 2867, 1593, 7488, 5190, 2380, 6739, 2905, 6240, 1382, 3693, 5492, 6779, 7554, 136, 7568, 6807, 5534, 3749, 1452, 6324, 3003, 6851, 2506, 5330, 7642, 1761, 3049, 3825, 4089, 3841, 3081, 1809, 25, 5410, 2602, 6963, 3131, 6468, 1612, 3925, 5726, 7015, 111, 376, 129, 7051, 5780, 3997, 1702, 6576, 3257, 7107, 2764, 5590, 223, 2025, 3315, 4093, 4359, 4113, 3355, 2085, 303, 5690, 2884, 7247, 3417, 6756, 1902, 4217, 6020, 7311, 409, 676, 431, 7355, 6086, 4305, 2012, 6888, 3571, 7423, 3082, 5910, 545, 2349, 3641, 4421, 4689, 4445, 3689, 2421, 641, 6030, 3226, 7591, 3763, 7104, 2252, 4569, 6374, 7667, 767, 1036, 793, 38, 6452, 4673, 2382, 7260, 3945, 118, 3460, 6290, 927, 2733, 4027, 4809, 5079, 4837, 4083, 2817, 1039, 6430, 3628, 314, 4169, 7512, 2662, 4981, 6788, 402, 1185, 1456, 1215, 462, 6878, 5101, 2812, 11, 4379, 554, 3898, 6730, 1369, 3177, 4473, 5257, 5529, 5289, 4537, 3273, 1497, 6890, 4090, 778, 4635, 299, 3132, 5453, 7262, 878, 1663, 1936, 1697]


In [9]:
for i in range(512):
    if out_buf[i]!=exp[i]:
        print("Error")
print("Done")

Done
