In [29]:
import numpy as np
from itertools import groupby
from functools import partial

def make_twiddle(n1, n2):
    def w(k, n):
        return np.exp(2j * np.pi * k / n)

    I1 = np.arange(n1)
    I2 = np.arange(n2)
    return w(I1[:,None] * I2[None,:], n1*n2).astype('complex64')

## Radix-n data streamlining

For a radix-2 FFT, it is possible to divide input data in such a way that every butterfly operation reads and writes to two different arrays in memory. This could be advantageous for an implementation on the FPGA. The trick is to separate data locations based on the parity of their index.

### Parity

The index of each element in the array can be written in binary. In the case of a radix-2 FFT, we can see the array being reshaped in a $[2, 2, 2, \dots]$ shape. Every stage of the $n=2^k$ sized radix-2 FFT is being performed in a different dimension of the $k$-dimensional reshaped array. The indexing in this multi-dimensional array is the same as the binary notation for the linear index. That means that each time the 2-FFT is performed, the multi-index of the elements involved will differ by one bit in the linear index. Separating the array based on the parity of the index will guarantee that each 2-FFT operation reads and writes one (complex) number from each set. The parity is the sum of all the individual bits, modulo 2.

$$P_2(i) = \left(\sum_k b_k\right) \mod 2,\quad {\rm where}\, i := \sum_k b_k 2^k$$

We can extend this concept to the radix-4 FFT, on 4 data channels.

### 4-parity

In the radix-4 FFT we can view the array as being reshaped to a $[4, 4, 4, \dots]$ shape. The multi-index into this array is equivalent to the quarternary number notation of the linear index. Similar to the radix-2 parity, we can define the radix-4 parity as the sum of each quarternary digit in the index, modulo 4.

$$P_4(i) = \left(\sum_k q_k\right) \mod 4,\quad {\rm where}\, i := \sum_k q_k 4^k$$

This would ensure that each 4-FFT reads its data from the 4 different input channels. We define the `parity` function to work with any radix:

In [2]:
def digits(n, i):
    """Generates the n-numbered digits of i, in reverse order."""
    while True:
        if i == 0:
            return
        else:
            i, q = divmod(i, n)
            yield q
            
def parity(n, i):
    return sum(digits(n, i)) % n

To see which index belongs in which channel, we use a `groupby` on keyed the `parity` function

In [3]:
def channels(N, radix):
    parity_r = partial(parity, radix)
    return groupby(sorted(range(N), key=parity_r), parity_r)

for instance, for radix-2, size 16, this creates the following channels

In [4]:
for (g, i) in channels(16, 2):
    print(g, list(i))

0 [0, 3, 5, 6, 9, 10, 12, 15]
1 [1, 2, 4, 7, 8, 11, 13, 14]


### Example: radix-2 16-FFT

We can implement the 16-FFT as a radix-2 operation on an array of shape $[2,2,2,2]$.

In [168]:
def fft2(x0, x1):
    y = np.zeros(shape=(2,)+x0.shape, dtype='complex64')
    y[0] = x0 + x1
    y[1] = x0 - x1
    x0[:] = y[0]
    x1[:] = y[1]

To check whether our grouping works, we create a dummy array, and a dummy `fft2_check` that just checks if all arguments are in different channels.

In [169]:
def fft2_check(x0, x1):
    assert np.all(x0 != x1)

In [170]:
x = np.zeros(shape=(16,), dtype=int)
for (g, i) in channels(16, 2):
    x[list(i)] = g

In [171]:
N = 16
radix = 2

factors = [radix] * int(math.log(N, radix))
s = x.reshape(factors).transpose()

try:
    for k in range(len(factors)):
        fft2_check(np.take(s, 0, axis=k), np.take(s, 1, axis=k))
except AssertionError:
    print("failed")
else:
    print("succeeded")

succeeded


In [172]:
# x = (np.random.normal(size=N) + 1j * np.random.normal(size=N)).astype('complex64')
x = np.arange(0, N, dtype='complex64')
s = x.copy().reshape(factors).transpose()

fft2(s[...,0], s[...,1])

for k in range(1, len(factors)):
    w = make_twiddle(2, 2**k).conj()[:,:]
    z = s.reshape([-1, 2, 2**k])
    z *= w
    fft2(z[..., 0, :], z[..., 1, :])
    s = z

