In [1]:
from plot import *
from fft import *
from maths import rmse, mae, linear_interpolate, cosine_interpolate, cubic_interpolate
from maths import catmull_rom_interpolate, hermite_interpolate, fit_kochanek_bartels_spline


In [None]:
def dft2_at_time_and_level(data: np.ndarray, quantile: float = 0.75):
    fft = np.fft.rfft2(data)
    amplitudes = np.abs(fft)

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    last_i = 0
    last_j = 0

    for i in range(361):
        for j in range(256):
            if amplitudes[i, j] < cutoff_amp:
                continue

            fft_real.append(fft[i, j].real)
            fft_imag.append(fft[i, j].imag)

            while (di := i - last_i) > 253:
                fft_i_indices.append(254)
                last_i += 253

            while (dj := j - last_j) > 253:
                fft_j_indices.append(254)
                last_j += 253

            if di < 0:
                fft_i_indices.append(255)
                last_i = 0
            else:
                fft_i_indices.append(di)
                last_i += di

            if dj < 0:
                fft_j_indices.append(255)
                last_j = 0
            else:
                fft_j_indices.append(dj)
                last_j += dj

    fft_real = np.array(fft_real, dtype="float32") / 512
    fft_imag = np.array(fft_imag, dtype="float32") / 512

    fft_i_indices = np.array(fft_i_indices, dtype="uint8")
    fft_j_indices = np.array(fft_j_indices, dtype="uint8")

    fft_i_indices = rlen_encode_array(fft_i_indices, 0)
    fft_j_indices = rlen_encode_array(fft_j_indices, 1)

    return fft_real.astype("float16"), fft_imag.astype("float16"), \
        fft_i_indices, fft_j_indices


def idft2_at_time_and_level(fft_real, fft_imag, fft_i_indices, fft_j_indices):
    ifft = np.zeros((361, 289), dtype="complex64")
    fft = fft_real.astype("complex64") * 512 + fft_imag.astype("complex64") * 512j

    fft_i_indices = rlen_decode_array(fft_i_indices, 0)
    fft_j_indices = rlen_decode_array(fft_j_indices, 1)

    i = 0
    j = 0

    i_index = 0
    j_index = 0

    for idx in range(len(fft)):
        while (di := fft_i_indices[i_index]) == 254:
            i += 253
            i_index += 1

        while (dj := fft_j_indices[j_index]) == 254:
            j += 253
            j_index += 1

        if di == 255:
            i = 0
        else:
            i += di

        if dj == 255:
            j = 0
        else:
            j += dj

        ifft[i, j] = fft[idx]

        i_index += 1
        j_index += 1

    return np.fft.irfft2(ifft)


In [3]:
def plot_dft2_at_time_and_level(filename: str, variable: str, time: int, level: int, **kwargs):
    data = load_variable_at_time_and_level(filename, variable, time, level)

    fft = dft2_at_time_and_level(data, **kwargs)
    prediction = idft2_at_time_and_level(*fft)

    title = f"{format_variable(variable)} at {format_level(level)}" \
            f" on {format_date(filename)} at {format_time(time, filename)}"
    output = f"{variable}/{kwargs['quantile']}/{format_level(level, for_output=True)}" \
             f"-{format_date(filename, for_output=True)}-{format_time(time, filename)}"

    fig, ax1, ax2 = create_1x2_plot(title, sharey=True, sharex=True)
    ax1.imshow(data, cmap=cmr.arctic, origin="lower", extent=[-180, 180, -90, 90], aspect="auto")
    ax2.imshow(prediction, cmap=cmr.arctic, origin="lower", extent=[-180, 180, -90, 90], aspect="auto")

    ax1.xaxis.set_major_formatter(FormatStrFormatter("%d°"))
    ax1.yaxis.set_major_formatter(FormatStrFormatter("%d°"))

    fig.suptitle(title, fontsize=8, y=0.96)

    plt.savefig("assets/2D-dft-graphs/" + output + ".png", dpi=300)
    plt.show()

    title = f"Error ({get_units_from_variable(variable)}) at {format_level(level)} " \
            f" on {format_date(filename)} at {format_time(time, filename)}"

    fig, ax1, ax2 = create_1x2_plot(title, sharey=True, sharex=True)
    ax1.imshow(np.abs(prediction - data), cmap="hot", origin="lower", extent=[-180, 180, -90, 90], aspect="auto")
    ax2.imshow((prediction - data) ** 2, cmap="hot", origin="lower", extent=[-180, 180, -90, 90], aspect="auto")

    ax1.xaxis.set_major_formatter(FormatStrFormatter("%d°"))
    ax1.yaxis.set_major_formatter(FormatStrFormatter("%d°"))

    fig.suptitle(title, fontsize=8, y=0.96)
    ax1.set_title("Absolute Error (m/s)", fontsize=8)
    ax2.set_title("Squared Error (m²/s²)", fontsize=8)

    output += "-error"
    plt.savefig("assets/2D-dft-graphs/" + output + ".png", dpi=300)
    plt.show()

    lines = f"""
    Original Stdev: {data.astype('float32').std()} m/s
    Predicted MAE:  {mae(data, prediction)} m/s
    Predicted RMSE: {rmse(data, prediction)} m/s

    Frequencies: {len(fft[0])}
    Size/time: {sum(el.nbytes for el in fft) * 36 / (1000 ** 2)} MB
    Size/day: {sum(el.nbytes for el in fft) * 36 * 8 / (1000 ** 2)} MB
    Size/year: {sum(el.nbytes for el in fft) * 36 * 8 * 365 / (1000 ** 2)} MB
    """
    print(lines)
    with open("assets/2D-dft-graphs/" + output + ".txt", "w") as file:
        file.writelines(lines)


