In [1]:
import random

import numpy as np

from plot import *
from fft import *
from maths import rmse, mae, mse


In [2]:
filename = "MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4"
variable = "U"
time = 0
levels = 36
quantile = 0.99044

skip_levels = 4

In [14]:
data = load_yavg_variable_at_time(filename, variable, time, levels=levels + skip_levels)[::-1]
prediction = np.zeros((levels, 361, 576), dtype="float32")

In [16]:
preds = []

for lev in range(0, levels + skip_levels, skip_levels):
    fft = dft2_at_time_and_level(data[lev], quantile=quantile)
    pred = idft2_at_time_and_level(*fft)

    preds.append(pred)
    if lev < levels:
        prediction[lev] = pred

preds.insert(0, preds[0])
preds.append(preds[-1])

In [17]:
def hermite_interpolate(values, index, mu, tens, b):
    mu2 = mu * mu
    mu3 = mu2 * mu

    m0 = (values[index + 1] - values[index]) * (1 + b) * (1 - tens)
    m0 += (values[index + 2] - values[index + 1]) * (1 - b) * (1 - tens)

    m1 = (values[index + 2] - values[index + 1]) * (1 + b) * (1 - tens)
    m1 += (values[index + 3] - values[index + 2]) * (1 - b) * (1 - tens)
    m0 /= 2
    m1 /= 2

    a0 = 2 * mu3 - 3 * mu2 + 1
    a1 = mu3 - 2 * mu2 + mu
    a2 = mu3 - mu2
    a3 = -2 * mu3 + 3 * mu2
    
    return a0 * values[index + 1] + a1 * m0 + a2 * m1 + a3 * values[index + 2]

In [71]:
tensions = []

for lev in tqdm(range(levels)):
    if lev % skip_levels == 0:
        continue

    pred_idx = lev // skip_levels

    best_loss = float("inf")
    best_tension = 0
    
    t = (lev % skip_levels) / skip_levels

    for tension in np.linspace(-1, 1, 5):
        lr = 0.1

        for _ in range(20):
            if (loss := mse(data[lev], hermite_interpolate(preds, pred_idx, t, tension, 0))) < best_loss:
                best_loss = loss
                best_tension = tension
            else:
                lr /= -2

            tension += lr

    tensions.append(best_tension)


  0%|          | 0/36 [00:00<?, ?it/s]

In [72]:
for lev in range(levels):
    pred_idx = lev // skip_levels

    if lev % skip_levels == 0:
        continue

    t = (lev % skip_levels) / skip_levels
    prediction[lev] = hermite_interpolate(preds, pred_idx, t, tensions[lev - pred_idx - 1], 0)

data_float32 = data[:levels].astype("float32")

print(f"Original Stdev: {data_float32.std()} m/s")
print(f"Predicted RMSE: {rmse(data_float32, prediction)} m/s")
print(f"Predicted MAE:  {mae(data_float32, prediction)} m/s")
print(f"Size/time: {sum(el.nbytes for el in fft) * (len(preds) - 2) / (1000 ** 2)} mB")
print(f"Size/day: {sum(el.nbytes for el in fft) * (len(preds) - 2) * 8 / (1000 ** 2)} mB")
print(f"Size/year: {sum(el.nbytes for el in fft) * (len(preds) - 2) * 8 * 365 / (1000 ** 2)} mB")


Original Stdev: 10.018913269042969 m/s
Predicted RMSE: 0.7954508078474939 m/s
Predicted MAE:  0.5190367698669434 m/s
Size/time: 0.06986 mB
Size/day: 0.55888 mB
Size/year: 203.9912 mB
