In [2]:
import pickle

from plot import *
from fft import *
from util import log
from maths import rmse, mse, mae, linear_interpolate


In [3]:
DFT3_LEVEL_CACHE = {}


def dft3_at_level(data: np.ndarray, level: int, quantile: float = 0.75):
    if level in DFT3_LEVEL_CACHE:
        fft, amplitudes = DFT3_LEVEL_CACHE[level]
    else:
        fft = np.fft.rfftn(data)
        amplitudes = np.abs(fft)

        DFT3_LEVEL_CACHE[level] = fft, amplitudes

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []
    fft_k_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    for k in range(289):
        for j in range(361):
            for i in range(365 * 8):
                if amplitudes[i, j, k] < cutoff_amp:
                    continue

                fft_real.append(fft[i, j, k].real)
                fft_imag.append(fft[i, j, k].imag)
                fft_i_indices.append(i)
                fft_j_indices.append(j)
                fft_k_indices.append(k)

    fft_real = (np.array(fft_real, dtype="float32") / 262144).astype("float16")
    fft_real = encode_zlib(fft_real)

    fft_imag = (np.array(fft_imag, dtype="float32") / 262144).astype("float16")
    fft_imag = encode_zlib(fft_imag)

    fft_i_indices = np.array(fft_i_indices, dtype="int16")
    fft_i_indices = encode_difference_uint8(fft_i_indices)
    fft_i_indices = encode_zlib(fft_i_indices)

    fft_j_indices = np.array(fft_j_indices, dtype="int16")
    fft_j_indices = encode_difference_uint8(fft_j_indices)
    fft_j_indices = encode_zlib(fft_j_indices)

    fft_k_indices = np.array(fft_k_indices, dtype="int16")
    fft_k_indices = encode_difference_uint8(fft_k_indices)
    fft_k_indices = encode_zlib(fft_k_indices)

    return fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices


def idft3_at_level(fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices):
    ifft = np.zeros((365 * 8, 361, 289), dtype="complex64")

    fft_real = decode_zlib(fft_real, dtype="float16")
    fft_imag = decode_zlib(fft_imag, dtype="float16")
    fft = fft_real.astype("complex64") * 262144 + fft_imag.astype("complex64") * 262144j

    fft_i_indices = decode_zlib(fft_i_indices)
    fft_i_indices = decode_difference_uint8(fft_i_indices)

    fft_j_indices = decode_zlib(fft_j_indices)
    fft_j_indices = decode_difference_uint8(fft_j_indices)

    fft_k_indices = decode_zlib(fft_k_indices)
    fft_k_indices = decode_difference_uint8(fft_k_indices)

    for idx in range(len(fft)):
        ifft[fft_i_indices[idx], fft_j_indices[idx], fft_k_indices[idx]] = fft[idx]

    return np.fft.irfftn(ifft)

In [4]:
def fit_dft3_at_level(filename: str, variable: str, level: int, verbose: bool = True, **kwargs):
    output = f"models/3D-dft/{get_year_from_filename(filename)}/{variable}-{level}-{kwargs['quantile']}.bin"

    if os.path.exists(output):
        with open(output, "rb") as file:
            fft = pickle.load(file)
        if verbose:
            print("Loading data...")
            data = load_variable_at_level(filename, variable, level, cache=verbose)
    else:
        if verbose:
            print("Loading data...")

        data = load_variable_at_level(filename, variable, level, cache=verbose)
        if verbose:
            print("Performing DFT...")

        fft = dft3_at_level(data, level, **kwargs)
        with open(output, "wb") as file:
            pickle.dump(fft, file, protocol=pickle.HIGHEST_PROTOCOL)

    if verbose:
        data = data.astype("float32")

        print("Performing IDFT...")
        prediction = idft3_at_level(*fft)

        size = sum(el.nbytes for el in fft)
        lines = f"""
        Original Stdev: {data.std()} m/s
        Predicted MAE:  {mae(data, prediction)} m/s
        Predicted RMSE: {rmse(data, prediction)} m/s

        Size/level: {size / (1000 ** 2)} MB
        Size/year: {size * 36 / (1000 ** 2)} MB
        """
        print(lines)


