In [9]:
import numpy as np
import math
from itertools import groupby
from functools import partial

def make_twiddle(n1, n2):
    def w(k, n):
        return np.exp(2j * np.pi * k / n)

    I1 = np.arange(n1)
    I2 = np.arange(n2)
    return w(I1[:,None] * I2[None,:], n1*n2).astype('complex64')

## Radix-n data streamlining

For a radix-2 FFT, it is possible to divide input data in such a way that every butterfly operation reads and writes to two different arrays in memory. This could be advantageous for an implementation on the FPGA. The trick is to separate data locations based on the parity of their index.

### Parity

The index of each element in the array can be written in binary. In the case of a radix-2 FFT, we can see the array being reshaped in a $[2, 2, 2, \dots]$ shape. Every stage of the $n=2^k$ sized radix-2 FFT is being performed in a different dimension of the $k$-dimensional reshaped array. The indexing in this multi-dimensional array is the same as the binary notation for the linear index. That means that each time the 2-FFT is performed, the multi-index of the elements involved will differ by one bit in the linear index. Separating the array based on the parity of the index will guarantee that each 2-FFT operation reads and writes one (complex) number from each set. The parity is the sum of all the individual bits, modulo 2.

$$P_2(i) = \left(\sum_k b_k\right) \mod 2,\quad {\rm where}\, i := \sum_k b_k 2^k$$

We can extend this concept to the radix-4 FFT, on 4 data channels.

### 4-parity

In the radix-4 FFT we can view the array as being reshaped to a $[4, 4, 4, \dots]$ shape. The multi-index into this array is equivalent to the quarternary number notation of the linear index. Similar to the radix-2 parity, we can define the radix-4 parity as the sum of each quarternary digit in the index, modulo 4.

$$P_4(i) = \left(\sum_k q_k\right) \mod 4,\quad {\rm where}\, i := \sum_k q_k 4^k$$

This would ensure that each 4-FFT reads its data from the 4 different input channels. We define the `parity` function to work with any radix:

In [10]:
def digits(n, i):
    """Generates the n-numbered digits of i, in reverse order."""
    while True:
        if i == 0:
            return
        else:
            i, q = divmod(i, n)
            yield q
            
def parity(n, i):
    return sum(digits(n, i)) % n

To see which index belongs in which channel, we use a `groupby` on keyed the `parity` function

In [11]:
def channels(N, radix):
    parity_r = partial(parity, radix)
    return groupby(sorted(range(N), key=parity_r), parity_r)

for instance, for radix-2, size 16, this creates the following channels

In [12]:
for (g, i) in channels(16, 2):
    print(g, list(i))

0 [0, 3, 5, 6, 9, 10, 12, 15]
1 [1, 2, 4, 7, 8, 11, 13, 14]


### Example: radix-2 16-FFT

We can implement the 16-FFT as a radix-2 operation on an array of shape $[2,2,2,2]$.

In [13]:
def fft2(x0, x1):
    y = np.zeros(shape=(2,)+x0.shape, dtype='complex64')
    y[0] = x0 + x1
    y[1] = x0 - x1
    x0[:] = y[0]
    x1[:] = y[1]

To check whether our grouping works, we create a dummy array, and a dummy `fft2_check` that just checks if all arguments are in different channels.

In [14]:
def fft2_check(x0, x1):
    assert np.all(x0 != x1)

In [15]:
x = np.zeros(shape=(16,), dtype=int)
for (g, i) in channels(16, 2):
    x[list(i)] = g

In [16]:
N = 16
radix = 2

factors = [radix] * int(math.log(N, radix))
s = x.reshape(factors).transpose()

try:
    for k in range(len(factors)):
        fft2_check(np.take(s, 0, axis=k), np.take(s, 1, axis=k))
except AssertionError:
    print("failed")
else:
    print("succeeded")

succeeded


In [22]:
# x = (np.random.normal(size=N) + 1j * np.random.normal(size=N)).astype('complex64')
x = np.arange(0, N, dtype='complex64')
s = x.copy().reshape(factors).transpose()

fft2(s[...,0], s[...,1])

for k in range(1, len(factors)):
    w = make_twiddle(2, 2**k).conj()[:,:]
    z = s.reshape([-1, 2, 2**k])
    z *= w
    fft2(z[..., 0, :], z[..., 1, :])
    s = z

In [23]:
s.flatten()

array([120.        +0.j       ,  -7.9999995+40.21872j  ,
        -8.       +19.31371j  ,  -8.       +11.972846j ,
        -8.        +8.j       ,  -8.        +5.3454294j,
        -8.        +3.3137083j,  -8.        +1.591299j ,
        -8.        +0.j       ,  -8.        -1.591299j ,
        -8.        -3.3137083j,  -8.        -5.3454294j,
        -8.        -8.j       ,  -8.       -11.972846j ,
        -8.       -19.31371j  ,  -7.9999995-40.21872j  ], dtype=complex64)

In [24]:
abs(s.flatten() - np.fft.fft(x)).max()

3.577337255373132e-06

### Example: radix-4 64-FFT

