In [1]:
import pickle

from plot import *
from fft import *
from util import log
from maths import rmse, mse, mae, linear_interpolate


In [2]:
DFT3_LEVEL_CACHE = {}


def dft3_at_level(data: np.ndarray, level: int, quantile: float = 0.75):
    if level in DFT3_LEVEL_CACHE:
        fft, amplitudes = DFT3_LEVEL_CACHE[level]
    else:
        fft = np.fft.rfftn(data)
        amplitudes = np.abs(fft)

        DFT3_LEVEL_CACHE[level] = fft, amplitudes

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []
    fft_k_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    for k in range(289):
        for j in range(361):
            for i in range(365 * 8):
                if amplitudes[i, j, k] < cutoff_amp:
                    continue

                fft_real.append(fft[i, j, k].real)
                fft_imag.append(fft[i, j, k].imag)
                fft_i_indices.append(i)
                fft_j_indices.append(j)
                fft_k_indices.append(k)

    fft_real = (np.array(fft_real, dtype="float32") / 262144).astype("float16")
    fft_real = encode_zlib(fft_real)

    fft_imag = (np.array(fft_imag, dtype="float32") / 262144).astype("float16")
    fft_imag = encode_zlib(fft_imag)

    fft_i_indices = np.array(fft_i_indices, dtype="int16")
    fft_i_indices = encode_difference_uint8(fft_i_indices)
    fft_i_indices = encode_zlib(fft_i_indices)

    fft_j_indices = np.array(fft_j_indices, dtype="int16")
    fft_j_indices = encode_difference_uint8(fft_j_indices)
    fft_j_indices = encode_zlib(fft_j_indices)

    fft_k_indices = np.array(fft_k_indices, dtype="int16")
    fft_k_indices = encode_difference_uint8(fft_k_indices)
    fft_k_indices = encode_zlib(fft_k_indices)

    return fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices


def idft3_at_level(fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices):
    ifft = np.zeros((365 * 8, 361, 289), dtype="complex64")

    fft_real = decode_zlib(fft_real, dtype="float16")
    fft_imag = decode_zlib(fft_imag, dtype="float16")
    fft = fft_real.astype("complex64") * 262144 + fft_imag.astype("complex64") * 262144j

    fft_i_indices = decode_zlib(fft_i_indices)
    fft_i_indices = decode_difference_uint8(fft_i_indices)

    fft_j_indices = decode_zlib(fft_j_indices)
    fft_j_indices = decode_difference_uint8(fft_j_indices)

    fft_k_indices = decode_zlib(fft_k_indices)
    fft_k_indices = decode_difference_uint8(fft_k_indices)

    for idx in range(len(fft)):
        ifft[fft_i_indices[idx], fft_j_indices[idx], fft_k_indices[idx]] = fft[idx]

    return np.fft.irfftn(ifft)

In [3]:
def fit_dft3_at_level(filename: str, variable: str, level: int, verbose: bool = True, **kwargs):
    output = f"models/3D-dft/{variable}/{kwargs['quantile']}/{level}.bin"

    if os.path.exists(output):
        with open(output, "rb") as file:
            fft = pickle.load(file)
        if verbose:
            print("Loading data...")
            data = load_variable_at_level(filename, variable, level, cache=verbose)
    else:
        if verbose:
            print("Loading data...")

        data = load_variable_at_level(filename, variable, level, cache=verbose)
        if verbose:
            print("Performing DFT...")

        fft = dft3_at_level(data, level, **kwargs)
        with open(output, "wb") as file:
            pickle.dump(fft, file, protocol=pickle.HIGHEST_PROTOCOL)

    if verbose:
        data = data.astype("float32")

        print("Performing IDFT...")
        prediction = idft3_at_level(*fft)

        size = sum(el.nbytes for el in fft)
        lines = f"""
        Original Stdev: {data.std()} m/s
        Predicted MAE:  {mae(data, prediction)} m/s
        Predicted RMSE: {rmse(data, prediction)} m/s

        Size/level: {size / (1000 ** 2)} MB
        Size/year: {size * 36 / (1000 ** 2)} MB
        """
        print(lines)