In [None]:
%matplotlib notebook

for var in ["U", "V"]:
    plot_dft2_at_time_and_level("MERRA2.tavg3_3d_asm_Nv.YAVG0101.nc4", var,
                                time=0, level=35, quantile=0.9935)


In [5]:
def fit_dft2_at_time(filename: str, variable: str, time: int, verbose: bool = True, **kwargs):
    data = load_variable_at_time(filename, variable, time)
    prediction = np.zeros((36, 361, 576), dtype="float32")

    dfts = []
    for lev in (tqdm(range(36)) if verbose else range(36)):
        fft = dft2_at_time_and_level(data[lev], **kwargs)
        dfts.append(fft)
        prediction[lev] = idft2_at_time_and_level(*fft)

    data = data.astype("float32")
    if verbose:
        size = sum(el.nbytes for dft in dfts for el in dft)
        lines = f"""
        Original Stdev: {data.std()} m/s
        Predicted MAE:  {mae(data, prediction)} m/s
        Predicted RMSE: {rmse(data, prediction)} m/s

        Frequencies: {len(fft[0])}
        Size/time: {size / (1000 ** 2)} MB
        Size/day: {size * 8 / (1000 ** 2)} MB
        Size/year: {size * 8 * 365 / (1000 ** 2)} MB
        """
        print(lines)
    else:
        return data, prediction, dfts


def fit_dft2_on_day(filename: str, variable: str, **kwargs):
    data_variance = 0
    mae_error = np.zeros((36, 361, 576), dtype="float32")
    mse_error = np.zeros((36, 361, 576), dtype="float32")
    all_dfts = []

    for time in tqdm(range(8)):
        data, pred, dfts = fit_dft2_at_time(filename, variable, time, verbose=False, **kwargs)
        all_dfts.append(dfts)

        data_variance += data.var()
        mae_error += abs(data - pred)
        mse_error += (data - pred) ** 2

    size = sum(el.nbytes for dfts in all_dfts for dft in dfts for el in dft)
    lines = f"""
    Original Stdev: {(data_variance / 8) ** 0.5} m/s
    Predicted MAE:  {(mae_error / 8).mean()} m/s
    Predicted RMSE: {(mse_error / 8).mean() ** 0.5} m/s

    Size/day: {size / (1000 ** 2)} MB
    Size/year: {size * 365 / (1000 ** 2)} MB
    """
    print(lines)


