In [1]:
import numpy as np
import cv2
import taichi as ti
import taichi.math as tm


@ti.func
def log2_int(n: ti.u32):
    res = 0
    if n & ti.u32(0xffff0000):
        res += 16
        n >>= 16
    if n & ti.u32(0x0000ff00):
        res += 8
        n >>= 8
    if n & ti.u32(0x000000f0):
        res += 4
        n >>= 4
    if n & ti.u32(0x0000000c):
        res += 2
        n >>= 2
    if n & ti.u32(0x00000002):
        res += 1
        n >>= 1
    return res


def log2_int_static(n: np.uint32):
    res = 0
    if n & np.uint32(0xffff0000):
        res += 16
        n >>= 16
    if n & np.uint32(0x0000ff00):
        res += 8
        n >>= 8
    if n & np.uint32(0x000000f0):
        res += 4
        n >>= 4
    if n & np.uint32(0x0000000c):
        res += 2
        n >>= 2
    if n & np.uint32(0x00000002):
        res += 1
        n >>= 1
    return res


@ti.func
def fast_mod(x, m):
    '''
    m must be the power of 2
    '''
    return x & (m - 1)


@ti.func
def fast_div(x, log2_m):
    '''
    m must be the power of 2
    '''
    return x >> log2_m


@ti.func
def fast_mul(x, log2_m):
    '''
    m must be the power of 2
    '''
    return x << log2_m


@ti.func
def print_complex_img(img: ti.template()):
    for i, j in img:
        if j != img.shape[1] - 1:
            print(f'{i}{j}[{img[i, j].x:.2f},{img[i, j].y:.2f}]', end=' ')
        else:
            print(f'{i}{j}[{img[i, j].x:.2f},{img[i, j].y:.2f}]')
    print()


@ti.func
def print_buffer(buffer: ti.template(), flag, size_x):
    for i in range(buffer.shape[1]):
        x_coord = fast_mod(i, size_x)
        y_coord = fast_div(i, log2_int(size_x))
        if x_coord != size_x - 1:
            print(
                f'{y_coord}{x_coord}[{buffer[flag,i].x:.2f},{buffer[flag,i].y:.2f}]', end=' ')
        else:
            print(
                f'{y_coord}{x_coord}[{buffer[flag,i].x:.2f},{buffer[flag,i].y:.2f}]')
    print()


@ti.func
def fft_pass(size: ti.template(),
             log2_size: ti.template(),
             pixel_size,
             log2_pixel_size,
             inverse: ti.template(),
             buffer: ti.template(),
             target_):
    ti.loop_config(block_dim=size)
    for i in range(pixel_size):
        target = target_
        for s in range(log2_size):
            b = size >> (s + 1)  # [4 2 1] # block size
            log2_b = log2_size - s - 1
            # opt w_ = (k // b_) * b_
            # overall rotation #begin offset
            w = fast_mul(fast_div(i, log2_b), log2_b)
            ev = fast_mod(w + i, size) + \
                fast_mul(fast_div(i, log2_size), log2_size)
            od = ev + b
            phi = -2 * tm.pi / size * w

            twiddle = tm.vec2(tm.cos(phi), tm.sin(phi))
            if inverse:
                twiddle.y = -twiddle.y

            od_v = buffer[not target, od]
            buffer[target, i] = \
                buffer[not target, ev] + \
                tm.vec4(tm.cmul(od_v.xy, twiddle),
                        tm.cmul(od_v.zw, twiddle))

            target = not target
            ti.simt.block.mem_sync()


@ti.func
def fft_2d(xy: ti.template(), buffer: ti.template(), inverse: ti.template()):
    '''
    xy : [y_size, x_size][4] 2 complex
    buffer: [2, x_size * y_size][4] 2 * 2 complex
    '''
    x_size = ti.static(xy.shape[1])
    y_size = ti.static(xy.shape[0])
    log2_x_size = ti.static(log2_int_static(x_size))
    log2_y_size = ti.static(log2_int_static(y_size))

    pixel_size = ti.static(x_size * y_size)
    log2_pixel_size = ti.static(log2_int_static(pixel_size))

    # ti.block_local(buffer)

    target: ti.i32 = 1
    # buffer copy
    for i, j in xy:
        index = fast_mul(i, log2_x_size) + j
        buffer[0, index] = xy[i, j]

    # horizontal
    fft_pass(x_size, log2_x_size, pixel_size, log2_pixel_size,
             inverse, buffer, target)

    if fast_mod(log2_x_size, 2) == 1:
        target = not target

    # swap x,y
    ti.loop_config(block_dim=x_size)
    for i, j in xy:
        old_index = fast_mul(i, log2_x_size) + j
        new_index = fast_mul(j, log2_y_size) + i
        if inverse:
            buffer[target, new_index] = buffer[not target, old_index]/x_size
        else:
            buffer[target, new_index] = buffer[not target, old_index]
    target = not target

    # vertical
    fft_pass(y_size, log2_y_size, pixel_size, log2_pixel_size,
             inverse, buffer, target)

    if fast_mod(log2_y_size, 2) == 1:
        target = not target

    # copy back
    for i, j in xy:
        index = fast_mul(j, log2_y_size) + i
        if inverse:
            xy[i, j] = buffer[not target, index] / y_size
        else:
            xy[i, j] = buffer[not target, index]


