In [40]:
import random

from plot import *
from fft import *
from maths import rmse, mae, mse


In [13]:
filename = "MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4"
variable = "U"
time = 0
levels = 36
quantile = 0.99044

skip_levels = 4

In [14]:
data = load_yavg_variable_at_time(filename, variable, time, levels=levels + skip_levels)[::-1]
prediction = np.zeros((levels, 361, 576), dtype="float32")

In [27]:
preds = []

for lev in range(0, levels + skip_levels, skip_levels):
    fft = dft2_at_time_and_level(data[lev], quantile=quantile)
    pred = idft2_at_time_and_level(*fft)

    preds.append(pred)
    if lev < levels:
        prediction[lev] = pred

preds.insert(0, preds[0])
preds.append(preds[-1])

In [28]:
def hermite_interpolate(values, index, mu, tens, b):
    mu2 = mu * mu
    mu3 = mu2 * mu

    m0 = (values[index + 1] - values[index]) * (1 + b) * (1 - tens)
    m0 += (values[index + 2] - values[index + 1]) * (1 - b) * (1 - tens)

    m1 = (values[index + 2] - values[index + 1]) * (1 + b) * (1 - tens)
    m1 += (values[index + 3] - values[index + 2]) * (1 - b) * (1 - tens)
    m0 /= 2
    m1 /= 2

    a0 = 2 * mu3 - 3 * mu2 + 1
    a1 = mu3 - 2 * mu2 + mu
    a2 = mu3 - mu2
    a3 = -2 * mu3 + 3 * mu2
    
    return a0 * values[index + 1] + a1 * m0 + a2 * m1 + a3 * values[index + 2]

In [274]:
tensions = []

for lev in tqdm(range(levels)):
    if lev % skip_levels == 0:
        continue

    pred_idx = lev // skip_levels

    best_loss = float("inf")
    best_tension = 0
    
    mu = (lev % skip_levels) / skip_levels
    mu2 = mu * mu
    mu3 = mu2 * mu

    a0 = (mu3 - 2 * mu2 + mu) / 2
    a1 = (mu3 - mu2) / 2
    m0 = preds[pred_idx + 1] - preds[pred_idx] + preds[pred_idx + 2] - preds[pred_idx + 1]
    m1 = preds[pred_idx + 2] - preds[pred_idx + 1] + preds[pred_idx + 3] - preds[pred_idx + 2]

    const = (2 * mu3 - 3 * mu2 + 1) * preds[pred_idx + 1] + (-2 * mu3 + 3 * mu2) * preds[pred_idx + 2]
    const2 = a0 * m0 + a1 * m1
    baseline = preds[pred_idx] - const

    lr = 0.00015
    tension = 0

    for _ in range(200):
        dt = (2 * const2 * (baseline - tension * const2)).mean()
        tension = np.tanh(tension + lr * dt)

        if (loss := ((baseline - tension * const2) ** 2).mean()) < best_loss:
            best_loss = loss
            best_tension = tension
        else:
            pass
            # lr /= 1.1
            # tension -= dt

    tensions.append(best_tension)


  0%|          | 0/36 [00:00<?, ?it/s]

In [276]:
tensions

[-0.0015125482023553816,
 0.0016443568920793783,
 0.012293766372655184,
 -0.018805677238908698,
 -0.011609435947663177,
 0.00627802418289664,
 -0.008313515111077674,
 0.002961827156883978,
 0.019782678100450934,
 -0.013203696516491038,
 0.009298508283036015,
 0.04239836596741894,
 -0.028864238347298312,
 0.021687662315774927,
 0.08358335454335546,
 -0.051796761937004214,
 0.05968494591088164,
 0.14152534669296252,
 -0.1354084540727688,
 -0.10777314669780115,
 0.10592215892872528,
 -0.16821314816810048,
 -0.18897192877243438,
 -0.16725325072268823,
 -0.11448276469201626,
 -0.07089362001046359,
 0.1610290594750893]

In [277]:
for lev in range(levels):
    if lev % skip_levels == 0:
        continue

    pred_idx = lev // skip_levels
    t = (lev % skip_levels) / skip_levels

    prediction[lev] = hermite_interpolate(preds, pred_idx, t, tensions[pred_idx], 0)

data_float32 = data[:levels].astype("float32")

print(f"Original Stdev: {data_float32.std()} m/s")
print(f"Predicted RMSE: {rmse(data_float32, prediction)} m/s")
print(f"Predicted MAE:  {mae(data_float32, prediction)} m/s")
print(f"Size/time: {sum(el.nbytes for el in fft) * (len(preds) - 2) / (1000 ** 2)} mB")
print(f"Size/day: {sum(el.nbytes for el in fft) * (len(preds) - 2) * 8 / (1000 ** 2)} mB")
print(f"Size/year: {sum(el.nbytes for el in fft) * (len(preds) - 2) * 8 * 365 / (1000 ** 2)} mB")


Original Stdev: 10.071374893188477 m/s
Predicted RMSE: 0.9160949412414103 m/s
Predicted MAE:  0.5961885452270508 m/s
Size/time: 0.06986 mB
Size/day: 0.55888 mB
Size/year: 203.9912 mB
