# In-Place Inverse Pass

In [2]:
def permute_indices(N, R_seq):
    idx = [i for i in range(N)]
    P = 1
    for R in reversed(R_seq[1:]):
        g = N//P
        def map_fn(i):
            s = i // g
            i = i % g
            res = (i%R) * g//R + i//R + s * g
            return res
        for i in range(N):
            idx[i] = map_fn(idx[i])
        P *= R
    return idx

In [3]:
import numpy as np

def fft(x, R_seq):
    N= len(x)
    X = [0] * N
    X[:] = x[:]
                
    P = 1
    for R in reversed(R_seq):
        T = N // R
        for i in range(T):
            # k, s = i // P, i % P
            k = i % (T//P)
            s = i // (T//P)
            idx = [k + r * N//(P*R) + s * N//P for r in range(R)]
            # print(f'{i} : {idx} -> {idx}')
            res = [0] * R
            for t in range(R):
                for r in range(R):
                    w0 = r*t/R
                    w1 = t*k*P/N
                    # print(f'w0={w0} w1={w1}')
                    twiddle = np.exp(-2j*np.pi * (w0+w1))
                    res[t] += X[idx[r]] * twiddle
            for t in range(R):
                X[idx[t]] = res[t]
        P *= R

    perm = permute_indices(N, R_seq) 
    for i in range(N):
        x[i] = X[perm[i]]
    return x

R_seq = [9,3,6,6]
N = np.prod(R_seq)
np.random.seed(0)
x0 = np.random.randn(N) + 1j*np.random.randn(N)
x1 = np.copy(x0)
x1 = fft(x1, R_seq)
print(np.linalg.norm(np.fft.fft(x0) - x1))

2.9498114564208594e-12


# Embedded FFT for Inverse Pass

In [4]:
def factors(N):
    factor_seq = []
    f = 2
    while N % f == 0:
        factor_seq.append(f)
        N //= f
    f = 3
    while f*f <= N:
        while N % f == 0:
            factor_seq.append(f)
            N //= f
        f += 2
    if N != 1:
        factor_seq.append(N)
    return factor_seq

In [5]:
def fft_internal_outplace(y,k_,N_,P_,R_):
    N = R_
    R_seq = factors(N)
    rb = [0] * N
    wb = [0] * N
    rb[:] = y[:]
    P = 1
    for R in reversed(R_seq):
        S = N//R
        for i in range(S):
            k,p = i//P, i%P
            dst = [P*R*k + p + P*r for r in range(R)]
            src = [i + t*N//R      for t in range(R)]
            for t in range(R):
                wb[dst[t]] = 0
                for r in range(R):
                    phi = r*t/R + t*k*P/N + t*k_*P*P_/N_
                    wb[dst[t]] += rb[src[r]] * np.exp(-2j*np.pi * phi)
        rb,wb = wb,rb
        P *= R
    return rb

def fft_internal_inplace(y,k_,N_,P_,R_):
    N = R_
    R_seq = factors(N)
    Y = [0] * N
    Y[:] = y[:]

    P = 1
    for R in reversed(R_seq):
        T = N//R
        for i in range(T):
            k,p = i//P, i%P
            idx = [k + r * N//(P*R) + p * N//P for r in range(R)]
            Z = [0] * R
            for t in range(R):
                for r in range(R):
                    phi = r*t/R + t*k*P/N + t*k_*P*P_/N_
                    Z[t] += Y[idx[r]] * np.exp(-2j*np.pi * phi)
            for t in range(R):
                Y[idx[t]] = Z[t]
        P *= R
    perm = permute_indices(N,R_seq)
    for i in range(N):
        y[i] = Y[perm[i]] #* np.exp(-2j*np.pi * i*k_*P_/N_)
    return y

fft_internal = fft_internal_inplace
R_seq = [9,3,6,6]
N=1
for R in R_seq: N *= R
np.random.seed(0)
x = np.random.randn(N) + 1j*np.random.randn(N)

print(np.linalg.norm(np.fft.fft(x) - fft_inplace(x,R_seq)))

2.393768410351114e-12


In [6]:
fft_internal = fft_internal_outplace

In [7]:
fft_internal = fft_internal_inplace

In [8]:
import numpy as np

def fft_inplace(x,R_seq):
    N= len(x)
    P = N
    X = np.zeros(N,dtype=np.complex128)
    X[:] = x[:]

    P = 1
    for R in reversed(R_seq):
        T = N//R
        for i in range(T):
            k,p = i//P, i%P
            idx = [k + r * N//(P*R) + p * N//P for r in range(R)]
            Y = fft_internal([ X[idx[t]] for t in range(R) ],k,N,P,R)
            for t in range(R):
                X[idx[t]] = Y[t]
        P *= R
    perm = permute_indices(N,R_seq)
    for i in range(N):
        x[i]= X[perm[i]]
    return x

def fft_outplace(x,R_seq):
    N=len(x)
    P = 1
    rb = [0] * N
    wb = [0] * N
    rb[:] = x[:]
    for R in R_seq:
        S = N//R
        for i in range(S):
            k,s = i//P, i%P
            dst = [P*R*k + s + P*r for r in range(R)]
            src = [i + t*N//R      for t in range(R)]
            Y = fft_internal([rb[r] for r in src],k,N,P,R)
            for t in range(R):
                wb[dst[t]] = Y[t]
        rb,wb = wb,rb
        P *= R
    return rb

# Symmetric Pass Opt for In-Place Inverse Pass

In [9]:
def conjmul(a,b):
    v0 = a.real * b.real
    v1 = a.imag * b.imag
    v2 = a.real * b.imag
    v3 = a.imag * b.real
    return v0 - v1 + (v2 + v3) * 1j, v0 + v1 + (-v2 + v3) * 1j

def fft(x, R_seq):
    X = np.zeros_like(x, dtype=np.complex128)
    N = len(X)

    idx = permute_indices(N, R_seq) 
    # for i in range(N):
    #     X[idx[i]] = x[i]
    for i in range(N):
        X[i] = x[i]
                
    P = 1
    for R in reversed(R_seq):
        T = N // R
        for i in range(T):
            k, s = i // P, i % P
            idx = [k + r * N//(P*R) + s * N//P for r in range(R)]
            vals = [X[src] for src in idx]
            X[idx[0]] = np.sum(vals)
            for t in range(1, R//2+1):
                y_0 = vals[0]
                y_1 = vals[0]
                for r in range(1, R):
                    phi = r*t/R
                    twiddle = np.exp(-2j * np.pi * phi)
                    val = vals[r]
                    m_0,m_1 = conjmul(val, twiddle)
                    y_0 += m_0
                    y_1 += m_1
                X[idx[t]] = y_0 * np.exp(-2j * np.pi * (t*k*P)/N)
                X[idx[R-t]] = y_1 * np.exp(-2j * np.pi * ((R-t)*k*P)/N)
        P *= R

    idx = permute_indices(N, R_seq) 
    for i in range(N):
        x[i] = X[idx[i]]
    # for i in range(N):
    #     x[i] = X[i]
    return x

R_seq = [3,5,7,9]
N = np.prod(R_seq)
np.random.seed(0)
x0 = np.random.randn(N) + 1j*np.random.randn(N)
x1 = np.copy(x0)
x1 = fft(x1, R_seq)
print(np.linalg.norm(np.fft.fft(x0) - x1))

1.104342542587768e-12