In [25]:
for g, i in channels(64, 4):
    print(g, list(i))

0 [0, 7, 10, 13, 19, 22, 25, 28, 34, 37, 40, 47, 49, 52, 59, 62]
1 [1, 4, 11, 14, 16, 23, 26, 29, 35, 38, 41, 44, 50, 53, 56, 63]
2 [2, 5, 8, 15, 17, 20, 27, 30, 32, 39, 42, 45, 51, 54, 57, 60]
3 [3, 6, 9, 12, 18, 21, 24, 31, 33, 36, 43, 46, 48, 55, 58, 61]


In [26]:
from itertools import combinations
import math

def fft4_check(x0, x1, x2, x3):
    for (a, b) in combinations([x0, x1, x2, x3], 2):
        assert np.all(a != b)

In [27]:
N = 64
radix = 4

x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g

In [28]:
factors = [radix] * int(math.log(N, radix))
s = x.reshape(factors)

try:
    for k in range(len(factors)):
        fft4_check(*(np.take(s, i, axis=k) for i in range(radix)))
except AssertionError:
    print("failed")
else:
    print("succeeded")

succeeded


## Implementing the radix-4 FFT

In [29]:
def fft4(x0, x1, x2, x3):
    a = x0 + x2
    b = x1 + x3
    c = x0 - x2
    d = x1 - x3
    x0[:] = a + b
    x1[:] = c - 1j*d
    x2[:] = a - b
    x3[:] = c + 1j*d

In [30]:
x = np.arange(0, N, dtype='complex64')
s = x.copy().reshape(factors).transpose()

fft4(*(s[...,k] for k in range(4)))

for k in range(1, len(factors)):
    w = make_twiddle(4, 4**k).conj()[:,:]
    z = s.reshape([-1, 4, 4**k])
    z *= w
    fft4(*(z[..., l, :] for l in range(4)))
    s = z

In [31]:
abs(s.flatten() - np.fft.fft(x)).max()

2.5034746158780763e-05

## OpenCL

Each of the transforms are in-place, meaning we can do this in any order, but we don't know the system in the indices yet.

In [46]:
# Fill `x` with the integer id of the array to which the location belongs
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g  # np.arange(N//radix, dtype=int)
s = x.reshape(factors).transpose()

perm = np.array([s[...,k].flatten() for k in range(4)])
perm

array([[0, 1, 2, 3, 1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2],
       [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3],
       [2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 3, 0],
       [3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 3, 0, 2, 3, 0, 1]])

This pattern is the same for each iteration. Note that the permutation is always the same cycle starting at a different number. This means the top row gives us which of the four permutations is needed. Do we implement four different versions of `fft4` to implement each permutation?

In [33]:
colid = np.argsort(perm[0]).reshape(4,4)
colid

array([[ 0,  7, 10, 13],
       [ 1,  4, 11, 14],
       [ 2,  5,  8, 15],
       [ 3,  6,  9, 12]])

These are the indices into the index array that will select which of the four permutated versions of `fft4` needs to be used.

Now for the indices into each channel:

In [49]:
# Here `x` is the index into each source array
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    #x[list(i)] = g  # np.arange(N//radix, dtype=int)
    x[list(i)] = np.arange(N//radix, dtype=int)
s = x.reshape(factors).transpose()

idx = np.array([s[...,k].flatten() for k in range(4)])
idx

array([[ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
       [ 4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7],
       [ 8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11],
       [12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15]])

So the first permutation is called with these indices

In [35]:
idx[:,colid[0]]

array([[ 0,  3,  2,  1],
       [ 4,  7,  6,  5],
       [ 8, 11, 10,  9],
       [12, 15, 14, 13]])

and so on ..., these all contain the same columns in different order. Meaning we can run each `fft4` permutation **with the same indices!**

In [36]:
z = s.reshape([-1, 4, 4])
idx = np.array([z[:,l,:].flatten() for l in range(4)])
idx

array([[ 0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12],
       [ 1,  5,  9, 13,  1,  5,  9, 13,  1,  5,  9, 13,  1,  5,  9, 13],
       [ 2,  6, 10, 14,  2,  6, 10, 14,  2,  6, 10, 14,  2,  6, 10, 14],
       [ 3,  7, 11, 15,  3,  7, 11, 15,  3,  7, 11, 15,  3,  7, 11, 15]])

In [37]:
idx[:,colid[0]]

array([[ 0, 12,  8,  4],
       [ 1, 13,  9,  5],
       [ 2, 14, 10,  6],
       [ 3, 15, 11,  7]])

In [38]:
z = s.reshape([-1, 4, 16])
idx = np.array([z[:,l,:].flatten() for l in range(4)])

In [42]:
idx[:,colid[0]]

array([[ 0, 13, 10,  7],
       [ 0, 13, 10,  7],
       [ 0, 13, 10,  7],
       [ 0, 13, 10,  7]])

Only the last instance is different. However, note that these are the same indices as we found in computing the 4-parity.

### Twiddles

In [70]:
# Fill `x` with the integer id of the array to which the location belongs
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g  # np.arange(N//radix, dtype=int)
channel = x.reshape(factors).transpose()

# Here `x` is the index into each source array
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    #x[list(i)] = g  # np.arange(N//radix, dtype=int)
    x[list(i)] = np.arange(N//radix, dtype=int)
index = x.reshape(factors).transpose()

perm = np.array([s[...,k].flatten() for k in range(4)])

# step 1
w = make_twiddle(4, 4).conj()
print(index.reshape([-1, 4, 4])[0])
for i in range(4):
    for j in range(4):
        for k in range(4):
            print("{}[{}] * {}".format(channel[i, j, k], index[i,j,k], w[j,k]), end=", ")
        print()

# step 2
w = make_twiddle(4, 16).conj()
print(channel.reshape([-1, 4, 16])[0])
print(index.reshape([-1, 4, 16])[0])
print(w)

[[ 0  4  8 12]
 [ 1  5  9 13]
 [ 2  6 10 14]
 [ 3  7 11 15]]
0[0] * (1-0j), 1[4] * (1-0j), 2[8] * (1-0j), 3[12] * (1-0j), 
1[1] * (1-0j), 2[5] * (0.9238795042037964-0.3826834261417389j), 3[9] * (0.7071067690849304-0.7071067690849304j), 0[13] * (0.3826834261417389-0.9238795042037964j), 
2[2] * (1-0j), 3[6] * (0.7071067690849304-0.7071067690849304j), 0[10] * (6.123234262925839e-17-1j), 1[14] * (-0.7071067690849304-0.7071067690849304j), 
3[3] * (1-0j), 0[7] * (0.3826834261417389-0.9238795042037964j), 1[11] * (-0.7071067690849304-0.7071067690849304j), 2[15] * (-0.9238795042037964+0.3826834261417389j), 
1[0] * (1-0j), 2[4] * (1-0j), 3[8] * (1-0j), 0[12] * (1-0j), 
2[1] * (1-0j), 3[5] * (0.9238795042037964-0.3826834261417389j), 0[9] * (0.7071067690849304-0.7071067690849304j), 1[13] * (0.3826834261417389-0.9238795042037964j), 
3[2] * (1-0j), 0[6] * (0.7071067690849304-0.7071067690849304j), 1[10] * (6.123234262925839e-17-1j), 2[14] * (-0.7071067690849304-0.7071067690849304j), 
0[3] * (1-0j), 1

We can split the twiddles in four parts, hard-coded. The indices we can compute real-time.

Thus, the algorithm becomes

In [236]:
N = 64
radix = 4
M = N // radix
depth = int(math.log(N, radix))
factors = [radix] * depth

# let x be the input array
x = np.arange(N, dtype='complex64')

# we'll execute the algorithm on `radix` number of channels
s = tuple(np.zeros(M, dtype='complex64') for r in range(radix))

# we perform the transpose at the same time as the group
i = np.arange(N, dtype=int).reshape(factors).transpose().flatten()
for g, j in channels(N, radix):
    s[g][:] = x[i][list(j)]
    
# we start with a step 1 vs. M // radix
p = M // radix
q = 1
while p > 0:
    for a in range(radix):
        fft4(s[0][q*a], s[1][q*a + p], s[2][q*a + 2*p], s[3][q*a + 3*p])
        fft4(s[1][q*a], s[2][q*a + p], s[3][q*a + 2*p], s[0][q*a + 3*p])
        fft4(s[2][q*a], s[3][q*a + p], s[0][q*a + 2*p], s[1][q*a + 3*p])
        fft4(s[3][q*a], s[0][q*a + p], s[1][q*a + 2*p], s[2][q*a + 3*p])
    q *= radix
    p //= radix


(array([ 0.+0.j, 52.+0.j, 40.+0.j, 28.+0.j, 49.+0.j, 37.+0.j, 25.+0.j,
        13.+0.j, 34.+0.j, 22.+0.j, 10.+0.j, 62.+0.j, 19.+0.j,  7.+0.j,
        59.+0.j, 47.+0.j], dtype=complex64),
 array([16.+0.j,  4.+0.j, 56.+0.j, 44.+0.j,  1.+0.j, 53.+0.j, 41.+0.j,
        29.+0.j, 50.+0.j, 38.+0.j, 26.+0.j, 14.+0.j, 35.+0.j, 23.+0.j,
        11.+0.j, 63.+0.j], dtype=complex64),
 array([32.+0.j, 20.+0.j,  8.+0.j, 60.+0.j, 17.+0.j,  5.+0.j, 57.+0.j,
        45.+0.j,  2.+0.j, 54.+0.j, 42.+0.j, 30.+0.j, 51.+0.j, 39.+0.j,
        27.+0.j, 15.+0.j], dtype=complex64),
 array([48.+0.j, 36.+0.j, 24.+0.j, 12.+0.j, 33.+0.j, 21.+0.j,  9.+0.j,
        61.+0.j, 18.+0.j,  6.+0.j, 58.+0.j, 46.+0.j,  3.+0.j, 55.+0.j,
        43.+0.j, 31.+0.j], dtype=complex64))