def fit_dft3(filename: str, variable: str, skip_levels: int, **kwargs):
    levels = sorted(set(range(0, 36, skip_levels)).union({35}))
    for lev in tqdm(levels):
        print(lev)
        fit_dft3_at_level(filename, variable, lev, verbose=False, **kwargs)
        DFT3_LEVEL_CACHE.clear()

In [None]:
fit_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935, skip_levels=2)


In [5]:
def test_fit_dft3(filename: str, variable: str, quantile: float, skip_levels: int):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    predicted_levels = set()
    predicted_window = []
    predicted_window_levels = []

    def predict_from_fft(level):
        nonlocal data_variance, mae_error, mse_error

        predicted_levels.add(level)

        data = load_variable_at_level(filename, variable, level, cache=False, folder="raw").astype("float32")

        log("Loading DFT")
        with open(f"models/3D-dft/{get_year_from_filename(filename)}/{variable}-{level}-{quantile}.bin", "rb") as file:
            fft = pickle.load(file)
        log("Performing IDFT")
        pred = idft3_at_level(*fft)

        log("Calculating Error")
        predicted_window.append(pred)
        predicted_window_levels.append(level)
        if len(predicted_window_levels) > 2:
            predicted_window.pop(0)
            predicted_window_levels.pop(0)

        mae_loss = mae(data, pred)
        mse_loss = mse(data, pred)
        var = data.var()

        data_variance += var
        mae_error += mae_loss
        mse_error += mse_loss

        print(f"""
        Level {level}:
            Original Stdev: {var ** 0.5} m/s
            Predicted MAE:  {mae_loss} m/s
            Predicted RMSE: {mse_loss ** 0.5} m/s
        """)

    def interpolate_from_fft(level):
        nonlocal data_variance, mae_error, mse_error

        predicted_levels.add(level)

        # lev0 = max(0, lev - lev % skip_levels - 1 * skip_levels)
        lev1 = max(0, lev - lev % skip_levels)
        lev2 = min(35, lev - lev % skip_levels + skip_levels)
        # lev3 = min(35, lev - lev % skip_levels + 2 * skip_levels)
        t = (lev - lev1) / (lev2 - lev1)

        for i in (lev1, lev2):
            if i in predicted_window_levels:
                predicted_window.append(predicted_window[predicted_window_levels.index(i)])
                predicted_window_levels.append(i)

                if len(predicted_window_levels) > 2:
                    predicted_window.pop(0)
                    predicted_window_levels.pop(0)
            else:
                predict_from_fft(i)

        print(predicted_window_levels)
        data = load_variable_at_level(filename, variable, lev, cache=False).astype("float32")

        log("Interpolating DFT")
        pred = linear_interpolate(predicted_window, 0, t)

        log("Calculating Error")
        mae_loss = mae(data, pred)
        mse_loss = mse(data, pred)
        var = data.var()

        data_variance += var
        mae_error += mae_loss
        mse_error += mse_loss

        log(f"""
        Level {level}:
            Original Stdev: {var ** 0.5} m/s
            Predicted MAE:  {mae_loss} m/s
            Predicted RMSE: {mse_loss ** 0.5} m/s
        """)

    for lev in tqdm(range(36)):
        if lev in predicted_levels:
            continue

        if lev % skip_levels == 0 or lev == 35:
            predict_from_fft(lev)
            continue

        interpolate_from_fft(lev)

    lines = f"""
    Original Stdev: {(data_variance / 36) ** 0.5} m/s
    Predicted MAE:  {(mae_error / 36)} m/s
    Predicted RMSE: {(mse_error / 36) ** 0.5} m/s
    """
    log(lines)


test_fit_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935, skip_levels=3)


  0%|          | 0/36 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