def fit_interpolated_dft2_at_time(filename: str, variable: str, time: int, skip_levels: int = 2,
                                  interpolation: callable = linear_interpolate, verbose: bool = True,
                                  **kwargs):
    data = load_variable_at_time(filename, variable, time)
    prediction = np.zeros((36, 361, 576), dtype="float32")

    preds = []
    dfts = []
    levels = sorted(set(range(0, 36, skip_levels)).union({35}))
    for lev in (tqdm(levels) if verbose else levels):
        fft = dft2_at_time_and_level(data[lev], **kwargs)
        dfts.append(fft)
        preds.append(idft2_at_time_and_level(*fft))

    for lev in range(36):
        if lev in levels:
            prediction[lev] = preds[levels.index(lev)]
            continue

        pred_idx = levels.index(lev - lev % skip_levels)
        t = (lev - levels[pred_idx]) / (levels[pred_idx + 1] - levels[pred_idx])
        prediction[lev] = interpolation(preds, pred_idx, t)

    data = data.astype("float32")
    if verbose:
        size = sum(el.nbytes for dft in dfts for el in dft)
        print(f"Original Stdev: {data.std()} m/s")
        print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
        print(f"Predicted MAE:  {mae(data, prediction)} m/s")
        print(f"Size/time: {size / (1000 ** 2)} MB")
        print(f"Size/day: {size * 8 / (1000 ** 2)} MB")
        print(f"Size/year: {size * 8 * 365 / (1000 ** 2)} MB")
    else:
        return data, prediction, dfts


def fit_interpolated_dft2_on_day(filename: str, variable: str, quantile: float, skip_levels: int,
                                 interpolation: callable = linear_interpolate):
    data_variance = 0
    mae_error = np.zeros((36, 361, 576), dtype="float32")
    mse_error = np.zeros((36, 361, 576), dtype="float32")
    all_dfts = []

    for time in tqdm(range(8)):
        data, pred, dfts = fit_interpolated_dft2_at_time(filename, variable, time, skip_levels, interpolation,
                                                         quantile=quantile, verbose=False)
        all_dfts.append(dfts)

        data_variance += data.var()
        mae_error += abs(data - pred)
        mse_error += (data - pred) ** 2

    size = sum(el.nbytes for dfts in all_dfts for dft in dfts for el in dft)
    lines = f"""
    Original Stdev: {(data_variance / 8) ** 0.5} m/s
    Predicted MAE:  {(mae_error / 8).mean()} m/s
    Predicted RMSE: {(mse_error / 8).mean() ** 0.5} m/s

    Size/day: {size / (1000 ** 2)} MB
    Size/year: {size * 365 / (1000 ** 2)} MB
    """
    print(lines)


def fit_dft2_hermite_interpolated_at_time(filename: str, variable: str, time: int, quantile: float, skip_levels: int,
                                          verbose: bool = True):
    data = load_variable_at_time(filename, variable, time)
    prediction = np.zeros((36, 361, 576), dtype="float32")

    preds = []
    dfts = []
    levels = sorted(set(range(0, 36, skip_levels)).union({35}))
    for lev in (tqdm(levels) if verbose else levels):
        fft = dft2_at_time_and_level(data[lev], quantile=quantile)
        dfts.append(fft)
        preds.append(idft2_at_time_and_level(*fft))

    tensions = []
    biases = []

    for lev in range(36):
        if lev % skip_levels == 0:
            continue

        pred_idx = levels.index(lev - lev % skip_levels)
        t = (lev - levels[pred_idx]) / (levels[pred_idx + 1] - levels[pred_idx])

        best_tension, best_bias = fit_kochanek_bartels_spline(data[lev], preds, pred_idx, t)
        tensions.append(best_tension)
        biases.append(best_bias)

    for lev in range(36):
        if lev in levels:
            prediction[lev] = preds[levels.index(lev)]
            continue

        pred_idx = levels.index(lev - lev % skip_levels)
        t = (lev - levels[pred_idx]) / (levels[pred_idx + 1] - levels[pred_idx])
        prediction[lev] = hermite_interpolate(preds, pred_idx, t, tensions[lev - pred_idx - 1], 0)

    data = data.astype("float32")
    if verbose:
        size = sum(el.nbytes for dft in dfts for el in dft)
        print(f"Original Stdev: {data.std()} m/s")
        print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
        print(f"Predicted MAE:  {mae(data, prediction)} m/s")
        print(f"Size/time: {size / (1000 ** 2)} MB")
        print(f"Size/day: {size * 8 / (1000 ** 2)} MB")
        print(f"Size/year: {size * 8 * 365 / (1000 ** 2)} MB")
    else:
        return data, prediction, dfts


