In [8]:
import pickle

from plot import *
from fft import *
from maths import rmse, mae


In [43]:
DFT3_LEVEL_CACHE = {}


def dft3_at_level(data: np.ndarray, level: int, quantile: float = 0.75):
    if level in DFT3_LEVEL_CACHE:
        fft, amplitudes = DFT3_LEVEL_CACHE[level]
    else:
        fft = np.fft.rfftn(data)
        amplitudes = np.abs(fft)

        DFT3_LEVEL_CACHE[level] = fft, amplitudes

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []
    fft_k_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    print("Filtering frequencies...")
    for k in range(289):
        for j in range(361):
            for i in range(365 * 8):
                if amplitudes[i, j, k] < cutoff_amp:
                    continue

                fft_real.append(fft[i, j, k].real)
                fft_imag.append(fft[i, j, k].imag)
                fft_i_indices.append(i)
                fft_j_indices.append(j)
                fft_k_indices.append(k)

    print("Encoding DFT...")
    fft_real = (np.array(fft_real, dtype="float32") / 32768).astype("float16")
    fft_real = encode_zlib(fft_real)

    fft_imag = (np.array(fft_imag, dtype="float32") / 32768).astype("float16")
    fft_imag = encode_zlib(fft_imag)

    fft_i_indices = np.array(fft_i_indices, dtype="int16")
    fft_i_indices = encode_difference_uint8(fft_i_indices)
    fft_i_indices = encode_zlib(fft_i_indices)

    fft_j_indices = np.array(fft_j_indices, dtype="int16")
    fft_j_indices = encode_difference_uint8(fft_j_indices)
    fft_j_indices = encode_zlib(fft_j_indices)

    fft_k_indices = np.array(fft_k_indices, dtype="int16")
    fft_k_indices = encode_difference_uint8(fft_k_indices)
    fft_k_indices = encode_zlib(fft_k_indices)

    return fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices


def idft3_at_level(fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices):
    ifft = np.zeros((365 * 8, 361, 289), dtype="complex64")

    print("Decoding DFT...")
    fft_real = decode_zlib(fft_real, dtype="float16")
    fft_imag = decode_zlib(fft_imag, dtype="float16")
    fft = fft_real.astype("complex64") * 32768 + fft_imag.astype("complex64") * 32768j

    fft_i_indices = decode_zlib(fft_i_indices)
    fft_i_indices = decode_difference_uint8(fft_i_indices)

    fft_j_indices = decode_zlib(fft_j_indices)
    fft_j_indices = decode_difference_uint8(fft_j_indices)

    fft_k_indices = decode_zlib(fft_k_indices)
    fft_k_indices = decode_difference_uint8(fft_k_indices)

    print("Building IDFT...")
    for idx in range(len(fft)):
        ifft[fft_i_indices[idx], fft_j_indices[idx], fft_k_indices[idx]] = fft[idx]

    return np.fft.irfftn(ifft)

In [41]:
def fit_dft3_at_level(filename: str, variable: str, level: int, **kwargs):
    output = f"models/3D-dft/{get_year_from_filename(filename)}/{variable}-{level}-{kwargs['quantile']}.bin"

    print("Loading data...")
    data = load_variable_at_level(filename, variable, level)

    if os.path.exists(output):
        print("Loading DFT...")
        with open(output, "rb") as file:
            fft = pickle.load(file)

    else:
        print("Performing DFT...")
        fft = dft3_at_level(data, level, **kwargs)

        print("Saving DFT...")
        with open(output, "wb") as file:
            pickle.dump(fft, file, protocol=pickle.HIGHEST_PROTOCOL)

    print("Performing IDFT...")
    prediction = idft3_at_level(*fft)

    data = data.astype("float32")
    size = sum(el.nbytes for el in fft)
    lines = f"""
    Original Stdev: {data.std()} m/s
    Predicted MAE:  {mae(data, prediction)} m/s
    Predicted RMSE: {rmse(data, prediction)} m/s

    Size/level: {size / (1000 ** 2)} MB
    Size/year: {size * 36 / (1000 ** 2)} MB
    """
    print(lines)


In [44]:
fit_dft3_at_level("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U",
                  level=35, quantile=0.9935)


Loading data...
Performing DFT...
Filtering frequencies...
Encoding DFT...
Saving DFT...
Performing IDFT...
Decoding DFT...
Building IDFT...


KeyboardInterrupt: 