# testing MSSA

based on the tutorial:

_https://www.kaggle.com/jdarcy/introducing-ssa-for-time-series-decomposition#Decomposing-Time-Series-Data-With-Singular-Spectrum-Analysis_

and the MSSA package:

_https://github.com/kieferk/pymssa_

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from pymssa import MSSA
matplotlib.rc_file('../rc_file')
%matplotlib widget

# get it to work with xarray 

In [None]:
N = 200 # The number of time 'moments' in our toy series
t = np.arange(0,N)
trend = 0.001 * (t - 100)**2
p1, p2 = 20, 30
periodic1 = 2 * np.sin(2*np.pi*t/p1)
periodic2 = 0.75 * np.sin(2*np.pi*t/p2)

np.random.seed(123) # So we generate the same noisy time series every time.
noise = 2 * (np.random.rand(N) - 0.5)
F = trend + periodic1 + periodic2 + noise

# Plot everything
plt.plot(t, F, lw=2.5)
plt.plot(t, trend, alpha=0.75)
plt.plot(t, periodic1, alpha=0.75)
plt.plot(t, periodic2, alpha=0.75)
plt.plot(t, noise, alpha=0.5)
plt.legend(["Toy Series ($F$)", "Trend", "Periodic #1", "Periodic #2", "Noise"])
plt.xlabel("$t$")
plt.ylabel("$F(t)$")
plt.title("The Toy Time Series and its Components");

In [None]:
da = xr.DataArray(data=F, coords={'time':np.arange(200)}, dims='time')

In [None]:
mssa_F = MSSA(n_components=None, window_size=70, verbose=True)
mssa_F.fit(da)
print(np.shape(mssa_F.components_))

In [None]:
f, ax = plt.subplots(1,2)

for i in range(10):
    ax[0].plot(mssa_F.components_[0,:,i])
    
ax[1].imshow(mssa_F.w_correlation(mssa_F.components_[0, :, :]))

## making a test dataset

In [None]:
N = 11
T = 200

def construct_rotating_Gaussian(N, T):
    def Gaussian_2D(x, y, x0, y0, sigmax, sigmay):
        return np.exp(-((x-x0)/sigmax)**2) * np.exp(-((y-y0)/sigmay)**2)

    # coordinates
    X, Y = np.meshgrid([np.arange(N)/(N-1)*2-1], [np.arange(N)*2/(N-1)-1])
    R = np.sqrt(X**2+Y**2)

    f, ax = plt.subplots(1,3)
    for i, A in enumerate([X, Y, R]):
        im = ax[i].imshow(A, cmap=['RdBu', 'RdBu', 'plasma'][i])
        plt.colorbar(im, ax=ax[i], orientation='horizontal', fraction=0.25)    
    
    # rotating Gaussian
    A = np.zeros((T, N**2))
    for m, t in enumerate(np.linspace(0, 10*np.pi, T)):
        x0 = 0.5*np.sin(t)
        y0 = 0.5*np.cos(t)
        for i, xi in enumerate(X.flatten()):
            yi = Y.flatten()[i]
            A[m,i] = Gaussian_2D(x=xi, y=yi, x0=x0, y0=y0, sigmax=.5, sigmay=.5)
    A = A.reshape((T, N, N))

    f, ax = plt.subplots(1, N, figsize=(12,3))
    for i in range(N):
        ax[i].imshow(A[i*4,:,:])
        ax[i].axis('off')
    
    return A

A = construct_rotating_Gaussian(N=N, T=T)

In [None]:
%%time
def animate_function(A, frames):
    fig, ax = plt.subplots(1,1)
    im = plt.imshow(np.zeros((N,N)), vmin=0, vmax=1)
    def animate(i):
        im.set_array(A[i,:,:])
        return [im]

    return FuncAnimation(fig, animate, frames=frames)

ani = animate_function(A, frames=int(T/5))

In [None]:
%%time
ani

## xarray

In [None]:
def construct_dataarray(A):
    assert np.ndim(A)==3
    T = len(A[:,0,0])
    N = len(A[0,:,0])
    coords = {'time':np.arange(T), 'lat':np.arange(N), 'lon':np.arange(N)}
    dims = ('time','lat','lon')
    da = xr.DataArray(data=A, coords=coords, dims=dims)
    # stack lat and lon into a single dimension called allpoints
    return da, da.stack(allpoints=['lat', 'lon'])

