In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import a6
import typing as t
import xarray as xr
import math
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import cartopy.crs as ccrs

In [None]:
path = pathlib.Path("/home/fabian/Documents/MAELSTROM/gwl")

gwl = xr.open_dataset("../data/gwl.nc")
modes = a6.modes.methods.determine_lifetimes_of_modes(gwl["GWL"])
scores = xr.open_dataset("../data/scores.nc")

data = xr.open_dataset(
    "/home/fabian/Documents/MAELSTROM/data/pca/pressure_level_500_950_daily_mean_2017_2020.nc"
).sel(level=500)

In [None]:
data["z_h"] = a6.features.methods.geopotential.calculate_geopotential_height(
    data["z"]
)

In [None]:
d = data.isel(time=0)["z_h"]
d = a6.datasets.methods.select.select_dwd_area(d)
a6.plotting.plot_geopotential_height_contours(d)
plt.show()

convolved = a6.features.methods.convolution.apply_kernel(
    d, kernel="mean", size=10
)
a6.plotting.plot_geopotential_height_contours(d.copy(data=convolved, deep=True))
plt.show()

pooled = a6.features.methods.pooling.apply_pooling(
    convolved, mode="mean", size=10
)
a6.plotting.plot_2d_data(pooled, flip=True)
plt.show()

In [None]:
d = data.isel(time=0)
# d = a6.datasets.methods.select.select_dwd_area(d)
ssrs = []
sizes = list(range(5, 40, 5))
for size in sizes:
    convolved = a6.features.methods.convolution.apply_kernel(
        d, kernel="gaussian", size=size, sigma=10
    )
    ssrs.append(
        a6.evaluation.residuals.calculate_normalized_root_ssr(d, convolved)
    )
xr.DataArray(ssrs, coords={"size": sizes}).plot()

In [None]:
var_ssrs = {}
sizes = list(range(5, 40, 1))
for var in data.isel(time=0).data_vars:
    d = data.isel(time=0)[var]
    d = a6.datasets.methods.select.select_dwd_area(d)
    ssrs = []
    for size in sizes:
        convolved = a6.features.methods.convolution.apply_kernel(
            d, kernel="mean", size=size
        )
        ssrs.append(
            a6.evaluation.residuals.calculate_normalized_root_ssr(d, convolved)
        )
    name = f"SSR({var})"
    var_ssrs[name] = xr.DataArray(
        ssrs, coords={"size": sizes}, dims=["size"], name=name
    )

ds = xr.Dataset(
    var_ssrs,
    coords={"size": sizes},
)

In [None]:
for var in ds.data_vars:
    ds[var].plot(label=var)
plt.legend()
plt.ylabel("SSR")
plt.show()

In [None]:
fig, _ = a6.plotting.plot_modes_durations(modes)
fig.savefig(path / "gwls.pdf")

In [None]:
scores_per_mode = a6.evaluation.modes.evaluate_scores_per_mode(
    modes, scores=[scores]
)
scores_per_mode

In [None]:
for mode in modes:
    datetimes = list(
        mode.get_dates(start=data["time"][0], end=data["time"][-1])
    )
    fig, _ = a6.plotting.plot_combined(
        data=data,
        dates=datetimes,
    )
    fig.savefig(path / f"gwl_{mode.label}.png")