## Imports

In [1]:
# Official Libraries
import os
import time
import logging
import importlib
import numpy as np
from functools import partial
from contextlib import contextmanager

# Local Modules
from modules.constants import ENVIRONMENT_PATH
import modules.calculations as calculations
import modules.plotter as plotter

## Logger

In [11]:
logging.basicConfig(level=logging.INFO, format="%(message)s")


@contextmanager
def time_block(label: str):
    """
    Context manager to time a block of code and log the elapsed time.

    Parameters
    ----------
    label : str
        A descriptive label for the code block being timed.
    """
    start_time = time.time()  # Record the start time
    yield  # Execute the code block
    end_time = time.time()  # Record the end time

    # Calculate and log elapsed time
    elapsed_time = end_time - start_time
    logging.info(f"{label} took {elapsed_time:.2f} seconds.")

## Flags

In [12]:
DO_PRELIMINARY_ANALYSIS_1 = True
DO_PRELIMINARY_ANALYSIS_2 = True

## Preliminary Analysis #1

In [13]:
# ==========================
# Preliminary Analysis 1
# ==========================

if DO_PRELIMINARY_ANALYSIS_1:
    importlib.reload(calculations)
    # --------------------------------------
    # Indian Monsoon Index & Onset
    # --------------------------------------
    with time_block("Indian Monsoon Index & Onset"):
        indian_monsoon_index = calculations.Indian_Monsoon_Onset_Bin_Wang(
            filepath=os.path.join(
                ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_u.nc"
            )
        )
        years = np.arange(1979, 2022)
        onset_dates = np.argmax(indian_monsoon_index > 0, axis=1)
        indian_monsoon_onset = list(zip(years, onset_dates))

    # --------------------------------------
    # Meridional Mean Zonal Mass Stream Function
    # --------------------------------------
    with time_block("Meridional Mean Zonal Mass Stream Function"):
        streamfunction, streamfunction_dimensions = (
            calculations.streamfunction_Schwendike(
                filepath=os.path.join(
                    ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE,
                    "sparse_divergent_wind.nc",
                )
            )
        )

    # --------------------------------------
    # Meridional Mean Moist Static Energy Vertical Flux
    # --------------------------------------
    with time_block("Meridional Mean Moist Static Energy Vertical Flux"):
        mse_vertical_flux, mse_vertical_flux_dimensions = (
            calculations.calculate_MSE_vertical_flux(
                mse_filepath=os.path.join(
                    ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_mse.nc"
                ),
                w_filepath=os.path.join(
                    ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_w.nc"
                ),
            )
        )

Indian Monsoon Index & Onset took 19.35 seconds.
Meridional Mean Zonal Mass Stream Function took 35.86 seconds.
Meridional Mean Moist Static Energy Vertical Flux took 53.90 seconds.


In [14]:
# ==========================
# Preliminary Analysis 1 - Charts
# ==========================

