In [None]:
%matplotlib notebook

import functools

import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np
import xarray as xr

import lifetimes


def animate(data: xr.DataArray, time_coordinate: str = "time") -> None:
    times = np.datetime_as_string(data[time_coordinate])
    fig = plt.figure()
    im_plot = plt.imshow(data.isel({time_coordinate: 0}).values, vmin=np.min(data), vmax=np.max(data))
    ax = plt.gca()
    title = ax.text(0.5, 1.100, f"time step {times[0]}", transform=ax.transAxes, ha="center")

    def animation(i):
        im_plot.set_data(data.isel({time_coordinate: i}).values)
        title.set_text(f"time step {times[i]}")
        return [im_plot]

    anim = matplotlib.animation.FuncAnimation(fig, animation, frames=len(data[time_coordinate]))
    plt.show()
    return anim

In [None]:
def plot_gallery(images, height, width, rows=3, cols=4, titles=None):
    """Plot an iterable of images as tiles."""
    fig = plt.figure()
    for i in range(rows * cols):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(images[i].reshape((height, width)))
        if titles is not None:
            plt.title(titles[i])
        plt.xticks(())
        plt.yticks(())
    plt.tight_layout()
    return fig

def plot_pca_components_gallery(pca, height, width, rows, cols):
    titles = [f"EVR: {evr:.2f}" for evr in pca.explained_variance_ratio_]
    fig = plot_gallery(pca.components_, height=height, width=width, rows=rows, cols=cols, titles=titles)
    #return fig

def weight_by_longitude(data, longitude_dimension_name="longitude"):
    data_tmp = data.copy()
    data_tmp["longitudinal_weights"] = 1/np.cos(_degrees_to_radians(data_tmp[longitude_dimension_name]))
    data_tmp *= data_tmp["longitudinal_weights"]
    return data_tmp
    
def _degrees_to_radians(d):
    return np.radians(d)



In [None]:
# Create fake dataset of two temporarily variable elliptical data regions on a grid
width = 10
height = 10
dataset = lifetimes.testing.create_dummy_ecmwf_ifs_hres_dataset(
    grid_size=(width, height)
)
ds = dataset.as_xarray()

# Or load from local file
#path = '/home/fabian/Documents/MAELSTROM/data/ml_20190101_00.nc'
#ds = xr.open_dataset(path).isel(level=0)

data = ds["ellipse"]

# determing modes (perform spatio-temporal PCA)
modes = [lifetimes.modes.Modes(feature=data)]
pca_partial_method = functools.partial(
    lifetimes.modes.methods.spatio_temporal_principal_component_analysis,
    time_coordinate="time",
    x_coordinate=None,
    y_coordinate=None,
    variance_ratio=None,
)
[pca] = lifetimes.modes.determine_modes(modes=modes, method=pca_partial_method)


In [None]:
data

In [None]:
anim = animate(data)

In [None]:
plot_pca_components_gallery(pca, width, height, cols = 2, rows=2)

In [None]:
import sklearn.decomposition
rotated = sklearn.decomposition._factor_analysis._ortho_rotation(pca.components_.T)

In [None]:
plot_gallery(rotated, width, height, rows=2, cols=2)

In [None]:
original_shape = pca.components_.reshape(data.shape)
original_shape.shape

In [None]:
ds["pcs"] = (["time", "lat", "lon"], original_shape)
pca.components_.shape

In [None]:
anim = animate(ds["pcs"])
plt.show()

In [None]:
data_weighted = lifetimes.utils.weight_by_latitudes(data, latitudes="lat")
data_standardized = lifetimes.utils.standardize(data_weighted)

In [None]:
anim = animate(data_standardized)
plt.plot()