def fit_dft2_hermite_interpolated_on_day(filename: str, variable: str, quantile: float, skip_levels: int):
    data_variance = 0
    mae_error = np.zeros((36, 361, 576), dtype="float32")
    mse_error = np.zeros((36, 361, 576), dtype="float32")
    all_dfts = []

    for time in tqdm(range(8)):
        data, pred, dfts = fit_dft2_hermite_interpolated_at_time(filename, variable, time, quantile, skip_levels,
                                                                 verbose=False)
        all_dfts.append(dfts)

        data_variance += data.var()
        mae_error += abs(data - pred)
        mse_error += (data - pred) ** 2

    size = sum(el.nbytes for dfts in all_dfts for dft in dfts for el in dft)
    lines = f"""
    Original Stdev: {(data_variance / 8) ** 0.5} m/s
    Predicted MAE:  {(mae_error / 8).mean()} m/s
    Predicted RMSE: {(mse_error / 8).mean() ** 0.5} m/s

    Size/day: {size / (1000 ** 2)} MB
    Size/year: {size * 365 / (1000 ** 2)} MB
    """
    print(lines)

In [7]:
_, _, dfts = fit_dft2_at_time("MERRA2.tavg3_3d_asm_Nv.YAVG0101.nc4", "U", time=0, quantile=0.9935, verbose=False)

In [8]:
dfts[0][2]

array([ 23,   1,  17,   1,  19,   1,  22,   1,  20,   1,  19,   1,  16,
         1,  19,   1,  20,   1,  16,   1,  18,   1,  15,   1,  13,   1,
        10,   1,  11,   1,  10,   1,   7,   1,   9,   1,   5,   1,   7,
         1,   6,   1,   4,   1,   1,   1,   5,   1,   1,   1,   3,   1,
         1,   1,   1,   1,   0,   1,   0,   2,   1,   4,   0, 254,   0,
        39,   0,   1,   0,   2,   0,   1,   0,   3,   0,   1,   1,   1,
         1,   1,   3,   1,   3,   1,   0,   1,   4,   1,   4,   1,   5,
         1,   9,   1,   9,   1,  10,   1,  10,   1,  12,   1,  14,   1,
        14,   1,  18,   1,  16,   1,  17,   1,  17,   1,  18,   1,  19,
         1,  16,   1,  20,   1,  21,   1,  16,   1,  21], dtype=uint8)

In [14]:
rlen_decode_array(dfts[0][3], 1) - 1

array([255,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   1,
         0,   0,   0,   0,   0,   0,   0,   3,   0,   4, 254,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   1,   1,   0,   0,   0,   1,
         4,   6, 254,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   1,   0,   1,   0,   1,   0,   2,   5, 254,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   1,   0,   1,   5,   9, 254,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   2,   0,   1,
         4, 254,   0,   0,   0,   0,   0,   0,   0,   0,   1,   0,   0,
         0,   0,   0,   0,   0,   1,   0,   6, 254,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   1,   0,   0,   1,   2, 254,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         1,   0,   1,   1,   2,   0, 254,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   1,   0,   0,   0,   2,   2,   

In [15]:
np.count_nonzero(rlen_decode_array(dfts[0][3], 1) - 1) * 3

708

In [23]:
fit_dft2_on_day("MERRA2.tavg3_3d_asm_Nv.YAVG0101.nc4", "U", quantile=0.9935)


  0%|          | 0/8 [00:00<?, ?it/s]


    Original Stdev: 10.1023820749644 m/s
    Predicted MAE:  0.45432770252227783 m/s
    Predicted RMSE: 0.6424456984583848 m/s

    Size/day: 0.963742 MB
    Size/year: 351.76583 MB
    


In [24]:
fit_dft2_hermite_interpolated_on_day("MERRA2.tavg3_3d_asm_Nv.YAVG0101.nc4", "U",
                                     quantile=0.99044, skip_levels=3)


  0%|          | 0/8 [00:00<?, ?it/s]


    Original Stdev: 10.1023820749644 m/s
    Predicted MAE:  0.44000789523124695 m/s
    Predicted RMSE: 0.6435797632915197 m/s

    Size/day: 0.508316 MB
    Size/year: 185.53534 MB
    
