In [1]:
import numpy as np
from numpy.fft import fft2, ifft2

In [2]:
def fourierExtrapolation2D(img, n_predict):
    h, w = img.shape
    n_harm = 10  # number of harmonics in the model

    # Create a mesh grid of time indices
    t_y, t_x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')

    # Find linear trend in the image
    p_y = np.polyfit(t_y.flatten(), img.flatten(), 1)
    p_x = np.polyfit(t_x.flatten(), img.flatten(), 1)

    # Detrend the image
    img_notrend = img - np.outer(t_y.flatten(), p_y) - np.outer(t_x.flatten(), p_x)

    # Compute 2D Fourier transform of detrended image
    img_freqdom = fft2(img_notrend)

    # Frequencies
    f_y = np.fft.fftfreq(h)
    f_x = np.fft.fftfreq(w)

    # Sort indices by frequency, lower -> higher
    indices_y = list(range(h))
    indices_x = list(range(w))

    indices_y.sort(key=lambda i: np.absolute(f_y[i]))
    indices_x.sort(key=lambda i: np.absolute(f_x[i]))

    # Create a mesh grid of prediction indices
    t_predict_y, t_predict_x = np.meshgrid(np.arange(h + n_predict), np.arange(w + n_predict), indexing='ij')

    # Initialize the restored image
    restored_img = np.zeros((h + n_predict, w + n_predict))

    for i in indices_y[:1 + n_harm * 2]:
        for j in indices_x[:1 + n_harm * 2]:
            ampli = np.absolute(img_freqdom[i, j]) / (h * w)  # amplitude
            phase = np.angle(img_freqdom[i, j])  # phase
            restored_img += ampli * np.cos(2 * np.pi * (f_y[i] * t_predict_y + f_x[j] * t_predict_x) + phase)

    # Add back the linear trend
    restored_img += np.outer(t_predict_y.flatten(), p_y) + np.outer(t_predict_x.flatten(), p_x)

    return restored_img