[11:58:55] LOG: Loading DFT
[11:58:55] LOG: Performing IDFT
[11:59:22] LOG: Calculating Error

        Level 0:
            Original Stdev: 11.38197934172047 m/s
            Predicted MAE:  0.09965647515987254 m/s
            Predicted RMSE: 0.1341450619757956 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:00:01] LOG: Loading DFT
[12:00:01] LOG: Performing IDFT
[12:00:28] LOG: Calculating Error

        Level 3:
            Original Stdev: 12.885986973387809 m/s
            Predicted MAE:  0.15934444036320133 m/s
            Predicted RMSE: 0.21936457957497246 m/s
        
[0, 3]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:01:06] LOG: Interpolating DFT
[12:01:08] LOG: Calculating Error
[12:01:11] LOG: 
        Level 1:
            Original Stdev: 11.579793338809257 m/s
            Predicted MAE:  0.571019988823665 m/s
            Predicted RMSE: 1.0191353409542854 m/s
        
[0, 3]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:01:46] LOG: Interpolating DFT
[12:01:47] LOG: Calculating Error
[12:01:50] LOG: 
        Level 2:
            Original Stdev: 12.186177186825692 m/s
            Predicted MAE:  0.6394137820287646 m/s
            Predicted RMSE: 1.1452951518171026 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:02:25] LOG: Loading DFT
[12:02:25] LOG: Performing IDFT
[12:02:52] LOG: Calculating Error

        Level 6:
            Original Stdev: 13.372077979223045 m/s
            Predicted MAE:  0.25617333927588615 m/s
            Predicted RMSE: 0.3374529362853957 m/s
        
[3, 6]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:03:29] LOG: Interpolating DFT
[12:03:31] LOG: Calculating Error
[12:03:34] LOG: 
        Level 4:
            Original Stdev: 13.456200650164964 m/s
            Predicted MAE:  0.9342818968548541 m/s
            Predicted RMSE: 1.362134670256098 m/s
        
[3, 6]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:04:08] LOG: Interpolating DFT
[12:04:09] LOG: Calculating Error
[12:04:12] LOG: 
        Level 5:
            Original Stdev: 13.71315865199713 m/s
            Predicted MAE:  0.973156760573604 m/s
            Predicted RMSE: 1.402196353270533 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:04:46] LOG: Loading DFT
[12:04:46] LOG: Performing IDFT
[12:05:13] LOG: Calculating Error

        Level 9:
            Original Stdev: 10.406084305644072 m/s
            Predicted MAE:  0.3136316277531956 m/s
            Predicted RMSE: 0.4043871811373609 m/s
        
[6, 9]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:05:50] LOG: Interpolating DFT
[12:05:51] LOG: Calculating Error
[12:05:54] LOG: 
        Level 7:
            Original Stdev: 12.477996992127657 m/s
            Predicted MAE:  0.6744009841083345 m/s
            Predicted RMSE: 0.8899932087790168 m/s
        
[6, 9]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:06:28] LOG: Interpolating DFT
[12:06:29] LOG: Calculating Error
[12:06:33] LOG: 
        Level 8:
            Original Stdev: 11.344413800955282 m/s
            Predicted MAE:  0.597549811248363 m/s
            Predicted RMSE: 0.779012125926512 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:07:09] LOG: Loading DFT
[12:07:09] LOG: Performing IDFT
[12:07:35] LOG: Calculating Error

        Level 12:
            Original Stdev: 8.645025033203028 m/s
            Predicted MAE:  0.27199249215327037 m/s
            Predicted RMSE: 0.3515439508835796 m/s
        
[9, 12]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:08:12] LOG: Interpolating DFT
[12:08:13] LOG: Calculating Error
[12:08:16] LOG: 
        Level 10:
            Original Stdev: 9.752047592735414 m/s
            Predicted MAE:  0.413306482270061 m/s
            Predicted RMSE: 0.5465748221602307 m/s
        
[9, 12]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:08:50] LOG: Interpolating DFT
[12:08:51] LOG: Calculating Error
[12:08:54] LOG: 
        Level 11:
            Original Stdev: 9.171758127848864 m/s
            Predicted MAE:  0.38795157129075075 m/s
            Predicted RMSE: 0.5125898432878698 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:09:28] LOG: Loading DFT
