# Applying Kernel and Pooling on Fields

Apply a kernel and pooling on variables in a dataset to reduce the grid size.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import a6
import typing as t
import xarray as xr
import math
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import cartopy.crs as ccrs

In [None]:
data = xr.open_dataset(
    "/home/fabian/Documents/MAELSTROM/data/pca/pressure_level_500_950_daily_mean_2017_2020.nc"
).sel(level=500)
data

In [None]:
coordinates = a6.datasets.coordinates.Coordinates()
variables = a6.datasets.variables.Model()

preprocessing = (
    a6.datasets.methods.select.select_dwd_area(coordinates=coordinates)
    >> a6.features.methods.weighting.weight_by_latitudes(
        latitudes=coordinates.latitude,
        use_sqrt=True,
    )
    >> a6.features.methods.geopotential.calculate_geopotential_height(
        variables=variables,
    )
    >> a6.features.methods.wind.calculate_wind_speed(variables=variables)
    >> a6.features.methods.variables.drop_variables(
        names=[variables.z, variables.u, variables.v]
    )
)
data = preprocessing.apply_to(data)
data

In [None]:
d = data[variables.geopotential_height]
a6.plotting.plot_geopotential_height_contours(d.isel(time=0))
plt.show()

convolved = a6.features.methods.convolution.apply_kernel(
    d, kernel="mean", size=5, non_functional=True
)
a6.plotting.plot_2d_data(convolved[0], flip=True)
plt.show()

pooled = a6.features.methods.pooling.apply_pooling(
    convolved, mode="mean", size=5, non_functional=True
)
a6.plotting.plot_2d_data(pooled[0], flip=True)
plt.show()

In [None]:
var_ssrs = {}
sizes = list(range(5, 40, 2))
for var in data.data_vars:
    d = data[var]
    ssrs = [
        (
            a6.features.methods.convolution.apply_kernel(
                kernel="mean", size=size
            )
            >> a6.evaluation.residuals.calculate_normalized_root_ssr(y_true=d)
        ).apply_to(d)
        for size in sizes
    ]
    name = f"SSR({var})"
    var_ssrs[name] = xr.DataArray(
        ssrs, coords={"size": sizes}, dims=["size"], name=name
    )

ds = xr.Dataset(
    var_ssrs,
    coords={"size": sizes},
)

In [None]:
for var in ds.data_vars:
    ax = ds[var].plot(label=var)
plt.legend()
plt.ylabel("SSR")
plt.savefig("~/Documents/MAELSTROM/gwl/kernel-size-ssr.pdf")

In [None]:
ax[0].savefig