In [None]:
%matplotlib notebook

import functools

import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np
import xarray as xr

import lifetimes

In [None]:
def plot_gallery(images, height, width, rows=3, cols=4, titles=None):
    """Plot an iterable of images as tiles."""
    fig = plt.figure()
    for i in range(rows * cols):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(images[i].reshape((height, width)))
        if titles is not None:
            plt.title(titles[i])
        plt.xticks(())
        plt.yticks(())
    plt.tight_layout()
    return fig

def plot_pca_components_gallery(pca, height, width, rows, cols):
    titles = [f"EVR: {evr:.2f}" for evr in pca.explained_variance_ratio_]
    fig = plot_gallery(pca.components_, height=height, width=width, rows=rows, cols=cols, titles=titles)
    #return fig

def weight_by_longitude(data, longitude_dimension_name="longitude"):
    data_tmp = data.copy()
    data_tmp["longitudinal_weights"] = 1/np.cos(_degrees_to_radians(data_tmp[longitude_dimension_name]))
    data_tmp *= data_tmp["longitudinal_weights"]
    return data_tmp
    
def _degrees_to_radians(d):
    return np.radians(d)



In [None]:
# Create fake dataset of two temporarily variable elliptical data regions on a grid
width = 10
height = 10
dataset = lifetimes.testing.create_dummy_ecmwf_ifs_hres_dataset(
    grid_size=(width, height)
)
ds = dataset.as_xarray()

# Or load from local file
path = '/home/fabian/Documents/MAELSTROM/data/pca/ecmwf_ifs_hres_daily_temperature_averages_jan_dec_2020.nc'
ds = xr.open_dataset(path)

data = ds["t"]

# determing modes (perform spatio-temporal PCA)
modes = [lifetimes.modes.Modes(feature=data)]
pca_partial_method = functools.partial(
    lifetimes.modes.methods.spatio_temporal_principal_component_analysis,
    time_coordinate="time",
    latitude_coordinate="latitude",
)
[pca] = lifetimes.modes.determine_modes(modes=modes, method=pca_partial_method)


In [None]:
anim = lifetimes.plotting.animate_timeseries(data)

In [None]:
lifetimes.plotting.plot_scree_test(pca, variance_ratio=0.95)
n_components = 6

In [None]:
lifetimes.plotting.plot_first_three_components_timeseries(pca)

In [None]:
clusters = lifetimes.modes.methods.find_principal_component_clusters(
    pca, n_components=n_components, n_clusters=4,
)
lifetimes.plotting.plot_first_three_components_timeseries_clusters(clusters)

In [None]:
clusters.labels.plot()

In [None]:
data_weighted = lifetimes.utils.weight_by_latitudes(
    data, latitudes="latitude"
)
data_standardized = lifetimes.utils.standardize(data_weighted)

In [None]:
original_shape = pca.components_.reshape(data.shape)
original_shape.shape

In [None]:
data_reshaped = lifetimes.utils.reshape_spatio_temporal_numpy_array(
    data_standardized.values
)

In [None]:
anim = animate(ds["pcs"])
plt.show()

In [None]:
data_weighted = lifetimes.utils.weight_by_latitudes(data, latitudes="lat")
data_standardized = lifetimes.utils.standardize(data_weighted)

In [None]:
anim = animate(data_standardized)
plt.plot()