@ti.func
def vec4_cconj(vec: tm.vec4):
    return tm.vec4(vec.x, -vec.y, vec.z, -vec.w)


@ti.kernel
def quad_fft(img: ti.template(),
             freq: ti.template(),
             buffer: ti.template(),
             inverse: ti.template()):
    '''
    img: [size_y, size_x][4] rgba
    freq: [size_y, size_x, 2][4] (rfgf, bfaf)
    buffer: [2, size_y * size_x][4]
    '''
    n = ti.static(img.shape[0])
    m = ti.static(img.shape[1])
    if ti.static(not inverse):
        fft_2d(img, buffer, inverse=inverse)
        for i, j in img:
            z1 = img[i, j]
            z2 = vec4_cconj(img[n-i, m-j])
            rfgf = tm.vec4(0.5 * (z1.xy + z2.xy),
                           tm.cmul(tm.vec2(0, -0.5), z1.xy - z2.xy))
            bfaf = tm.vec4(0.5 * (z1.zw + z2.zw),
                           tm.cmul(tm.vec2(0, -0.5), z1.zw - z2.zw))
            freq[i, j, 0] = rfgf
            freq[i, j, 1] = bfaf
    else:
        for i, j in img:
            img[i, j].xy = freq[i, j, 0].xy + \
                tm.cmul(tm.vec2(0, 1), freq[i, j, 0].zw)
            img[i, j].zw = freq[i, j, 1].xy + \
                tm.cmul(tm.vec2(0, 1), freq[i, j, 1].zw)
        fft_2d(img, buffer, inverse=inverse)


# numpy replacement for quad_fft
def numpy_quad_fft(img, freq, buffer, inverse):
    np_img = img.to_numpy()
    [r, g, b, a] = np.squeeze(np.split(np_img, 4, axis=2))
    np_freq = freq.to_numpy()

    if not inverse:
        fft_res_rg = np.fft.fft2(r + 1j*g)
        fft_res_ba = np.fft.fft2(b + 1j*a)
        R = 1/2 * (fft_res_rg + np.conj(fft_res_rg[::-1, ::-1]))
        G = -1j/2 * (fft_res_rg - np.conj(fft_res_rg[::-1, ::-1]))
        B = 1/2 * (fft_res_ba + np.conj(fft_res_ba[::-1, ::-1]))
        A = -1j/2 * (fft_res_ba - np.conj(fft_res_ba[::-1, ::-1]))
        np_freq[:, :, 0, 0] = R.real
        np_freq[:, :, 0, 1] = R.imag
        np_freq[:, :, 0, 2] = G.real
        np_freq[:, :, 0, 3] = G.imag
        np_freq[:, :, 1, 0] = B.real
        np_freq[:, :, 1, 1] = B.imag
        np_freq[:, :, 1, 2] = A.real
        np_freq[:, :, 1, 3] = A.imag
    else:
        rg_ifft_in = np_freq[:, :, 0, 0] + np_freq[:, :, 0, 1]*1j + \
            1j*(np_freq[:, :, 0, 2] + np_freq[:, :, 0, 3]*1j)
        ba_ifft_in = np_freq[:, :, 1, 0] + np_freq[:, :, 1, 1]*1j + \
            1j*(np_freq[:, :, 1, 2] + np_freq[:, :, 1, 3]*1j)
        fft_res_rg = np.fft.ifft2(rg_ifft_in)
        fft_res_ba = np.fft.ifft2(ba_ifft_in)
        np_img[:, :, 0] = fft_res_rg.real
        np_img[:, :, 1] = fft_res_rg.imag
        np_img[:, :, 2] = fft_res_ba.real
        np_img[:, :, 3] = fft_res_ba.imag
    img.from_numpy(np_img)
    freq.from_numpy(np_freq)


@ti.kernel
def field_transpose(src: ti.template(), dst: ti.template()):
    for i, j in src:
        dst[j, i] = src[i, j]


