In [1]:
import numpy as np

from plot import *
from fft import *
from maths import rmse


In [94]:
def dft3_at_level(data: np.ndarray, level:int, quantile: float = 0.75):
    if level in DFT3_LEVEL_CACHE:
        fft = DFT3_LEVEL_CACHE[level]
    else:
        fft = np.fft.rfftn(data)
        DFT3_LEVEL_CACHE[level] = fft

    amplitudes = np.abs(fft)

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []
    fft_k_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices = zip(*tuple((fft[i, j, k].real,
                                                                                  fft[i, j, k].imag, i, j, k)
                                                                                 for k in range(256)
                                                                                 for j in range(361)
                                                                                 for i in range(365 * 8)
                                                                                 if amplitudes[i, j, k] > cutoff_amp))
    fft_real = np.array(fft_real, dtype="float32") / 32768
    fft_imag = np.array(fft_imag, dtype="float32") / 32768

    return fft_real.astype("float16"), fft_imag.astype("float16"), \
        np.array(fft_i_indices, dtype="uint16"), np.array(fft_j_indices, dtype="uint16"), np.array(fft_k_indices, dtype="uint8")


def idft3_at_time_and_level(fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices):
    fft = np.zeros((365 * 8, 361, 289), dtype="complex64")

    fft_real = fft_real.astype("float32") * 32768
    fft_imag = fft_imag.astype("float32") * 32768

    for k in range(len(fft_i_indices)):
        fft[fft_i_indices[k], fft_j_indices[k], fft_k_indices[k]] = fft_real[k] + 1j * fft_imag[k]

    return np.fft.irfftn(fft)

In [95]:
def fit_dft3_at_level(filename: str, variable: str, level: int, **kwargs):
    print("Loading data...")
    data = load_variable_at_level(filename, variable, level)

    print("Performing DFT...")
    fft = dft3_at_level(data, level, **kwargs)

    print("Performing IDFT...")
    prediction = idft3_at_time_and_level(*fft)

    print(f"Original Stdev:  {data.astype('float32').std()} m/s")
    print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
    print(f"Frequencies: {len(fft[0])}")
    print(f"Size/level: {sum(el.nbytes for el in fft) / (1024 ** 2)} mB")
    print(f"Size/year: {sum(el.nbytes for el in fft) * 72 / (1024 ** 2)} mB")


In [97]:
%matplotlib notebook

fit_dft3_at_level("MERRA2_100.tavg3_3d_asm_Nv.1980{:0>2}{:0>2}.nc4", "U",
                  level=71, quantile=0.996)


Loading data...
Performing DFT...
Performing IDFT...
Original Stdev:  6.392603397369385 m/s
Predicted RMSE: 1.147797871369466 m/s
Frequencies: 1218507
Size/level: 10.458529472351074 mB
Size/year: 753.0141220092773 mB
