In [1]:
import pickle

from plot import *
from fft import *
from maths import rmse, mse, mae, linear_interpolate, cubic_interpolate
from maths import catmull_rom_interpolate, hermite_interpolate, fit_kochanek_bartels_spline


In [2]:
def dft3_at_time(data: np.ndarray, quantile: float = 0.75):
    fft = np.fft.rfftn(data)
    amplitudes = np.abs(fft)

    fft_real = []
    fft_imag = []
    fft_i_indices = []
    fft_j_indices = []
    fft_k_indices = []

    cutoff_amp = np.quantile(amplitudes, quantile)

    for i in range(36):
        for k in range(289):
            for j in range(361):
                if amplitudes[i, j, k] < cutoff_amp:
                    continue

                fft_real.append(fft[i, j, k].real)
                fft_imag.append(fft[i, j, k].imag)
                fft_i_indices.append(i)
                fft_j_indices.append(j)
                fft_k_indices.append(k)

    fft_real = (np.array(fft_real, dtype="float32") / 131072).astype("float16")
    fft_real = encode_zlib(fft_real)

    fft_imag = (np.array(fft_imag, dtype="float32") / 131072).astype("float16")
    fft_imag = encode_zlib(fft_imag)

    fft_i_indices = np.array(fft_i_indices, dtype="int16")
    fft_i_indices = encode_difference_uint8(fft_i_indices)
    fft_i_indices = encode_zlib(fft_i_indices)

    fft_j_indices = np.array(fft_j_indices, dtype="int16")
    fft_j_indices = encode_difference_uint8(fft_j_indices)
    fft_j_indices = encode_zlib(fft_j_indices)

    fft_k_indices = np.array(fft_k_indices, dtype="int16")
    fft_k_indices = encode_difference_uint8(fft_k_indices)
    fft_k_indices = encode_zlib(fft_k_indices)

    return fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices


def idft3_at_time(fft_real, fft_imag, fft_i_indices, fft_j_indices, fft_k_indices):
    ifft = np.zeros((36, 361, 289), dtype="complex64")

    fft_real = decode_zlib(fft_real, dtype="float16")
    fft_imag = decode_zlib(fft_imag, dtype="float16")
    fft = fft_real.astype("complex64") * 131072 + fft_imag.astype("complex64") * 131072j

    fft_i_indices = decode_zlib(fft_i_indices)
    fft_i_indices = decode_difference_uint8(fft_i_indices)

    fft_j_indices = decode_zlib(fft_j_indices)
    fft_j_indices = decode_difference_uint8(fft_j_indices)

    fft_k_indices = decode_zlib(fft_k_indices)
    fft_k_indices = decode_difference_uint8(fft_k_indices)

    for idx in range(len(fft)):
        ifft[fft_i_indices[idx], fft_j_indices[idx], fft_k_indices[idx]] = fft[idx]

    return np.fft.irfftn(ifft)

In [3]:
def fit_dft3_at_time(filename: str, variable: str, time: int, verbose: bool = True, predict: bool = True, **kwargs):
    output = f"models/2D-dft/{variable}/{kwargs['quantile']}/{format_date(filename, for_output=True)}-{format_time(time, filename)}.bin"
    if os.path.exists(output) and not verbose:
        return
    else:
        data = load_variable_at_time(filename, variable, time, cache=verbose)
        fft = dft3_at_time(data, **kwargs)
        with open(output, "wb") as file:
            pickle.dump(fft, file, protocol=pickle.HIGHEST_PROTOCOL)

    if verbose:
        data = data.astype("float32")
        prediction = idft3_at_time(*fft)

        size = sum(el.nbytes for el in fft)
        lines = f"""
        Original Stdev: {data.std()} m/s
        Predicted MAE:  {mae(data, prediction)} m/s
        Predicted RMSE: {rmse(data, prediction)} m/s

        Size/time: {size / (1000 ** 2)} MB
        Size/day: {size * 8 / (1000 ** 2)} MB
        Size/year: {size * 8 * 365 / (1000 ** 2)} MB
        """
        print(lines)
    elif predict:
        return data.astype("float32"), idft3_at_time(*fft), fft
    else:
        return fft


