In [None]:
from tqdm.notebook import tqdm

from plot import *
from fft import *
from maths import rmse, mae


In [None]:
def plot_yavg_dft_at_time_level_and_latitude(filename: str, variable: str, time: int, level: int, latitude: int, **kwargs):
    data = load_yavg_variable_at_time_level_and_latitude(filename, variable, time, level, latitude)

    fft = dft_at_time_level_and_latitude(data, **kwargs)
    prediction = idft_at_time_level_and_latitude(*fft)

    ax = plt.axes()

    ax.plot(np.linspace(-180, 180, 576), data, linewidth=1)
    ax.plot(np.linspace(-180, 180, 576), prediction, linestyle="dashed", linewidth=1)

    ax.xaxis.set_major_formatter(FormatStrFormatter("%d°"))

    title = f"{get_variable_name_from_code(variable)} ({get_units_from_variable(variable)}) " \
            f"at {format_latitude(latitude)}, {format_level(level)} at {format_time(time)}"
    output = f"{variable}-{format_latitude(latitude, filename=True)}-{format_level(level, filename=True)}" \
             f"-yavg-{filename.split('.')[-2][2:]}-{format_time(time)}"

    plt.title(title, fontsize=8)
    plt.savefig("assets/1D-yavg-graphs/" + output + ".png", dpi=300)
    plt.show()

    print(f"Original Standard Dev (m/s):  {data.std()}")
    print(f"Predicted Standard Dev (m/s): {rmse(data, prediction)}")
    print(f"Frequencies: {len(fft[0])}")
    print(f"Size/level: {sum(el.nbytes for el in fft) * 361 / (1000 **2)} mB")
    print(f"Size/time: {sum(el.nbytes for el in fft) * 361 * 72 / (1000 **2)} mB")
    print(f"Size/day: {sum(el.nbytes for el in fft) * 361 * 72 * 8 / (1000 **2)} mB")
    print(f"Size/year: {sum(el.nbytes for el in fft) * 361 * 72 * 8 * 365 / (1000 **2)} mB")


In [None]:
%matplotlib notebook

plot_yavg_dft_at_time_level_and_latitude("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "U",
                                         time=0, level=71, latitude=180, quantile=0.965)

In [None]:
def plot_yavg_dft_at_time_and_level(filename: str, variable: str, time: int, level: int, **kwargs):
    data = load_yavg_variable_at_time_and_level(filename, variable, time, level)

    prediction = np.zeros((361, 576))

    worst_deviation = 0
    worst_lat = 0

    for lat in tqdm(range(361)):
        subdata = data[lat]

        fft = dft_at_time_level_and_latitude(subdata, **kwargs)
        pred = idft_at_time_level_and_latitude(*fft)

        if (dev := rmse(subdata, pred)) > worst_deviation:
            worst_deviation = dev
            worst_lat = lat

        prediction[lat] = pred

    _, ax1, ax2 = create_1x2_plot(f"{variable} ({get_units_from_variable(variable)}) at {format_level(level)}"
                                  f" at {format_time(time)}", sharey=True)
    ax1.imshow(data, cmap="viridis", origin="lower")
    ax2.imshow(prediction, cmap="viridis", origin="lower")

    _, ax1, ax2 = create_1x2_plot(f"Absolute Error ({get_units_from_variable(variable)}) at {format_level(level)}"
                                  f" at {format_time(time)}")
    ax1.imshow(np.abs(prediction - data), cmap="hot", origin="lower")
    ax2.plot(np.linspace(-180, 180, 576), data[worst_lat], linewidth=1)
    ax2.plot(np.linspace(-180, 180, 576), prediction[worst_lat], linestyle="dashed", linewidth=1)

    plt.show()

    print(f"Original Stdev: {data.astype('float32').std()} m/s")
    print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
    print(f"Frequencies/latitude: {len(fft[0])}")
    print(f"Size/time: {len(fft[0]) * 361 * 5 * 72 / (1000 **2)} mB")
    print()
    print(f"Stdev at Worst: {data[worst_lat].astype('float32').std()} m/s")
    print(f"Worst RMSE: {worst_deviation} m/s")


In [None]:
%matplotlib notebook

plot_yavg_dft_at_time_and_level("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "U",
                                time=0, level=71, quantile=0.965)


In [None]:
def plot_yavg_dft_at_time(filename: str, variable: str, time: int, **kwargs):
    data = load_yavg_variable_at_time(filename, variable, time)
    prediction = np.zeros((72, 361, 576))

    worst_deviation_ratio = 0
    worst_deviation = 0
    worst_deviation_std = 0

    deviations = []

    for lev in tqdm(range(72)):
        for lat in range(361):
            subdata = data[lev][lat]

            fft = dft_at_time_level_and_latitude(subdata, **kwargs)
            pred = idft_at_time_level_and_latitude(*fft)

            if (ratio := (dev := rmse(subdata, pred)) / subdata.astype("float32").std()) > worst_deviation_ratio:
                worst_deviation_ratio = ratio
                worst_deviation = dev
                worst_deviation_std = subdata.astype("float32").std()

            deviations.append(dev)

            prediction[lev][lat] = pred

    plot_histogram([{"data": deviations}], "Distribution of Prediction RMSE (m/s)", bins=100)

    print(f"Original Stdev: {(data.astype('float32')).std()} m/s")
    print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
    print(f"Size/time: {len(fft[0]) * 361 * 72 * 5 / (1000 **2)} mB")
    print()
    print(f"Stdev at Worst: {worst_deviation_std} m/s")
    print(f"Worst RMSE: {worst_deviation} m/s")


