In [3]:
import numpy as np

from plot import *
from fft import *
from maths import rmse


In [82]:
def dft3_at_level(data: np.ndarray, level: int, quantile: float = 0.75):
    if level in DFT3_LEVEL_CACHE:
        fft = DFT3_LEVEL_CACHE[level]
    else:
        fft = np.fft.rfftn(data)
        DFT3_LEVEL_CACHE[level] = fft

    amplitudes = np.abs(fft)

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []
    fft_k_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    i_dict = {}
    # j_dict = {}
    k_dict = {}

    last_j = 0

    for k in range(289):
        for j in range(361):
            for i in range(365 * 8):
                if amplitudes[i, j, k] < cutoff_amp:
                    continue

                fft_real.append(fft[i, j, k].real)
                fft_imag.append(fft[i, j, k].imag)

                if i not in i_dict:
                    i_dict[i] = len(i_dict)

                if k not in k_dict:
                    k_dict[k] = len(k_dict)

                if (dj := j - last_j) > 254:
                    fft_i_indices.append(i_dict[i])
                    fft_k_indices.append(k_dict[k])
                    fft_j_indices.append(254)
                    fft_real.append(0)
                    fft_imag.append(0)

                    last_j += 254
                    dj = j - last_j

                fft_i_indices.append(i_dict[i])
                if dj < 0:
                    fft_j_indices.append(255)
                    last_j = 0
                else:
                    fft_j_indices.append(dj)
                    last_j += dj

                fft_k_indices.append(k_dict[k])

    fft_real = np.array(fft_real, dtype="float32") / 32768
    fft_imag = np.array(fft_imag, dtype="float32") / 32768

    print("i indices:", len(i_dict))
    print("k indices:", len(k_dict))

    if len(i_dict) > 255:
        fft_i_indices = np.array(fft_i_indices, dtype="uint16")
    else:
        fft_i_indices = np.array(fft_i_indices, dtype="uint8")

    fft_j_indices = np.array(fft_j_indices, dtype="uint8")

    if len(k_dict) > 255:
        fft_k_indices = np.array(fft_k_indices, dtype="uint16")
    else:
        fft_k_indices = np.array(fft_k_indices, dtype="uint8")

    return fft_real.astype("float16"), fft_imag.astype("float16"), \
        fft_i_indices, fft_j_indices, fft_k_indices, \
        list(i_dict.keys()), list(k_dict.keys())


def idft3_at_time_and_level(fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices, i_dict,  k_dict):
    fft = np.zeros((365 * 8, 361, 289), dtype="complex64")

    fft_real = fft_real.astype("float32") * 32768
    fft_imag = fft_imag.astype("float32") * 32768

    j = 0
    for idx in range(len(fft_i_indices)):
        i = i_dict[fft_i_indices[idx]]

        dj = fft_j_indices[idx]
        if dj == 255:
            j = 0
        else:
            j += dj

        k = k_dict[fft_k_indices[idx]]

        fft[i, j, k] = fft_real[idx] + 1j * fft_imag[idx]

    return np.fft.irfftn(fft)

In [83]:
def fit_dft3_at_level(filename: str, variable: str, level: int, **kwargs):
    print("Loading data...")
    data = load_variable_at_level(filename, variable, level)

    print("Performing DFT...")
    fft = dft3_at_level(data, level, **kwargs)

    print("Performing IDFT...")
    prediction = idft3_at_time_and_level(*fft)

    print(f"Original Stdev:  {data.astype('float32').std()} m/s")
    print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
    print(f"Frequencies: {len(fft[0])}")
    print(f"Size/level: {sum(el.nbytes for el in fft if hasattr(el, 'nbytes')) / (1024 ** 2)} mB")
    print(f"Size/year: {sum(el.nbytes for el in fft if hasattr(el, 'nbytes')) * 72 / (1024 ** 2)} mB")


In [None]:
fit_dft3_at_level("MERRA2_100.tavg3_3d_asm_Nv.1980{:0>2}{:0>2}.nc4", "U",
                  level=71, quantile=0.997)

In [84]:
fit_dft3_at_level("MERRA2_100.tavg3_3d_asm_Nv.1980{:0>2}{:0>2}.nc4", "U",
                  level=71, quantile=0.997)


Loading data...
Performing DFT...
i indices: 1529
k indices: 241
Performing IDFT...
Original Stdev:  6.392603397369385 m/s
Predicted RMSE: 1.2533602702880768 m/s
Frequencies: 913978
Size/level: 6.9730987548828125 mB
Size/year: 502.0631103515625 mB