def fit_dft3_on_day(filename: str, variable: str, **kwargs):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    all_dfts = []

    for time in tqdm(range(8)):
        data, pred, dfts = fit_dft3_at_time(filename, variable, time, verbose=False, **kwargs)
        all_dfts.append(dfts)

        data_variance += data.var()
        mae_error += mae(data, pred)
        mse_error += mse(data, pred)

    size = sum(el.nbytes for dft in all_dfts for el in dft)
    lines = f"""
    Original Stdev: {(data_variance / 8) ** 0.5} m/s
    Predicted MAE:  {(mae_error / 8).mean()} m/s
    Predicted RMSE: {(mse_error / 8).mean() ** 0.5} m/s

    Size/day: {size / (1000 ** 2)} MB
    Size/year: {size * 365 / (1000 ** 2)} MB
    """
    print(lines)


def fit_dft3(filename: str, variable: str, times: list[int] = list(range(8)), **kwargs):
    for mm in tqdm(range(1, 13)):
        for dd in tqdm(range(1, monthrange(2001, mm)[1] + 1)):
            for time in times:
                fit_dft3_at_time(filename.format(mm, dd), variable, time, verbose=False, predict=False, **kwargs)


def test_time_interpolated_dft3_on_day(filename: str, variable: str, times, quantile: float,
                                       interpolation: callable = linear_interpolate,
                                       verbose: bool = True):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    all_dfts = []
    preds = []

    times = sorted(times)
    first_time = times[0]
    last_time = times[-1]

    for time in (tqdm(times) if verbose else times):
        path = f"{variable}/{quantile}/{format_date(filename, for_output=True)}-{format_time(time, filename)}"
        with open("models/2D-dft/" + path + ".bin", "rb") as file:
            fft = pickle.load(file)

        data = load_variable_at_time(filename, variable, time).astype("float32")
        pred = idft3_at_time(*fft)

        preds.append(pred)
        all_dfts.append(fft)

        data_variance += data.var()
        mae_error += mae(data, pred)
        mse_error += mse(data, pred)

    if interpolation in {cubic_interpolate, catmull_rom_interpolate, hermite_interpolate}:
        path = f"{variable}/{quantile}/{format_date(get_previous_file(filename), for_output=True)}-{format_time(last_time, filename)}"
        with open("models/2D-dft/" + path + ".bin", "rb") as file:
            fft = pickle.load(file)
        preds.insert(0, idft3_at_time(*fft))
        times.insert(0, last_time - 8)

    if times[-1] != 7:
        path = f"{variable}/{quantile}/{format_date(get_next_file(filename), for_output=True)}-{format_time(first_time, filename)}"
        with open("models/2D-dft/" + path + ".bin", "rb") as file:
            fft = pickle.load(file)
        preds.append(idft3_at_time(*fft))
        times.append(8 + first_time)

    for time in (tqdm(range(8)) if verbose else range(8)):
        if time in times:
            continue

        data = load_variable_at_time(filename, variable, time).astype("float32")

        pred_idx = -1
        for t in times:
            if t < time:
                pred_idx += 1

        t = (time - times[pred_idx]) / (times[pred_idx + 1] - times[pred_idx])
        pred = interpolation(preds, pred_idx, t)

        data_variance += data.var()
        mae_error += mae(data, pred)
        mse_error += mse(data, pred)

    data_variance /= 8
    mae_error /= 8
    mse_error /= 8

    if verbose:
        size = sum(el.nbytes for dft in all_dfts for el in dft)
        lines = f"""
        Original Stdev: {data_variance ** 0.5} m/s
        Predicted MAE:  {mae_error} m/s
        Predicted RMSE: {mse_error ** 0.5} m/s

        Size/day: {size / (1000 ** 2)} MB
        Size/year: {size * 365 / (1000 ** 2)} MB
        """
        print(lines)
    else:
        return data_variance, mae_error, mse_error