In [None]:
%matplotlib notebook

plot_yavg_dft_at_time("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "U",
                 time=0, quantile=0.965)

In [None]:
def compare_dft_at_time_vs_quantile(filename: str, variable: str, time: int, quantiles: list[float], **kwargs):
    data = load_variable_at_time(filename, variable, time)
    data_dict = []

    for quant in tqdm(quantiles):
        prediction = np.zeros((72, 361, 576))

        deviations = []

        for lev in range(72):
            for lat in range(361):
                subdata = data[lev][lat]

                fft = dft_at_time_level_and_latitude(subdata, **kwargs)
                pred = idft_at_time_level_and_latitude(*fft)

                deviations.append(rmse(subdata, pred))
                prediction[lev][lat] = pred

        data_dict.append({"data": deviations, "label": f"{quant * 100:0>1}%"})

    fig, ax11, ax12, ax21, ax22 = create_2x2_plot("Distribution of Prediction RMSE (m/s)", sharex=True, sharey=True)

    plt.sca(ax11)
    plot_histogram([data_dict[0]], "", bins=50)
    plt.sca(ax12)
    plot_histogram([data_dict[1]], "", bins=50)
    plt.sca(ax21)
    plot_histogram([data_dict[2]], "", bins=50)
    plt.sca(ax22)
    plot_histogram([data_dict[3]], "", bins=50)


In [None]:
%matplotlib notebook

compare_dft_at_time_vs_quantile("MERRA2_100.tavg3_3d_asm_Nv.19800101.nc4", "U",
                                time=0, quantiles=[0.95, 0.965, 0.975, 0.99])


In [None]:
def fit_yavg_dft_at_time(filename: str, variable: str, time: int, levels: int = 36, **kwargs):
    data = load_yavg_variable_at_time(filename, variable, time, levels=levels)
    prediction = np.zeros((levels, 361, 576))

    for lev in tqdm(range(levels)):
        for lat in range(361):
            subdata = data[lev, lat]

            fft = dft_at_time_level_and_latitude(subdata, **kwargs)
            pred = idft_at_time_level_and_latitude(*fft)

            prediction[lev, lat] = pred

    print(f"Original Stdev: {data.astype('float32').std()} m/s")
    print(f"Predicted MAE:  {mae(data, prediction)} m/s")
    print(f"Predicted RMSE: {rmse(data, prediction)} m/s")
    print(f"Frequencies/latitude: {len(fft[0])}")
    print(f"Size/time: {len(fft[0]) * 361 * 5 * levels / (1000 ** 2)} MB")
    print(f"Size/day: {len(fft[0]) * 361 * 5 * levels * 8 / (1000 ** 2)} MB")
    print(f"Size/year: {len(fft[0]) * 361 * 5 * levels * 8 * 365 / (1000 ** 2)} MB")


In [None]:
fit_yavg_dft_at_time("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "U",
                     time=0, levels=36, quantile=0.965)

fit_yavg_dft_at_time("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "V",
                     time=0, levels=36, quantile=0.965)


In [None]:
def fit_yavg_dft_on_day(filename: str, variable: str, levels: int = 36, **kwargs):
    data_variance = 0
    mae_error = np.zeros((levels, 361, 576), dtype="float32")
    mse_error = np.zeros((levels, 361, 576), dtype="float32")

    for time in tqdm(range(8)):
        data = load_yavg_variable_at_time(filename, variable, time, levels=levels)
        prediction = np.zeros((levels, 361, 576))

        for lev in range(levels):
            for lat in range(361):
                subdata = data[lev, lat]

                fft = dft_at_time_level_and_latitude(subdata, **kwargs)
                pred = idft_at_time_level_and_latitude(*fft)

                prediction[lev, lat] = pred

        data_variance += data.astype('float32').var()
        mae_error += abs(data - prediction)
        mse_error += (data - prediction) ** 2

    print(f"Original Stdev: {(data_variance / 8) ** 0.5} m/s")
    print(f"Predicted MAE:  {(mae_error / 8).mean()} m/s")
    print(f"Predicted RMSE: {(mse_error / 8).mean() ** 0.5} m/s")
    print(f"Frequencies/latitude: {len(fft[0])}")
    print(f"Size/time: {len(fft[0]) * 361 * 5 * levels / (1000 ** 2)} MB")
    print(f"Size/day: {len(fft[0]) * 361 * 5 * levels * 8 / (1000 ** 2)} MB")
    print(f"Size/year: {len(fft[0]) * 361 * 5 * levels * 8 * 365 / (1000 ** 2)} MB")


In [None]:
fit_yavg_dft_on_day("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "U",
                    levels=36, quantile=0.965)

fit_yavg_dft_on_day("MERRA2_{}.tavg3_3d_asm_Nv.{}0101.nc4", "V",
                    levels=36, quantile=0.965)