In [None]:
da, stacked = construct_dataarray(A)
da

In [None]:
%%time
# 3:03 mins
mssa = MSSA(n_components=None, window_size=70, verbose=True)
mssa.fit(stacked)

In [None]:
np.shape(mssa.components_)

In [None]:
f, ax = plt.subplots(3,2, figsize=(12,12))
for j, k in enumerate([0, 10, 100]):
    for i in range(10):
        ax[j,0].plot(mssa.components_[k,:,i])
        ax[j,1].imshow(mssa.w_correlation(mssa.components_[k, :, :]))

In [None]:
plt.imshow(mssa.w_correlation(mssa.components_[100, :, :30]))

In [None]:
# grouping the components by hand
ts0_groups = []
for i in range(13):
    if i==0: ts0_groups.append([0])
    ts0_groups.append([2*i+1, 2*i+2])
    if i==12: ts0_groups.append(list(np.arange(27, 130)))
print(ts0_groups)

In [None]:
ts0_grouped = mssa.grouped_components_[0]
ts0_grouped.shape

In [None]:
f, ax = plt.subplots(3,2, figsize=(8,8))
for j, k in enumerate([0, 10, 100]):
    mssa.set_ts_component_groups(k, ts0_groups)
    for i in range(15):
        ax[j,0].plot(mssa.grouped_components_[k][:,i])
        ax[j,1].plot(np.sum(mssa.grouped_components_[k][:,:i], axis=1))
    ax[j,1].plot(stacked[:,k])
        

In [None]:
f, ax = plt.subplots(1,2)
ax[0].plot(mssa.explained_variance_)
ax[1].plot(mssa.explained_variance_ratio_)

## (automatic) grouping of components

In [None]:
N = 7
T = 200
A = construct_rotating_Gaussian(N=N, T=T)
da, stacked = construct_dataarray(A)

In [None]:
ani = animate_function(A, frames=int(T/5))

In [None]:
ani

## recovering of pattern
there are three algorithms implements to automatically group the components: `svht`, `parallel_analysis`, `variance_threshold`

In [None]:
%%time
mssa1 = MSSA(n_components='svht',
            window_size=70,
            verbose=True)
mssa1.fit(stacked)
print(mssa1.components_.shape)

In [None]:
%%time
mssa2 = MSSA(n_components='variance_threshold',
            variance_explained_threshold=0.95,
            window_size=None,
            verbose=True)
mssa2.fit(stacked)
print(mssa2.components_.shape)

In [None]:
%%time
mssa3 = MSSA(n_components='parallel_analysis',
            pa_percentile_threshold=95,
            window_size=70,
            verbose=True)
mssa3.fit(stacked)
print(mssa3.components_.shape)

In [None]:
def recover_pattern(mssa, stacked):
    da = stacked.copy()
    da.data = mssa.components_.sum(axis=2).T
    return da.unstack()

In [None]:
stacked

In [None]:
for m in [mssa1, mssa2, mssa3]:
    da = recover_pattern(m, stacked)
    n = 10
    f, ax = plt.subplots(1, n, figsize=(12,3))
    for i in range(n):
        ax[i].imshow(da[i*4,:,:])
        error = ((da[i*4,:,:]-A[i*4,:,:])**2).values.sum()
        ax[i].text(.05, .9, f'{error:2.1e}', transform=ax[i].transAxes, c='w')
        ax[i].axis('off')        

# testing performance

In [None]:
from datetime import datetime

### sensitivity to `N`

In [None]:
%%time
# 11:23 mins
T = 200
Ns, times = [], []
for N in [2, 3, 5, 7, 9, 11, 13]:
    A = construct_rotating_Gaussian(N=N, T=T)
    da, stacked = construct_dataarray(A)
    start = datetime.now()
    mssa = MSSA(n_components=None, window_size=70, verbose=False)
    mssa.fit(stacked)
    Ns.append(N**2)
    times.append((datetime.now()-start).total_seconds())
    print(N**2, datetime.now()-start)

In [None]:
fit = np.polyfit(Ns, times, 1)
plt.plot(Ns, times)
plt.plot(Ns, [fit[0]*n + fit[1] for n in Ns])