In [43]:
from tqdm.notebook import tqdm

from plot import *
from fft import *
from maths import rmse, mae


In [209]:
def fit_yavg_dft2_at_time(filename: str, variable: str, time: int, skip_levels: int = 2, levels: int = 72, **kwargs):
    data = load_yavg_variable_at_time(filename, variable, time)[::-1]
    data = data[:levels + skip_levels]
    prediction = np.zeros((levels, 361, 576), dtype="float32")

    preds = []
    for lev in tqdm(range(0, levels + skip_levels, skip_levels)):
        fft = dft2_at_time_and_level(data[lev], **kwargs)
        pred = idft2_at_time_and_level(*fft)

        preds.append(pred)

        if lev < levels:
            prediction[lev] = pred

    for lev in range(levels):
        if lev % skip_levels == 0:
            continue

        pred_idx = lev // skip_levels
        t = (lev % skip_levels) / skip_levels

        prediction[lev] = preds[pred_idx] * (1 - t) + preds[pred_idx + 1] * t

    data = data[:levels]
    print(f"Original Stdev: {(data.astype('float32')).std()} m/s")
    print(f"Predicted RMSE: {rmse(data.astype('float32'), prediction)} m/s")
    print(f"Predicted MAE:  {mae(data.astype('float32'), prediction)} m/s")
    print(f"Size/time: {sum(el.nbytes for el in fft) * len(preds) / (1024 ** 2)} mB")
    print(f"Size/day: {sum(el.nbytes for el in fft) * len(preds) * 8 / (1024 ** 2)} mB")
    print(f"Size/year: {sum(el.nbytes for el in fft) * len(preds) * 8 * 365 / (1024 ** 2)} mB")


In [217]:
fit_yavg_dft2_at_time("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "U",
                 time=0, quantile=0.9905, skip_levels=3, levels=30)


  0%|          | 0/11 [00:00<?, ?it/s]

Original Stdev: 9.237070083618164 m/s
Predicted RMSE: 0.7944099221105478 m/s
Predicted MAE:  0.5580173134803772 m/s
Size/time: 0.072845458984375 mB
Size/day: 0.582763671875 mB
Size/year: 212.708740234375 mB