@ti.kernel
def compute_frequency_filter(frequency_filter: ti.template(), x_scale: ti.f32, y_scale: ti.f32):
    size_y , size_x = frequency_filter.shape
    for i, j in frequency_filter:
        y_distance = (i + size_y // 2) % size_y - size_y // 2
        x_distance = (j + size_x // 2) % size_x - size_x // 2
        distance = ti.sqrt(x_distance * x_distance *
                           x_scale + y_distance * y_distance * y_scale)
        frequency_filter[i, j] = ti.exp(-distance / 2)


@ti.kernel
def image_blur(freq: ti.template(), frequency_filter: ti.template()):
    '''
    freq: [size_y, size_x, 2][4] (rfgf, bfaf)
    '''
    for i, j, k in freq:
        for c in ti.static(range(4)):
            freq[i, j, k][c] *= frequency_filter[i, j]


def main():

    size_x = 1 << 10
    size_y = 1 << 9

    img_cv = cv2.imread(
        r'C:/Users/Estelle/source/repos/TaichiSandBox/src/fft/ayanami.png', cv2.IMREAD_COLOR)
    img_cv = cv2.resize(img_cv, (size_x, size_y))
    img_cv = img_cv.astype(np.float32)
    img_cv = img_cv / 255.0
    # reorganize the image as rgb and inverse the y axis
    img_cv = img_cv[::-1, :, ::-1]
    # add alpha channel
    img_cv = np.concatenate(
        (img_cv, np.ones((size_y, size_x, 1), dtype=np.float32)), axis=2)

    ti.init(arch=ti.vulkan, kernel_profiler=True,debug=True)

    # load to ti
    rgba = ti.Vector.field(4, dtype=ti.f32, shape=(size_y, size_x))
    rgba_freq = ti.Vector.field(4, dtype=ti.f32, shape=(size_y, size_x, 2))
    buffer = ti.Vector.field(4, dtype=ti.f32, shape=(2, size_x * size_y))
    original = ti.Vector.field(4, dtype=ti.f32, shape=(size_y, size_x))

    original.from_numpy(img_cv)
    rgba.from_numpy(img_cv)
    transposed = ti.Vector.field(4, dtype=ti.f32, shape=(size_x, size_y))

    # for blur the image
    frequency_fiter = ti.field(dtype=ti.f32, shape=(size_y, size_x))
    frequency_fiter.fill(1)

    gui = ti.GUI("Real Time FFT Image Convolution", (size_x, size_y))

    filter_x_scale = gui.slider("filter_x_scale", 0.0, 1.0, 0.5)
    filter_y_scale = gui.slider("filter_y_scale", 0.0, 1.0, 0.5)

    class Switch:
        def __init__(self, name, value=False):
            self.button = gui.button(name,)
            self.lable = gui.label(name)
            self.lable.value = value

        # attribute access of value
        @property
        def value(self):
            return self.lable.value

        @value.setter
        def value(self, value):
            self.lable.value = value

    switch_profile_print = Switch("profile print", False)
    switch_fft_img_update = Switch("fft img", False)

    use_numpy_fft = False

    use_numpy_fft_button = gui.button("numpy impl")
    use_taichi_fft_button = gui.button("taichi impl")

    try:
        while gui.running:
            for e in gui.get_events(ti.GUI.PRESS):
                if e.key == ti.GUI.ESCAPE:
                    gui.running = False
                elif e.key == switch_profile_print.button:
                    switch_profile_print.value = not switch_profile_print.value
                elif e.key == switch_fft_img_update.button:
                    switch_fft_img_update.value = not switch_fft_img_update.value
                    if not switch_fft_img_update.value:
                        cv2.destroyWindow("frequency amplitude")
                elif e.key == use_numpy_fft_button:
                    use_numpy_fft = True
                elif e.key == use_taichi_fft_button:
                    use_numpy_fft = False

            compute_frequency_filter(
                frequency_fiter, filter_x_scale.value, filter_y_scale.value)
            if switch_profile_print.value:
                ti.profiler.clear_kernel_profiler_info()
            rgba.copy_from(original)
            if use_numpy_fft:
                numpy_quad_fft(rgba, rgba_freq, buffer, False)
            else:
                quad_fft(rgba, rgba_freq, buffer, False)

            image_blur(rgba_freq, frequency_fiter)
            if switch_fft_img_update.value:
                data = np.fft.fftshift(axes=(0, 1),
                                       x=rgba_freq.to_numpy()[::-1, :, :, ::-1]
                                       )
                # turn anix 2 into complex
                # freq: [size_y, size_x, 2][4] (rfgf, bfaf)
                data = data[:, :, 0, 0] + data[:, :, 1, 0] * 1j +\
                    data[:, :, 0, 1] * 1j + data[:, :, 1, 1] * 1j + \
                    data[:, :, 0, 2] * 1j + data[:, :, 1, 2] * 1j + \
                    data[:, :, 0, 3] * 1j + data[:, :, 1, 3] * 1j
                data = np.abs(data)
                data = np.log(data+1)
                min_pixel = np.min(data, axis=(0, 1))
                max_pixel = np.max(data, axis=(0, 1))
                data = (data-min_pixel)/(max_pixel-min_pixel)
                cv2.imshow("frequency amplitude", data)
            if use_numpy_fft:
                numpy_quad_fft(rgba, rgba_freq, buffer, True)
            else:
                quad_fft(rgba, rgba_freq, buffer, True)
            if switch_profile_print.value:
                ti.profiler.print_kernel_profiler_info('trace')

            field_transpose(rgba, transposed)
            gui.set_image(transposed)
            gui.show()
    except:
        gui.close()
        cv2.destroyAllWindows()
        raise
    cv2.destroyAllWindows()
    gui.close()


if __name__ == '__main__':
    main()

[Taichi] version 1.6.0, llvm 15.0.1, commit f1c6fbbd, win, python 3.11.3
[Taichi] Starting on arch=vulkan
