In [1]:
import numpy as np
import math
from itertools import groupby
from functools import partial

def make_twiddle(n1, n2):
    def w(k, n):
        return np.exp(2j * np.pi * k / n)

    I1 = np.arange(n1)
    I2 = np.arange(n2)
    return w(I1[:,None] * I2[None,:], n1*n2).astype('complex64')

## Radix-n data streamlining

For a radix-2 FFT, it is possible to divide input data in such a way that every butterfly operation reads and writes to two different arrays in memory. This could be advantageous for an implementation on the FPGA. The trick is to separate data locations based on the parity of their index.

### Parity

The index of each element in the array can be written in binary. In the case of a radix-2 FFT, we can see the array being reshaped in a $[2, 2, 2, \dots]$ shape. Every stage of the $n=2^k$ sized radix-2 FFT is being performed in a different dimension of the $k$-dimensional reshaped array. The indexing in this multi-dimensional array is the same as the binary notation for the linear index. That means that each time the 2-FFT is performed, the multi-index of the elements involved will differ by one bit in the linear index. Separating the array based on the parity of the index will guarantee that each 2-FFT operation reads and writes one (complex) number from each set. The parity is the sum of all the individual bits, modulo 2.

$$P_2(i) = \left(\sum_k b_k\right) \mod 2,\quad {\rm where}\, i := \sum_k b_k 2^k$$

We can extend this concept to the radix-4 FFT, on 4 data channels.

### 4-parity

In the radix-4 FFT we can view the array as being reshaped to a $[4, 4, 4, \dots]$ shape. The multi-index into this array is equivalent to the quarternary number notation of the linear index. Similar to the radix-2 parity, we can define the radix-4 parity as the sum of each quarternary digit in the index, modulo 4.

$$P_4(i) = \left(\sum_k q_k\right) \mod 4,\quad {\rm where}\, i := \sum_k q_k 4^k$$

This would ensure that each 4-FFT reads its data from the 4 different input channels. We define the `parity` function to work with any radix:

In [2]:
def digits(n, i):
    """Generates the n-numbered digits of i, in reverse order."""
    while True:
        if i == 0:
            return
        else:
            i, q = divmod(i, n)
            yield q
            
def parity(n, i):
    return sum(digits(n, i)) % n

To see which index belongs in which channel, we use a `groupby` on keyed the `parity` function

In [3]:
def channels(N, radix):
    parity_r = partial(parity, radix)
    return groupby(sorted(range(N), key=parity_r), parity_r)

for instance, for radix-2, size 16, this creates the following channels

In [4]:
for (g, i) in channels(16, 2):
    print(g, list(i))

0 [0, 3, 5, 6, 9, 10, 12, 15]
1 [1, 2, 4, 7, 8, 11, 13, 14]


### Example: radix-2 16-FFT

We can implement the 16-FFT as a radix-2 operation on an array of shape $[2,2,2,2]$.

In [5]:
def fft2(x0, x1):
    y = np.zeros(shape=(2,)+x0.shape, dtype='complex64')
    y[0] = x0 + x1
    y[1] = x0 - x1
    x0[:] = y[0]
    x1[:] = y[1]

To check whether our grouping works, we create a dummy array, and a dummy `fft2_check` that just checks if all arguments are in different channels.

In [6]:
def fft2_check(x0, x1):
    assert np.all(x0 != x1)

In [7]:
x = np.zeros(shape=(16,), dtype=int)
for (g, i) in channels(16, 2):
    x[list(i)] = g

In [8]:
N = 16
radix = 2

factors = [radix] * int(math.log(N, radix))
s = x.reshape(factors).transpose()

try:
    for k in range(len(factors)):
        fft2_check(np.take(s, 0, axis=k), np.take(s, 1, axis=k))
except AssertionError:
    print("failed")
else:
    print("succeeded")

succeeded


In [9]:
# x = (np.random.normal(size=N) + 1j * np.random.normal(size=N)).astype('complex64')
x = np.arange(0, N, dtype='complex64')
s = x.copy().reshape(factors).transpose()

fft2(s[...,0], s[...,1])

for k in range(1, len(factors)):
    w = make_twiddle(2, 2**k).conj()[:,:]
    z = s.reshape([-1, 2, 2**k])
    z *= w
    fft2(z[..., 0, :], z[..., 1, :])
    s = z

In [10]:
s.flatten()