[12:09:28] LOG: Performing IDFT
[12:09:55] LOG: Calculating Error

        Level 15:
            Original Stdev: 7.371538287304178 m/s
            Predicted MAE:  0.23774022613858423 m/s
            Predicted RMSE: 0.30766597819255903 m/s
        
[12, 15]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:10:32] LOG: Interpolating DFT
[12:10:33] LOG: Calculating Error
[12:10:36] LOG: 
        Level 13:
            Original Stdev: 8.166585753163284 m/s
            Predicted MAE:  0.3166215711177805 m/s
            Predicted RMSE: 0.4128982323556156 m/s
        
[12, 15]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:11:10] LOG: Interpolating DFT
[12:11:11] LOG: Calculating Error
[12:11:14] LOG: 
        Level 14:
            Original Stdev: 7.742530422526604 m/s
            Predicted MAE:  0.30088803782990414 m/s
            Predicted RMSE: 0.3924125346822039 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:11:48] LOG: Loading DFT
[12:11:48] LOG: Performing IDFT
[12:12:15] LOG: Calculating Error

        Level 18:
            Original Stdev: 6.479860062206687 m/s
            Predicted MAE:  0.21961458521849211 m/s
            Predicted RMSE: 0.2849725771076872 m/s
        
[15, 18]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:14:15] LOG: Interpolating DFT
[12:14:16] LOG: Calculating Error
[12:14:19] LOG: 
        Level 16:
            Original Stdev: 7.038211072431521 m/s
            Predicted MAE:  0.27325914752591624 m/s
            Predicted RMSE: 0.360472404848916 m/s
        
[15, 18]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:14:53] LOG: Interpolating DFT
[12:14:54] LOG: Calculating Error
[12:14:57] LOG: 
        Level 17:
            Original Stdev: 6.731332184441273 m/s
            Predicted MAE:  0.269681299378081 m/s
            Predicted RMSE: 0.3553239189447178 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:15:31] LOG: Loading DFT
[12:15:31] LOG: Performing IDFT
[12:15:58] LOG: Calculating Error

        Level 21:
            Original Stdev: 6.0103419501444355 m/s
            Predicted MAE:  0.2117385264290441 m/s
            Predicted RMSE: 0.27567027738023203 m/s
        
[18, 21]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:16:35] LOG: Interpolating DFT
[12:16:37] LOG: Calculating Error
[12:16:40] LOG: 
        Level 19:
            Original Stdev: 6.298925629090293 m/s
            Predicted MAE:  0.22435851642323512 m/s
            Predicted RMSE: 0.2912592896201122 m/s
        
[18, 21]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:17:13] LOG: Interpolating DFT
[12:17:15] LOG: Calculating Error
[12:17:18] LOG: 
        Level 20:
            Original Stdev: 6.139713598275383 m/s
            Predicted MAE:  0.22138783816640845 m/s
            Predicted RMSE: 0.28754704014228155 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:17:56] LOG: Loading DFT
[12:17:56] LOG: Performing IDFT
[12:18:23] LOG: Calculating Error

        Level 24:
            Original Stdev: 5.827491050788622 m/s
            Predicted MAE:  0.2096129413275214 m/s
            Predicted RMSE: 0.2745280839501352 m/s
        
[21, 24]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:19:00] LOG: Interpolating DFT
[12:19:02] LOG: Calculating Error
[12:19:05] LOG: 
        Level 22:
            Original Stdev: 5.914747150656668 m/s
            Predicted MAE:  0.22132072730224261 m/s
            Predicted RMSE: 0.2877784277730654 m/s
        
[21, 24]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:19:38] LOG: Interpolating DFT
[12:19:40] LOG: Calculating Error
[12:19:43] LOG: 
        Level 23:
            Original Stdev: 5.857590548588396 m/s
            Predicted MAE:  0.22472098480186842 m/s
            Predicted RMSE: 0.29190375235447713 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:20:20] LOG: Loading DFT
