In [None]:
# Load autoreload extension for changes in `utils.py` to take effect even without a kernel restart + enable ipywidgets
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [None]:
import os
import pathlib
import xarray as xr
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import cartopy.feature as cfeature

from utils import (
    print_dataset_variables_summary,
    truncate_cmap,
    mark_and_store_points_onclick,
    Plotter, plot_figures,
    estimate_topography_from_dataset,
    get_swiss_projection, calculate_wind_uv,
    GIFWriter,
)

## Load data

In [None]:
# Specify the path to the directory containing .nc files
DATA_DIR_PATH = "kenda-ch1-eps_MDR_3D_d_2024112110"
data_dir = pathlib.Path(DATA_DIR_PATH)

In [None]:
ds_list = []
for filename in sorted(os.listdir(data_dir)):
    ds = xr.open_dataset(data_dir / filename)
    assert len(ds["time"]) == 1
    ds = ds.isel(time=0)

    ds = calculate_wind_uv(ds)
    ds["P"] = ds["P"] / 100
    ds["P"].attrs["units"] = "hPa"

    ds_list.append(ds)

print_dataset_variables_summary(ds_list[0])

In [None]:
plotter = Plotter(ds_list)

## Plot topography
The simulated data will be analyzed through vertical cross-sections in the 3D model domain over Switzerland.

To define the transect for cross-sections, **choose two points** on the topography plot.

In [None]:
coord_label = widgets.Label(value="Click two points...")
display(coord_label)

fig, ax = plt.subplots(figsize=(12, 6), subplot_kw={'projection': get_swiss_projection()})
cmap_terrain = truncate_cmap(plt.cm.terrain, start=0.25, stop=0.75)
topography = estimate_topography_from_dataset(ds_list[0])
topography.plot(x="x_1", y="y_1", cmap=cmap_terrain, ax=ax)

# country boundaries
ax.add_feature(cfeature.BORDERS, linewidth=1)
ax.add_feature(cfeature.LAKES, alpha=0.5)

plt.title("Topography over KENDA-CH1-EPS domain")
images_dir = pathlib.Path(os.path.abspath(".")) / "images"
images_dir.mkdir(exist_ok=True, parents=True)
plt.savefig(images_dir / "topography.png")

selected_points = []
markers = []  # store marker artists
line = None  # store line artist


def onclick(event):
    global selected_points, markers, line

    if event.inaxes != ax:
        return

    selected_points, markers, line = mark_and_store_points_onclick(event, coord_label, ds_list[0], markers, line,
                                                                   selected_points)
    fig.canvas.draw_idle()


fig.canvas.mpl_connect("button_press_event", onclick)
plt.show()

## Plot variables on map

In [None]:
figures = []
gif_writer = GIFWriter()

for ds in ds_list:
    fig = plotter.plot_cross_section_across_dim(ds, "z_1", 2500,
                                                # change variables here
                                                var="QC", var_contour=None, with_wind=False,
                                                gif_writer=gif_writer)
    figures.append(fig)

gif_writer.build_gif()
plot_figures(figures)

## Plot vertical cross-sections

In [None]:
figures = []
gif_writer = GIFWriter()

for ds in ds_list:
    fig = plotter.plot_cross_section_between_two_points(ds, selected_points[0], selected_points[1],
                                                        # change variables here
                                                        var="RELHUM", var_contour="THETA", with_wind=False,
                                                        gif_writer=gif_writer)
    figures.append(fig)

gif_writer.build_gif()
plot_figures(figures)

## Generate and save images for all variables

In [None]:
for var in ds_list[0].data_vars:
    if var in ["grid_mapping_1", "DD", "U", "V"]:
        continue

    with_wind = var == "FF"
    for z in [1500, 2500, 5000]:
        figures = []
        gif_writer = GIFWriter()

        for ds in ds_list:
            fig = plotter.plot_cross_section_across_dim(ds, "z_1", z,
                                                        # change variables here
                                                        var=var, var_contour=None, with_wind=with_wind,
                                                        gif_writer=gif_writer)
            figures.append(fig)

        gif_writer.build_gif()

In [None]:
for var in ds_list[0].data_vars:
    if var in ["grid_mapping_1", "DD", "U", "V"] or "z_1" not in ds_list[0][var].dims:
        continue

    with_wind = var == "FF"
    var_contour = "THETA" if var != "THETA" and not with_wind else None

    figures = []
    gif_writer = GIFWriter()

    for ds in ds_list:
        fig = plotter.plot_cross_section_between_two_points(ds, selected_points[0], selected_points[1],
                                                            # change variables here
                                                            var=var, var_contour=var_contour, with_wind=with_wind,
                                                            gif_writer=gif_writer)
        figures.append(fig)

    gif_writer.build_gif()