if DO_PRELIMINARY_ANALYSIS_1:
    importlib.reload(plotter)
    # --------------------------------------
    # Zonal Wind Shear Difference Chart
    # --------------------------------------
    with time_block("Zonal Wind Shear Difference Chart"):
        plotter.display_IMI_evolution(
            indian_monsoon_index=indian_monsoon_index,
            output_path=os.path.join(
                ENVIRONMENT_PATH.ABSOLUTE_PATH_IMAGES_PRELIMINARY, "IMI_evolution.png"
            ),
        )

    # --------------------------------------
    # Composite Chart - Early vs Late Onset Analysis
    # --------------------------------------
    with time_block("Composite Chart"):
        sorted_indices = np.argsort([key[1] for key in indian_monsoon_onset])
        early_onset_indices = sorted_indices[:10]
        late_onset_indices = sorted_indices[-10:]
        calendar_indices = np.arange(start=139 - 10, stop=138 + 30, step=2)

        for calendar_index in calendar_indices:
            # ---------------------------------
            # Compute Early and Late Onset Averages
            # ---------------------------------
            streamfunction_early = np.mean(streamfunction[early_onset_indices], axis=0)[
                calendar_index
            ]
            streamfunction_late = np.mean(streamfunction[late_onset_indices], axis=0)[
                calendar_index
            ]
            mse_vertical_flux_early = np.mean(
                mse_vertical_flux[early_onset_indices], axis=0
            )[calendar_index]
            mse_vertical_flux_late = np.mean(
                mse_vertical_flux[late_onset_indices], axis=0
            )[calendar_index]

            # ---------------------------------
            # Plot Streamfunction Composite (Shading and Contour)
            # ---------------------------------
            shading_levels = np.linspace(-2e10, 2e10, 11, endpoint=True)
            contour_levels = np.linspace(-2e10, 2e10, 11, endpoint=True)
            colorbar_levels = np.linspace(-2e10, 2e10, 5, endpoint=True)

            figure_object = plotter.display_early_late_composite(
                shading_early=streamfunction_early,
                shading_late=streamfunction_late,
                contour_early=streamfunction_early,
                contour_late=streamfunction_late,
                grids=streamfunction_dimensions,
                calendar_index=calendar_index,
                shading_levels=shading_levels,
                contour_levels=np.delete(contour_levels, len(contour_levels) // 2),
                colorbar_levels=colorbar_levels,
                plt_title=r"Shading: $\Psi$, Contour: $\Psi$, ",
                plt_label=r"(kg/s)",
                output_path=ENVIRONMENT_PATH.ABSOLUTE_PATH_IMAGES_PRELIMINARY,
                filename="streamfunction",
            )

            # ---------------------------------
            # Plot MSE Vertical Flux Composite (Shading) with Streamfunction Contour
            # ---------------------------------
            shading_levels = np.linspace(-2e4, 2e4, 41, endpoint=True)
            contour_levels = np.linspace(-2e10, 2e10, 11, endpoint=True)
            colorbar_levels = np.linspace(-2e4, 2e4, 5, endpoint=True)

            plotter.display_early_late_composite(
                shading_early=mse_vertical_flux_early,
                shading_late=mse_vertical_flux_late,
                contour_early=streamfunction_early,
                contour_late=streamfunction_late,
                grids=streamfunction_dimensions,
                calendar_index=calendar_index,
                shading_levels=shading_levels,
                contour_levels=np.delete(contour_levels, len(contour_levels) // 2),
                colorbar_levels=colorbar_levels,
                cmap="RdBu",
                plt_title=r"Shading: $\omega*MSE$, Contour: $\Psi$, ",
                plt_label=r"(Pa/s)*(J/kg)",
                output_path=ENVIRONMENT_PATH.ABSOLUTE_PATH_IMAGES_PRELIMINARY,
                filename="streamfunction_and_MSE_flux",
            )

Zonal Wind Shear Difference Chart took 0.45 seconds.
Composite Chart took 8.90 seconds.


## Preliminary Analysis #2

In [15]:
def reader(
    file_path: str, variable_name: str, pressure_level: int
) -> tuple[np.ndarray, dict[str, np.ndarray]]:
    from netCDF4 import Dataset

    with Dataset(file_path, mode="r") as dataset:
        dimension_names = dataset[variable_name].dimensions
        dimensions = {name: dataset[name][:] for name in dimension_names}

        data_slice = [slice(None)] * len(dimensions)
        for idx, key in enumerate(dimension_names):
            if key in {"time", "lat", "lon"}:
                continue
            elif key == "plev":
                data_slice[idx] = pressure_level
                dimensions["plev"] = dimensions["plev"][pressure_level]
        variable = dataset[variable_name][tuple(data_slice)]

    return variable, dimensions


def __segment_data(data, axis, year_index, reference_index, segment_length):
    num_iterations = len(year_index)

    def __generator():
        for i, year in enumerate(year_index):
            start = year * 365 + reference_index
            end = start + segment_length
            segment_slice = [slice(None)] * data.ndim
            segment_slice[axis] = slice(start, end)
            yield data[tuple(segment_slice)]

    return num_iterations, __generator()


def WK_olr_unittest():
    importlib.reload(calculations)
    with time_block("read #1"):
        file_path = os.path.join(
            ENVIRONMENT_PATH.ABSOLUTE_PATH_SATELLITE_TEST, "OLR_sparse.nc"
        )
        variable_name = "olr"
        pressure_level = -9999
        variable, dimensions = reader(
            file_path=file_path,
            variable_name=variable_name,
            pressure_level=pressure_level,
        )
    with time_block("calculate #2"):
        symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions = (
            calculations.power_spectrum_Wheeler_Kiladis(
                variable,
                dimensions,
                north_boundary=15,
                south_boundary=-15,
                east_boundary=360,
                west_boundary=0,
            )
        )
    return variable_name, symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions


def WK_mse_v_unittest():
    importlib.reload(calculations)
    with time_block("read #1"):
        file_path = os.path.join(
            ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_v.nc"
        )
        variable_name = "v"
        pressure_level = 2
        meridional_wind, dimensions = reader(
            file_path=file_path,
            variable_name=variable_name,
            pressure_level=pressure_level,
        )
        # variable = meridional_wind
        file_path = os.path.join(
            ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_mse.nc"
        )
        variable_name = "mse"
        pressure_level = 2
        specific_humidity, dimensions = reader(
            file_path=file_path,
            variable_name=variable_name,
            pressure_level=pressure_level,
        )
        variable_name = "meridional MSE flux"
        variable = meridional_wind * specific_humidity
    with time_block("calculate #2"):
        symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions = (
            calculations.power_spectrum_Wheeler_Kiladis(
                variable,
                dimensions,
                north_boundary=15,
                south_boundary=-15,
                east_boundary=360,
                west_boundary=0,
            )
        )
    return variable_name, symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions


def WK_mse_v_seasonal_unittest():
    importlib.reload(calculations)
    with time_block("read #1"):
        file_path = os.path.join(
            ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_v.nc"
        )
        variable_name = "v"
        pressure_level = 2
        meridional_wind, dimensions = reader(
            file_path=file_path,
            variable_name=variable_name,
            pressure_level=pressure_level,
        )
        # variable = meridional_wind
        file_path = os.path.join(
            ENVIRONMENT_PATH.ABSOLUTE_PATH_ERA5_SPARSE, "sparse_mse.nc"
        )
        variable_name = "mse"
        pressure_level = 2
        specific_humidity, dimensions = reader(
            file_path=file_path,
            variable_name=variable_name,
            pressure_level=pressure_level,
        )
        variable_name = "meridional MSE flux"
        variable = meridional_wind * specific_humidity
    with time_block("calculate #2"):
        SEGMENT_LENGTH = 50
        early_segmentation = partial(
            __segment_data,
            year_index=sorted_indices[slice(None, 10)],
            reference_index=110,
            segment_length=SEGMENT_LENGTH,
        )
        late_segmentation = partial(
            __segment_data,
            year_index=sorted_indices[slice(-10, None)],
            reference_index=110,
            segment_length=SEGMENT_LENGTH,
        )
        seasonal_segmentation = partial(
            __segment_data,
            year_index=sorted_indices,
            reference_index=110,
            segment_length=SEGMENT_LENGTH,
        )
        symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions = (
            calculations.power_spectrum_Wheeler_Kiladis(
                variable,
                dimensions,
                cutoff_frequency=1 / SEGMENT_LENGTH,
                segment_length=SEGMENT_LENGTH,
                overlap_length=0,
                segment_method=seasonal_segmentation,
            )
        )
    return variable_name, symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions

In [16]:
if DO_PRELIMINARY_ANALYSIS_2:
    with time_block("Wavenumber-Frequency Diagram"):
        # variable_name, symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions = (
        #     WK_olr_unittest()
        # )
        # variable_name, symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions = (
        #     WK_mse_v_unittest()
        # )
        variable_name, symmetric_PSD, antisymmetric_PSD, background_PSD, dimensions = (
            WK_mse_v_seasonal_unittest()
        )

read #1 took 28.91 seconds.
calculate #2 took 7.19 seconds.
Wavenumber-Frequency Diagram took 36.11 seconds.


In [17]:
if DO_PRELIMINARY_ANALYSIS_2:
    with time_block("Plotting of Stochastic Power Spectra Density Ratio"):
        importlib.reload(plotter)
        # filter_wnfr = [(3.5, 0.15), 2, 0.05]
        # wk_filter = [(-4.5, 0.02), 1.5, 0.03]
        plotter.display_wavenumber_frequency_diagram(
            symmetric_PSD,
            antisymmetric_PSD,
            background_PSD,
            dimensions,
            variable_name=variable_name,
            # cmap="jet",
            dispersion_line_order="atypical",
        )

ModuleNotFoundError: No module named 'constants'

In [None]:
# filtered_symmetric_components, filtered_antisymmetric_components, _ = (
#     calculations.calculate_filtered_signal(
#         file_path=filepath,
#         zonal_wavenumber_limit=np.array(
#             [filter_wnfr[0][0], filter_wnfr[0][0] + filter_wnfr[1]]
#         ),
#         segmentation_frequency_limit=np.array(
#             [filter_wnfr[0][1], filter_wnfr[0][1] + filter_wnfr[2]]
#         ),
#         variable_name=varname,
#     )
# )