def test_optimized_time_interpolated_dft3_on_day(filename: str, variable: str, times, quantile: float,
                                                 verbose: bool = True):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    all_dfts = []
    preds = []

    first_time = times[0]
    last_time = times[-1]

    for time in (tqdm(times) if verbose else times):
        path = f"{variable}/{quantile}/{format_date(filename, for_output=True)}-{format_time(time, filename)}"
        with open("models/2D-dft/" + path + ".bin", "rb") as file:
            fft = pickle.load(file)

        data = load_variable_at_time(filename, variable, time).astype("float32")
        pred = idft3_at_time(*fft)

        preds.append(pred)
        all_dfts.append(fft)

        data_variance += data.var()
        mae_error += mae(data, pred)
        mse_error += mse(data, pred)

    path = f"{variable}/{quantile}/{format_date(get_previous_file(filename), for_output=True)}-{format_time(last_time, filename)}"
    with open("models/2D-dft/" + path + ".bin", "rb") as file:
        fft = pickle.load(file)
    preds.insert(0, idft3_at_time(*fft))
    times.insert(0, last_time - 8)

    if times[-1] != 7:
        path = f"{variable}/{quantile}/{format_date(get_next_file(filename), for_output=True)}-{format_time(first_time, filename)}"
        with open("models/2D-dft/" + path + ".bin", "rb") as file:
            fft = pickle.load(file)
        preds.append(idft3_at_time(*fft))
        times.append(8 + first_time)

    for time in (tqdm(range(8)) if verbose else range(8)):
        if time in times:
            continue

        data = load_variable_at_time(filename, variable, time).astype("float32")

        pred_idx = -1
        for t in times:
            if t < time:
                pred_idx += 1

        t = (time - times[pred_idx]) / (times[pred_idx + 1] - times[pred_idx])
        tension, bias = fit_kochanek_bartels_spline(data, preds, pred_idx, t)
        pred = hermite_interpolate(preds, pred_idx, t, tension, bias)

        data_variance += data.var()
        mae_error += mae(data, pred)
        mse_error += mse(data, pred)

    data_variance /= 8
    mae_error /= 8
    mse_error /= 8

    if verbose:
        size = sum(el.nbytes for dft in all_dfts for el in dft)
        lines = f"""
        Original Stdev: {data_variance ** 0.5} m/s
        Predicted MAE:  {mae_error} m/s
        Predicted RMSE: {mse_error ** 0.5} m/s

        Size/day: {size / (1000 ** 2)} MB
        Size/year: {size * 365 / (1000 ** 2)} MB
        """
        print(lines)
    else:
        return data_variance, mae_error, mse_error


def test_dft3(*args, **kwargs):
    test_time_interpolated_dft3(*args, times=list(range(8)), **kwargs)


def test_time_interpolated_dft3(filename: str, *args, **kwargs):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    all_dfts = []

    for mm in tqdm(range(1, 13)):
        for dd in tqdm(range(1, monthrange(2001, mm)[1] + 1)):
            var, mae_loss, mse_loss = test_time_interpolated_dft3_on_day(filename.format(mm, dd),
                                                                         *args, verbose=False, **kwargs)
            data_variance += var
            mae_error += mae_loss
            mse_error += mse_loss

    lines = f"""
    Original Stdev: {(data_variance / 365) ** 0.5} m/s
    Predicted MAE:  {mae_error / 365} m/s
    Predicted RMSE: {(mse_error / 365) ** 0.5} m/s
    """
    print(lines)


def test_optimized_time_interpolated_dft3(filename: str, *args, **kwargs):
    data_variance = 0
    mae_error = 0
    mse_error = 0
    all_dfts = []

    for mm in tqdm(range(1, 13)):
        for dd in tqdm(range(1, monthrange(2001, mm)[1] + 1)):
            var, mae_loss, mse_loss = test_optimized_time_interpolated_dft3_on_day(filename.format(mm, dd),
                                                                                   *args, verbose=False, **kwargs)
            data_variance += var
            mae_error += mae_loss
            mse_error += mse_loss

    lines = f"""
    Original Stdev: {(data_variance / 365) ** 0.5} m/s
    Predicted MAE:  {mae_error / 365} m/s
    Predicted RMSE: {(mse_error / 365) ** 0.5} m/s
    """
    print(lines)

In [None]:
fit_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935)


In [None]:
test_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "V", quantile=0.9935)

In [None]:
test_time_interpolated_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "V", quantile=0.9935, times=[0, 2, 4, 6],
                            interpolation=linear_interpolate)


  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/31 [00:00<?, ?it/s]

In [None]:
test_time_interpolated_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935,
                            times=[0, 2, 4, 6], interpolation=cubic_interpolate)


In [None]:
test_time_interpolated_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG{:0>2}{:0>2}.nc4", "U", quantile=0.9935,
                            times=[0, 2, 4, 6], interpolation=catmull_rom_interpolate)


In [None]:
test_optimized_time_interpolated_dft3("MERRA2.tavg3_3d_asm_Nv.YAVG0101.nc4", "U", quantile=0.9935, times=[0, 2, 4, 6])