array([120.        +0.j       ,  -7.9999995+40.21872j  ,
        -8.       +19.31371j  ,  -8.       +11.972846j ,
        -8.        +8.j       ,  -8.        +5.3454294j,
        -8.        +3.3137083j,  -8.        +1.591299j ,
        -8.        +0.j       ,  -8.        -1.591299j ,
        -8.        -3.3137083j,  -8.        -5.3454294j,
        -8.        -8.j       ,  -8.       -11.972846j ,
        -8.       -19.31371j  ,  -7.9999995-40.21872j  ], dtype=complex64)

In [11]:
abs(s.flatten() - np.fft.fft(x)).max()

3.577337255373132e-06

### Example: radix-4 64-FFT

In [12]:
for g, i in channels(64, 4):
    print(g, list(i))

0 [0, 7, 10, 13, 19, 22, 25, 28, 34, 37, 40, 47, 49, 52, 59, 62]
1 [1, 4, 11, 14, 16, 23, 26, 29, 35, 38, 41, 44, 50, 53, 56, 63]
2 [2, 5, 8, 15, 17, 20, 27, 30, 32, 39, 42, 45, 51, 54, 57, 60]
3 [3, 6, 9, 12, 18, 21, 24, 31, 33, 36, 43, 46, 48, 55, 58, 61]


In [13]:
from itertools import combinations
import math

def fft4_check(x0, x1, x2, x3):
    for (a, b) in combinations([x0, x1, x2, x3], 2):
        assert np.all(a != b)

In [14]:
N = 64
radix = 4

x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g

In [15]:
factors = [radix] * int(math.log(N, radix))
s = x.reshape(factors)

try:
    for k in range(len(factors)):
        fft4_check(*(np.take(s, i, axis=k) for i in range(radix)))
except AssertionError:
    print("failed")
else:
    print("succeeded")

succeeded


## Implementing the radix-4 FFT

In [16]:
def fft4(x0, x1, x2, x3, w0=1, w1=1, w2=1, w3=1):
    a = w0*x0 + w2*x2
    b = w1*x1 + w3*x3
    c = w0*x0 - w2*x2
    d = w1*x1 - w3*x3
    x0[:] = a + b
    x1[:] = c - 1j*d
    x2[:] = a - b
    x3[:] = c + 1j*d

In [17]:
x = np.arange(0, N, dtype='complex64')
s = x.copy().reshape(factors).transpose()

fft4(*(s[...,k] for k in range(4)))

for k in range(1, len(factors)):
    w = make_twiddle(4, 4**k).conj()[:,:]
    z = s.reshape([-1, 4, 4**k])
    z *= w
    fft4(*(z[..., l, :] for l in range(4)))
    s = z

In [18]:
abs(s.flatten() - np.fft.fft(x)).max()

2.5034746158780763e-05

## OpenCL

Each of the transforms are in-place, meaning we can do this in any order, but we don't know the system in the indices yet.

In [19]:
# Fill `x` with the integer id of the array to which the location belongs
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    x[list(i)] = g  # np.arange(N//radix, dtype=int)
s = x.reshape(factors).transpose()

perm = np.array([s[...,k].flatten() for k in range(4)])
perm

array([[0, 1, 2, 3, 1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2],
       [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3],
       [2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 3, 0],
       [3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 3, 0, 2, 3, 0, 1]])

This pattern is the same for each iteration. Note that the permutation is always the same cycle starting at a different number. This means the top row gives us which of the four permutations is needed. Do we implement four different versions of `fft4` to implement each permutation?

In [20]:
colid = np.argsort(perm[0]).reshape(4,4)
colid

array([[ 0,  7, 10, 13],
       [ 1,  4, 11, 14],
       [ 2,  5,  8, 15],
       [ 3,  6,  9, 12]])

These are the indices into the index array that will select which of the four permutated versions of `fft4` needs to be used.

Now for the indices into each channel:

In [21]:
# Here `x` is the index into each source array
x = np.zeros(shape=(N,), dtype=int)
for g, i in channels(N, radix):
    #x[list(i)] = g  # np.arange(N//radix, dtype=int)
    x[list(i)] = np.arange(N//radix, dtype=int)
s = x.reshape(factors).transpose()

idx = np.array([s[...,k].flatten() for k in range(4)])
idx

array([[ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
       [ 4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7,  4,  5,  6,  7],
       [ 8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11,  8,  9, 10, 11],
       [12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15]])

So the first permutation is called with these indices

In [22]:
idx[:,colid[0]]

array([[ 0,  3,  2,  1],
       [ 4,  7,  6,  5],
       [ 8, 11, 10,  9],
       [12, 15, 14, 13]])

and so on ..., these all contain the same columns in different order. Meaning we can run each `fft4` permutation **with the same indices!**

In [23]:
z = s.reshape([-1, 4, 4])
idx = np.array([z[:,l,:].flatten() for l in range(4)])
idx

array([[ 0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12,  0,  4,  8, 12],
       [ 1,  5,  9, 13,  1,  5,  9, 13,  1,  5,  9, 13,  1,  5,  9, 13],
       [ 2,  6, 10, 14,  2,  6, 10, 14,  2,  6, 10, 14,  2,  6, 10, 14],
       [ 3,  7, 11, 15,  3,  7, 11, 15,  3,  7, 11, 15,  3,  7, 11, 15]])

In [24]:
idx[:,colid[0]]

array([[ 0, 12,  8,  4],
       [ 1, 13,  9,  5],
       [ 2, 14, 10,  6],
       [ 3, 15, 11,  7]])

In [25]:
z = s.reshape([-1, 4, 16])
idx = np.array([z[:,l,:].flatten() for l in range(4)])

In [73]:
idx[:,colid[1]]

array([[ 4,  1, 14, 11],
       [ 4,  1, 14, 11],
       [ 4,  1, 14, 11],
       [ 4,  1, 14, 11]])

Only the last instance is different. However, note that these are the same indices as we found in computing the 4-parity.

# An algorithm by recipe

In [27]:
from dataclasses import dataclass

@dataclass
class MultiChannel:
    N: int
    radix: int
    
    @property
    def depth(self):
        return int(math.log(self.N, self.radix))
    
    @property
    def M(self):
        return self.N//self.radix
    
    @property
    def L(self):
        return self.M//self.radix
    
    @property
    def factors(self):
        return [self.radix] * self.depth
    
    @property
    def channels(self):
        return channels(self.N, self.radix)
    
    @property
    def channel_loc(self):
        x = np.zeros(shape=(self.N,), dtype=int)
        for g, i in self.channels:
            x[list(i)] = g
        return x.reshape(self.factors)
    
    @property
    def index_loc(self):
        x = np.zeros(shape=(self.N,), dtype=int)
        for g, i in self.channels:
            x[list(i)] = np.arange(self.M, dtype=int)
        return x.reshape(self.factors)
    
    def mix(self, x):
        return tuple(x[list(i)].copy() for g, i in self.channels)

    def unmix(self, s):
        x = np.zeros(shape=(self.N,), dtype='complex64')
        for g, i in self.channels:
            x[list(i)] = s[g]
        return x

In [28]:
def fft4x(x0, x1, x2, x3, i0, i1, i2, i3, w0=1, w1=1, w2=1, w3=1):
    a = w0*x0[i0] + w2*x2[i2]
    b = w1*x1[i1] + w3*x3[i3]
    c = w0*x0[i0] - w2*x2[i2]
    d = w1*x1[i1] - w3*x3[i3]
    x0[i0] = a + b
    x1[i1] = c - 1j*d
    x2[i2] = a - b
    x3[i3] = c + 1j*d

In [29]:
mc = MultiChannel(64, 4)

# let x be the input array
x = np.arange(0, N, dtype='complex64')

# <lin> ===========================================
s_lin = x.copy().reshape(mc.factors).transpose()
fft4(*(s_lin[...,k] for k in range(4)))
# </lin> ==========================================

mc =  MultiChannel(N=64, radix=4)

x = np.arange(N, dtype='complex64')
s = mc.mix(x.reshape(mc.factors).transpose().flatten())
    
# 16 fft-4, the inner-loop is best left unfolded
ca = mc.channel_loc.reshape([-1, 4])
ia = mc.index_loc.reshape([-1, 4])

for c, i in zip(ca, ia):
    fft4x(*[s[i] for i in c], *i)

assert abs(mc.unmix(s) - s_lin.flatten()).max() < 1e-4

In [30]:
# <lin> ===========================================
w = make_twiddle(4, 4).conj()
s_lin *= w
fft4(*(s_lin[..., l, :] for l in range(4)))
# </lin> ==========================================

ca = mc.channel_loc.transpose([0,2,1]).reshape([-1, 4])
ia = mc.index_loc.transpose([0,2,1]).reshape([-1, 4])
w_4_4 = (np.ones(shape=[4,4,4]) * w).transpose([0,2,1]).reshape([-1,4])

for c, i, w in zip(ca, ia, w_4_4):
    fft4x(*[s[i] for i in c], *i, *w)

assert abs(mc.unmix(s) - s_lin.flatten()).max() < 1e-4

In [31]:
# <lin> ===========================================
w = make_twiddle(4, 16).conj()
s_lin *= w.reshape([4,4,4])
fft4(*(s_lin[l] for l in range(4)))
# </lin> ==========================================

In [32]:
ca = mc.channel_loc.transpose([1,2,0]).reshape([-1, 4])
ia = mc.index_loc.transpose([1,2,0]).reshape([-1, 4])
w_4_16 = w.reshape([4,4,4]).transpose([1,2,0]).reshape([-1,4])

for c, i, w in zip(ca, ia, w_4_16):
    fft4x(*[s[i] for i in c], *i, *w)

assert abs(mc.unmix(s) - s_lin.flatten()).max() < 1e-4
assert abs(mc.unmix(s) - np.fft.fft(x)).max() < 1e-4

## Observations

The `ca` array is always the same.

# Unrolling the recipe for permutations

The pattern of the permutation shifts one every outer loop. To correct for this we need to permute all input factors

In [33]:
ca = mc.channel_loc.reshape([4, 4, 4])
perm = (ca[:,0,:].argsort(axis=1) + (np.arange(4)*4)[:,None]).flatten()

In [34]:
perm

array([ 0,  1,  2,  3,  7,  4,  5,  6, 10, 11,  8,  9, 13, 14, 15, 12])

In [35]:
[[i*4 + (rot-i)%4 for rot in range(4)] for i in range(4)]

[[0, 1, 2, 3], [7, 4, 5, 6], [10, 11, 8, 9], [13, 14, 15, 12]]

In [36]:
mc =  MultiChannel(N=64, radix=4)

x = np.arange(N, dtype='complex64')
s = mc.mix(x.reshape(mc.factors).transpose().flatten())
w_0 = np.ones(shape=[16,4])

w = make_twiddle(4, 4).conj()
w_4_4 = (np.ones(shape=[4,4,4]) * w).transpose([0,2,1]).reshape([-1,4])[perm]

w = make_twiddle(4, 16).conj()
w_4_16 = w.reshape([4,4,4]).transpose([1,2,0]).reshape([-1,4])[perm]

W = np.r_[w_0, w_4_4, w_4_16]

fa = 1
fb = 1
fc = 4
Wp = 0

for k in range(3):
    for i in range(4):
        for rot in range(4):
            ix = [fa * (i*4 + (rot - i)%4) + (1-fa) * (i*fc + j*fb) for j in range(4)]
            fft4x(*[s[(rot+j) % 4] for j in range(4)],
                  *ix,
                  *W[Wp])
            Wp += 1
    if fa == 0:
        fc //= 4
        fb *= 4
    else:
        fa = 0

In [37]:
assert abs(mc.unmix(s) - np.fft.fft(x)).max() < 1e-4

In [38]:
mc =  MultiChannel(N=1024, radix=4)
# perm = [i*mc.radix + (rot-i)%mc.radix for i in range(mc.L) for rot in range(mc.radix)]

ca = mc.channel_loc.reshape(mc.factors)
# perm = (ca[...,0,:].argsort(axis=1) + (np.arange(16)*4)[:,None]).flatten()

In [63]:
def comp_perm(radix, i):
    base = (i & ~(radix - 1))
    rem = (i & (radix - 1))
    p = parity(radix, base)
    return base | ((rem - p) % radix)

def comp_idx(radix, i, j, k):
    base = (i & ~(radix**k - 1))
    rem  = (i &  (radix**k - 1))
    return rem + j * radix**k + base * radix

In [138]:
%%capture fft_1024_mc_cl

# x = np.arange(mc.N, dtype='complex64')
x = np.random.normal(size=mc.N) + 1j*np.random.normal(size=mc.N)
s = mc.mix(x.reshape(mc.factors).transpose().flatten())
W = np.ones(shape=[mc.radix, mc.radix])
perm = np.array([comp_perm(mc.radix, i) for i in range(mc.M)])

n = mc.radix
for k in range(mc.depth - 1):
    w = make_twiddle(mc.radix, n).conj()
    w_r_x = (np.ones(shape=[mc.M//n,mc.radix,n]) * w).transpose([0,2,1]).reshape([-1,mc.radix])[perm]
    W = np.r_[W, w_r_x]
    n *= mc.radix

print(f"__constant float2 W[{W.shape[0]}][{mc.radix-1}] {{")
print( "    {" + "},\n    {".join(", ".join(f"(float2) ({w.real: f}f, {w.imag: f}f)" for w in ws[1:]) for ws in W) + "}};")

print( "void fft_4(__restrict float2 *s0, __restrict float2 *s1, __restrict float2 *s2, __restrict float2 *s3,")
print( "           int i0, int i1, int i2, int i3, int iw)")
print( "{")
print( "}")

print(f"void fft_{mc.N}(__restrict float2 *s0, __restrict float2 *s1, __restrict float2 *s2, __restrict float2 *s3)")
print( "{")
print( "    bool first = true;")
print(f"    for (int k = 0; k < {mc.depth}; ++k) {{")
print(f"        for (int i = 0; i < {mc.L}; ++i) {{")        
print( "            if (!first) {")
print( "            } else {")
print( "            }")
print( "        }")
print( "    }")
print( "}")

fa = True
Wp = 0

for k in range(mc.depth):
    for i in range(mc.L):
        if not fa:
            ix = [comp_idx(mc.radix, i, j, k-1) for j in range(mc.radix)]
            Wp += 4
        else:
            ix = [comp_perm(mc.radix, i*mc.radix)] * mc.radix
        fft4x(s[0], s[1], s[2], s[3], *ix, *W[Wp])
        if fa:
            ix = [comp_perm(mc.radix, i*mc.radix+1)] * mc.radix
        fft4x(s[1], s[2], s[3], s[0], *ix, *W[Wp+1])
        if fa:
            ix = [comp_perm(mc.radix, i*mc.radix+2)] * mc.radix
        fft4x(s[2], s[3], s[0], s[1], *ix, *W[Wp+2])
        if fa:
            ix = [comp_perm(mc.radix, i*mc.radix+3)] * mc.radix
        fft4x(s[3], s[0], s[1], s[2], *ix, *W[Wp+3])
    fa = False

In [135]:
abs(mc.unmix(s) - np.fft.fft(x)).max()

7.588975361366367e-06

# Running on Vanilla OpenCL

In [259]:
import os
import pyopencl as cl
import pyopencl.cltypes
import numpy as np

os.environ["PYOPENCL_CTX"] = "Intel(R) OpenCL HD Graphics"

def max_err(a, b):
    return np.abs(a - b).max()

np.set_printoptions(threshold=100)

In [260]:
kernel = open("fft1024.cl", "r").read()

In [268]:
ctx = cl.create_some_context()
prog = cl.Program(ctx, kernel).build(["-DTESTING"])
queue = cl.CommandQueue(ctx)

In [269]:
x = np.arange(1024, dtype='complex64')
y = np.zeros_like(x)

In [270]:
mf = cl.mem_flags
x_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_g = cl.Buffer(ctx, mf.WRITE_ONLY, x.nbytes)

In [271]:
prog.fft_1024(queue, (1,), None, x_g, y_g)

<pyopencl._cl.Event at 0x7fc2e80af9b0>

In [272]:
cl.enqueue_copy(queue, y, y_g)

<pyopencl._cl.NannyEvent at 0x7fc2e80af7d0>

In [273]:
y

array([ 7.1767319e+28+1.2452886e+29j,  3.5380017e+28+7.5806789e+28j,
       -2.7162168e+29+2.0000899e+29j, ..., -2.0000000e+00+2.0000000e+00j,
       -2.0000000e+00+0.0000000e+00j, -2.0000000e+00-2.0000000e+00j],
      dtype=complex64)

In [274]:
x = np.arange(1024, dtype=cl.cltypes.int)
y = np.zeros_like(x)
x_g = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=x)
y_g = cl.Buffer(ctx, mf.WRITE_ONLY, x.nbytes)
prog.test_parity_4(queue, (1024,), None, x_g, y_g)
cl.enqueue_copy(queue, y, y_g)

<pyopencl._cl.NannyEvent at 0x7fc2d049d7d0>

In [275]:
y

array([0, 1, 2, ..., 1, 2, 3], dtype=int32)

In [276]:
y_ref = np.array([parity(4, i) for i in range(1024)])

In [277]:
np.all(y == y_ref)

True