In [173]:
s.flatten()

array([120.        +0.j       ,  -7.9999995+40.21872j  ,
        -8.       +19.31371j  ,  -8.       +11.972846j ,
        -8.        +8.j       ,  -8.        +5.3454294j,
        -8.        +3.3137083j,  -8.        +1.591299j ,
        -8.        +0.j       ,  -8.        -1.591299j ,
        -8.        -3.3137083j,  -8.        -5.3454294j,
        -8.        -8.j       ,  -8.       -11.972846j ,
        -8.       -19.31371j  ,  -7.9999995-40.21872j  ], dtype=complex64)

In [174]:
abs(s.flatten() - np.fft.fft(x)).max()

3.577337255373132e-06

### Example: radix-4 64-FFT

In [175]:
for g, i in channels(64, 4):
    print(g, list(i))

0 [0, 7, 10, 13, 19, 22, 25, 28, 34, 37, 40, 47, 49, 52, 59, 62]
1 [1, 4, 11, 14, 16, 23, 26, 29, 35, 38, 41, 44, 50, 53, 56, 63]
2 [2, 5, 8, 15, 17, 20, 27, 30, 32, 39, 42, 45, 51, 54, 57, 60]
3 [3, 6, 9, 12, 18, 21, 24, 31, 33, 36, 43, 46, 48, 55, 58, 61]


In [176]:
from itertools import combinations
import math

def fft4_check(x0, x1, x2, x3):
    for (a, b) in combinations([x0, x1, x2, x3], 2):
        assert np.all(a != b)

In [177]:
N = 64
radix = 4

x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g

In [178]:
factors = [radix] * int(math.log(N, radix))
s = x.reshape(factors)

try:
    for k in range(len(factors)):
        fft4_check(*(np.take(s, i, axis=k) for i in range(radix)))
except AssertionError:
    print("failed")
else:
    print("succeeded")

succeeded


## Implementing the radix-4 FFT

In [179]:
def fft4(x0, x1, x2, x3):
    a = x0 + x2
    b = x1 + x3
    c = x0 - x2
    d = x1 - x3
    x0[:] = a + b
    x1[:] = c - 1j*d
    x2[:] = a - b
    x3[:] = c + 1j*d

In [180]:
x = np.arange(0, N, dtype='complex64')
s = x.copy().reshape(factors).transpose()

fft4(*(s[...,k] for k in range(4)))

for k in range(1, len(factors)):
    w = make_twiddle(4, 4**k).conj()[:,:]
    z = s.reshape([-1, 4, 4**k])
    z *= w
    fft4(*(z[..., l, :] for l in range(4)))
    s = z

In [181]:
abs(s.flatten() - np.fft.fft(x)).max()

2.5034746158780763e-05

## OpenCL

Each of the transforms are in-place, meaning we can do this in any order, but we don't know the system in the indices yet.

In [214]:
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g  # np.arange(N//radix, dtype=int)
s = x.reshape(factors).transpose()

np.array([s[...,k].flatten() for k in range(4)])

array([[0, 1, 2, 3, 1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2],
       [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3],
       [2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 3, 0],
       [3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 3, 0, 2, 3, 0, 1]])

This pattern is the same for each iteration. Now for the indices into each channel:

In [211]:
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    #x[list(i)] = g  # np.arange(N//radix, dtype=int)
    x[list(i)] = np.arange(N//radix, dtype=int)
s = x.reshape(factors).transpose()

np.array([s[...,k].flatten() for k in range(4)])

array([[ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
       [ 4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7],
       [ 8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11],
       [12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15]])

In [212]:
z = s.reshape([-1, 4, 4])
np.array([z[:,l,:].flatten() for l in range(4)])

array([[ 0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12],
       [ 1,  5,  9, 13,  1,  5,  9, 13,  1,  5,  9, 13,  1,  5,  9, 13],
       [ 2,  6, 10, 14,  2,  6, 10, 14,  2,  6, 10, 14,  2,  6, 10, 14],
       [ 3,  7, 11, 15,  3,  7, 11, 15,  3,  7, 11, 15,  3,  7, 11, 15]])

In [213]:
z = s.reshape([-1, 4, 16])
np.array([z[:,l,:].flatten() for l in range(4)])

array([[ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15],
       [ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15],
       [ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15],
       [ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15]])