[12:20:20] LOG: Performing IDFT
[12:20:47] LOG: Calculating Error

        Level 27:
            Original Stdev: 5.775212615091475 m/s
            Predicted MAE:  0.21261141555802435 m/s
            Predicted RMSE: 0.2805476172213472 m/s
        
[24, 27]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:21:28] LOG: Interpolating DFT
[12:21:30] LOG: Calculating Error
[12:21:33] LOG: 
        Level 25:
            Original Stdev: 5.8098867140737855 m/s
            Predicted MAE:  0.20789449091687762 m/s
            Predicted RMSE: 0.27315911388015185 m/s
        
[24, 27]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:22:09] LOG: Interpolating DFT
[12:22:11] LOG: Calculating Error
[12:22:14] LOG: 
        Level 26:
            Original Stdev: 5.794404770806853 m/s
            Predicted MAE:  0.20915556045272807 m/s
            Predicted RMSE: 0.2753863656488652 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:22:50] LOG: Loading DFT
[12:22:50] LOG: Performing IDFT
[12:23:17] LOG: Calculating Error

        Level 30:
            Original Stdev: 5.654282538572067 m/s
            Predicted MAE:  0.2245458467698903 m/s
            Predicted RMSE: 0.29830801319254585 m/s
        
[27, 30]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:23:56] LOG: Interpolating DFT
[12:23:57] LOG: Calculating Error
[12:24:00] LOG: 
        Level 28:
            Original Stdev: 5.747177094510318 m/s
            Predicted MAE:  0.2152625582105414 m/s
            Predicted RMSE: 0.28486553444487517 m/s
        
[27, 30]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:25:02] LOG: Interpolating DFT
[12:25:04] LOG: Calculating Error
[12:25:07] LOG: 
        Level 29:
            Original Stdev: 5.707465038484892 m/s
            Predicted MAE:  0.21967702303747066 m/s
            Predicted RMSE: 0.2915055738317794 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:25:42] LOG: Loading DFT
[12:25:42] LOG: Performing IDFT
[12:26:09] LOG: Calculating Error

        Level 33:
            Original Stdev: 5.381127878679719 m/s
            Predicted MAE:  0.23679652637664722 m/s
            Predicted RMSE: 0.3160377096743778 m/s
        
[30, 33]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:26:48] LOG: Interpolating DFT
[12:26:50] LOG: Calculating Error
[12:26:53] LOG: 
        Level 31:
            Original Stdev: 5.58171910346843 m/s
            Predicted MAE:  0.24310078054063936 m/s
            Predicted RMSE: 0.32627184936140535 m/s
        
[30, 33]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:28:57] LOG: Interpolating DFT
[12:28:59] LOG: Calculating Error
[12:29:02] LOG: 
        Level 32:
            Original Stdev: 5.491343187791445 m/s
            Predicted MAE:  0.2560220573110797 m/s
            Predicted RMSE: 0.34644994256480893 m/s
        


  0%|          | 0/12 [00:00<?, ?it/s]

[12:29:38] LOG: Loading DFT
[12:29:38] LOG: Performing IDFT
[12:30:05] LOG: Calculating Error

        Level 35:
            Original Stdev: 4.529156641991455 m/s
            Predicted MAE:  0.19884312650025285 m/s
            Predicted RMSE: 0.265055891763768 m/s
        
[33, 35]


  0%|          | 0/12 [00:00<?, ?it/s]

[12:30:45] LOG: Interpolating DFT
[12:30:47] LOG: Calculating Error
[12:30:50] LOG: 
        Level 34:
            Original Stdev: 5.166675731692002 m/s
            Predicted MAE:  0.3441599934030505 m/s
            Predicted RMSE: 0.5360250452124493 m/s
        
[12:30:50] LOG: 
    Original Stdev: 8.568464675185558 m/s
    Predicted MAE:  0.3275248175733363 m/s
    Predicted RMSE: 0.5530494016772665 m/s
    
