In [2]:
import numpy as np
from numpy import fft

from genfft.fft import make_twiddle
from genfft.plan import Array

In [3]:
N = 36
factors = [4, 9]

x = np.arange(N)
s = x.reshape(factors[::-1]).T

# Notw fft
y = fft.fft(s, axis=1)

# Twiddle
y *= make_twiddle(4, 9).conj()
z = fft.fft(y.T).T.flatten()

In [4]:
abs(fft.fft(x) - z).max()

2.3517349180080774e-06

In [5]:
from functools import reduce
import operator

In [6]:
def reduce_factors_back(f):
    return f[:-2] + [f[-2] * f[-1]]

In [7]:
f = [4, 4, 4]
N = reduce(operator.mul, f)

# x = np.arange(N)
x = np.random.normal(size=[N])
s = x.reshape(f[::-1]).transpose()

# Notw fft
y = fft.fft(s, axis=-1)

# Twiddle
y *= make_twiddle(*f[-2:]).conj()
y = fft.fft(y, axis=1)

# Twiddle
f = reduce_factors_back(f)
y = y.reshape(f) * make_twiddle(*f).conj()
z = fft.fft(y, axis=0).flatten()

In [8]:
abs(fft.fft(x) - z).max()

5.486820709182914e-07

In [9]:
def reduce_factors(f):
    return [f[0] * f[1]] + f[2:]

In [12]:
from textwrap import indent, wrap

def for_loop(var, start, n, step, body):
    if start == 0 or start == "0":
        end = "{}".format(n*step)
    else:
        end = "{}+{}".format(start, n*step)
        
    return "\n".join([
        "for ({var}={start}; {var}<{end}; {var}+={inc}) {{".format(var=var, start=start, end=end, inc=step),
        indent(body, "    "),
        "}"])

def inc_var(var):
    return chr(ord(var) + 1)

def twiddle_loop(a, var="a", start="0"):
    if a.ndim == 2:
        return "twiddle_{x}(output+{start}, output+{start}+1, w_{x}_{y}, {stride}, 0, {y}, 2)".format(x=a.shape[-2],start=start, y=a.shape[-1], stride=a.real.stride[-2])
    else:
        return for_loop(var, start, a.shape[0], a.real.stride[0], twiddle_loop(a[0], inc_var(var), start=var))
    
def decl_vars(n):
    return "int {};".format(", ".join(chr(i + ord("a")) for i in range(n)))

In [14]:
f = [4, 4, 4, 4, 2]
N = reduce(operator.mul, f)

# x = np.arange(N)
x = np.random.normal(size=[N])
s = x.reshape(f).transpose()
p = np.arange(N).reshape(f).transpose()

# Notw fft
print(decl_vars(len(f)-2))

offsets = p[..., 0, 0].flatten() * 2
print("__constant short offset[{size}] = {{\n{values}\n}};".format(
    size=len(offsets),
    values=indent("\n".join(wrap(", ".join(str(o) for o in offsets))), "    ")))

notw_stride = reduce(operator.mul, f[2:])
a_x = Array("x", "complex64", f).T

y = fft.fft(s, axis=-1)
p_y = np.arange(N).reshape(y.shape)
a_y = Array("y", "complex64", y.shape)

print(for_loop("a", 0, notw_stride, 1,
               "notw_{}(input + offset[a], input + offset[a] + 1, output + a*{stride}, output + a*{stride} + 1, {ins}, {outs}, {v}, {ivs}, {ovs})".format(
                   f[0], stride=notw_stride, ins=a_x.real.stride[-1], outs=a_y.real.stride[-1], v=f[1], ivs=a_x.real.stride[-2], ovs=a_y.real.stride[-2])))


while len(f) > 1:
    # Twiddle
    y *= make_twiddle(*f[1::-1]).conj()
    y = fft.fft(y, axis=-2)
    print(twiddle_loop(a_y))
    f = reduce_factors(f)
    y = y.reshape(f[::-1])
    p_y = p_y.reshape(f[::-1])
    a_y = a_y.reshape(f[::-1])

int a, b, c;
__constant short offset[32] = {
    0, 16, 32, 48, 4, 20, 36, 52, 8, 24, 40, 56, 12, 28, 44, 60, 2, 18,
    34, 50, 6, 22, 38, 54, 10, 26, 42, 58, 14, 30, 46, 62
};
for (a=0; a<32; a+=1) {
    notw_4(input + offset[a], input + offset[a] + 1, output + a*32, output + a*32 + 1, 256, 2, 4, 64, 8)
}
for (a=0; a<1024; a+=512) {
    for (b=a; b<a+512; b+=128) {
        for (c=b; c<b+128; c+=32) {
            twiddle_4(output+c, output+c+1, w_4_4, 8, 0, 4, 2)
        }
    }
}
for (a=0; a<1024; a+=512) {
    for (b=a; b<a+512; b+=128) {
        twiddle_4(output+b, output+b+1, w_4_16, 32, 0, 16, 2)
    }
}
for (a=0; a<1024; a+=512) {
    twiddle_4(output+a, output+a+1, w_4_64, 128, 0, 64, 2)
}
twiddle_2(output+0, output+0+1, w_2_256, 512, 0, 256, 2)
