# FFT Hardware Acceleration on PYNQ-Z2
This notebook demonstrates how to use the Vitis DSP Library FFT IP on PYNQ-Z2.

In [None]:
from pynq import Overlay
from pynq import allocate
import numpy as np
import matplotlib.pyplot as plt

# Load Overlay
overlay = Overlay("../build/overlay/fft.bit")
dma = overlay.axi_dma_0
fft_ip = overlay.fft_top_0

In [None]:
# Parameters
N = 1024
SSR = 2
scale = 2**15

# Generate input signal
t = np.linspace(0, 1, N, endpoint=False)
freq = 50 # 50 Hz
sig = 0.5 * np.exp(1j * 2 * np.pi * freq * t)

# Convert to fixed point int16
sig_fixed_real = (np.real(sig) * scale).astype(np.int16)
sig_fixed_imag = (np.imag(sig) * scale).astype(np.int16)

# Allocate buffers
input_buffer = allocate(shape=(N//SSR,), dtype=np.uint64)
output_buffer = allocate(shape=(N//SSR,), dtype=np.uint64)

# Pack data
for i in range(N//SSR):
    s0_real = sig_fixed_real[2*i]
    s0_imag = sig_fixed_imag[2*i]
    s1_real = sig_fixed_real[2*i+1]
    s1_imag = sig_fixed_imag[2*i+1]
    
    # Use Python integers for bitwise operations to avoid Numpy type issues
    s0_val = ((int(s0_imag) & 0xFFFF) << 16) | (int(s0_real) & 0xFFFF)
    s1_val = ((int(s1_imag) & 0xFFFF) << 16) | (int(s1_real) & 0xFFFF)
    
    val = (s1_val << 32) | s0_val
    input_buffer[i] = val

In [None]:
# Start the FFT IP
# 0x00: Control Register
# 0x81: bit 0 (ap_start) = 1, bit 7 (auto_restart) = 1
fft_ip.write(0x00, 0x81)

# Transfer
dma.sendchannel.transfer(input_buffer)
dma.recvchannel.transfer(output_buffer)
dma.sendchannel.wait()
dma.recvchannel.wait()

In [None]:
# Unpack output
output_data = np.zeros(N, dtype=np.complex128)
for i in range(N//SSR):
    val = int(output_buffer[i]) # Convert to Python int for safe bitwise ops
    s0_val = val & 0xFFFFFFFF
    s1_val = (val >> 32) & 0xFFFFFFFF
    
    s0_real = np.int16(s0_val & 0xFFFF)
    s0_imag = np.int16((s0_val >> 16) & 0xFFFF)
    s1_real = np.int16(s1_val & 0xFFFF)
    s1_imag = np.int16((s1_val >> 16) & 0xFFFF)
    
    output_data[2*i] = (float(s0_real) / scale) + 1j * (float(s0_imag) / scale)
    output_data[2*i+1] = (float(s1_real) / scale) + 1j * (float(s1_imag) / scale)

# Plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(np.abs(sig))
plt.title("Input Magnitude")
plt.subplot(1, 2, 2)
plt.plot(np.abs(output_data))
plt.title("Output Magnitude")
plt.show()

peak_idx = np.argmax(np.abs(output_data))