def fit_dft3(filename: str, variable: str, skip_levels: int, **kwargs):
    levels = sorted(set(range(0, 36, skip_levels)).union({35}))
    for lev in tqdm(levels):
        print(lev)
        fit_dft3_at_level(filename, variable, lev, verbose=False, **kwargs)
        DFT3_LEVEL_CACHE.clear()

In [None]:
fit_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935, skip_levels=2)


  0%|          | 0/19 [00:00<?, ?it/s]

0
2


  0%|          | 0/12 [00:00<?, ?it/s]

In [None]:
def test_fit_dft3(filename: str, variable: str, quantile: float, skip_levels: int):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    predicted_levels = set()
    predicted_window = []
    predicted_window_levels = []

    def predict_from_fft(level):
        nonlocal data_variance, mae_error, mse_error

        predicted_levels.add(level)

        data = load_variable_at_level(filename, variable, level, cache=False, folder="raw").astype("float32")

        log("Loading DFT")
        with open(f"models/3D-dft/{get_year_from_filename(filename)}/{variable}-{level}-{quantile}.bin", "rb") as file:
            fft = pickle.load(file)
        log("Performing IDFT")
        pred = idft3_at_level(*fft)

        log("Calculating Error")
        predicted_window.append(pred)
        predicted_window_levels.append(level)
        if len(predicted_window_levels) > 2:
            predicted_window.pop(0)
            predicted_window_levels.pop(0)

        mae_loss = mae(data, pred)
        mse_loss = mse(data, pred)
        var = data.var()

        data_variance += var
        mae_error += mae_loss
        mse_error += mse_loss

        print(f"""
        Level {level}:
            Original Stdev: {var ** 0.5} m/s
            Predicted MAE:  {mae_loss} m/s
            Predicted RMSE: {mse_loss ** 0.5} m/s
        """)

    def interpolate_from_fft(level):
        nonlocal data_variance, mae_error, mse_error

        predicted_levels.add(level)

        # lev0 = max(0, lev - lev % skip_levels - 1 * skip_levels)
        lev1 = max(0, lev - lev % skip_levels)
        lev2 = min(35, lev - lev % skip_levels + skip_levels)
        # lev3 = min(35, lev - lev % skip_levels + 2 * skip_levels)
        t = (lev - lev1) / (lev2 - lev1)

        for i in (lev1, lev2):
            if i in predicted_window_levels:
                predicted_window.append(predicted_window[predicted_window_levels.index(i)])
                predicted_window_levels.append(i)

                if len(predicted_window_levels) > 2:
                    predicted_window.pop(0)
                    predicted_window_levels.pop(0)
            else:
                predict_from_fft(i)

        print(predicted_window_levels)
        data = load_variable_at_level(filename, variable, lev, cache=False).astype("float32")

        log("Interpolating DFT")
        pred = linear_interpolate(predicted_window, 0, t)

        log("Calculating Error")
        mae_loss = mae(data, pred)
        mse_loss = mse(data, pred)
        var = data.var()

        data_variance += var
        mae_error += mae_loss
        mse_error += mse_loss

        log(f"""
        Level {level}:
            Original Stdev: {var ** 0.5} m/s
            Predicted MAE:  {mae_loss} m/s
            Predicted RMSE: {mse_loss ** 0.5} m/s
        """)

    for lev in tqdm(range(36)):
        if lev in predicted_levels:
            continue

        if lev % skip_levels == 0 or lev == 35:
            predict_from_fft(lev)
            continue

        interpolate_from_fft(lev)

    lines = f"""
    Original Stdev: {(data_variance / 36) ** 0.5} m/s
    Predicted MAE:  {(mae_error / 36)} m/s
    Predicted RMSE: {(mse_error / 36) ** 0.5} m/s
    """
    log(lines)


test_fit_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935, skip_levels=3)
