In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from ripser import ripser
from persim import plot_diagrams
from scipy.spatial.distance import pdist, squareform
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.svm import LinearSVC, SVC, SVR
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from gtda.time_series import TakensEmbedding
from PyEMD import EMD
from statsmodels.tsa.stattools import adfuller
from pylab import mpl

%matplotlib qt

In [4]:
more_data_path = "../dataset/top20.npy"
more_data = np.load(more_data_path)
selected_data = np.zeros((20,6,12))
selected_data[:, 0:2, :] = more_data[:, 0:2, :]
# print(selected_data[0,:,:])
selected_data[:, 2, :] = more_data[:, 2, :] + more_data[:, 3, :]
selected_data[:, 3:6, :] = more_data[:, 4:7, :]

In [9]:
# define trend-generation func
def generate_trend_matrix(statis_matrix: np.array, dot_prsv=2) -> np.array:
    assert len(statis_matrix.shape) == 2
    m, n = statis_matrix.shape
    trend_matrix = np.empty((m, n-1))
    for row in range(0, m):
        for col in range(0, n-1):
            if statis_matrix[row, col] != 0:
                trend_matrix[row, col] = round(statis_matrix[row, col+1] / statis_matrix[row, col], dot_prsv)
            else:
                trend_matrix[row, col] = 10
    return trend_matrix

enhanced_data = []
reshaped = selected_data.reshape(-1, 12)
for i in range(reshaped.shape[0]):
    d = reshaped[i, :]
    d = (d-np.mean(d)) / np.std(d)
    # ATTENTION
    # enhanced_data.append([[d[j], d[j+1], d[j+2]] for j in range(len(d)-3)])
    enhanced_data.append([[d[j], d[j+1], d[j+2]] for j in range(len(d)-2)])
enhanced_data = np.array(enhanced_data)
data_train = enhanced_data[:, :-1, :]
data_pred = enhanced_data[:, 1:, :]

In [10]:
feature_list_train = []
for i in range(data_train.shape[0]):
    r = ripser(data_train[i, :, :])
    feature = np.delete(r['dgms'][0], -1, axis=0)[:, 1]
    f = [feature.shape[0], np.sum(feature), np.mean(feature), np.std(feature), np.max(feature), np.min(feature), len([_ for _ in feature if _>0.5*np.max(feature)]), len([_ for _ in feature if _>np.mean(feature)])]
    f = [np.round(_, 2) for _ in f]
    feature_list_train.append(np.array(f))

feature_list_pred = []
for i in range(data_pred.shape[0]):
    r = ripser(data_pred[i, :, :])
    feature = np.delete(r['dgms'][0], -1, axis=0)[:, 1]
    f = [feature.shape[0], np.sum(feature), np.mean(feature), np.std(feature), np.max(feature), np.min(feature), len([_ for _ in feature if _>0.5*np.max(feature)]), len([_ for _ in feature if _>np.mean(feature)])]
    f = [np.round(_, 2) for _ in f]
    feature_list_pred.append(np.array(f))

In [14]:
trends = np.array([generate_trend_matrix(reshaped[i, :].reshape(1,-1)).squeeze() for i in range(reshaped.shape[0])])
trends = trends[:, -1].reshape(-1, 1)
feature_pd = pd.DataFrame(np.concatenate((np.array(feature_list_train), trends), axis=1), \
    columns=['Number of TDA barcode points', 'Sum of lifetime', 'Average of lifetime', 'Std of lifetime', 'Max of lifetime', \
        'Min of lifetime', 'Number of points bigger than 0.5*max', 'Number of points bigger than average', 'Trend'])

In [17]:
X = feature_pd[['Number of TDA barcode points', 'Sum of lifetime', 'Average of lifetime', 'Std of lifetime', 'Max of lifetime', \
        'Min of lifetime', 'Number of points bigger than 0.5*max', 'Number of points bigger than average']]
Y = feature_pd[['Trend']]

X = X.to_numpy()
Y = Y.to_numpy()

pred = []
for i in range(120):
    X_test = feature_list_pred[i]
    X_train = np.delete(X, i, axis=0)
    Y_train = np.delete(Y, i, axis=0)
    svr = SVR(C=1.0, epsilon=0.2, kernel='linear')
    svr.fit(X_train, np.ravel(Y_train))
    pred.append(float(svr.predict(X_test.reshape(1,-1))))

print(np.array(pred).reshape(-1,6))

[[0.77569966 0.83571667 0.62495079 0.64418745 0.90386382 0.88450766]
 [0.81187243 0.85812683 0.58386582 0.7432048  0.89266048 0.87599745]
 [0.8331817  0.81740484 0.73653958 0.78845862 0.91043914 0.86237466]
 [0.80263817 0.79246776 0.59442573 0.75456168 0.76217564 0.88885511]
 [0.86097902 0.86440656 0.60128859 0.72320102 0.75136294 0.90352131]
 [0.76890328 0.78641736 0.82374128 0.68147777 0.83070203 0.91551638]
 [0.82914931 0.80779966 0.82332658 0.90432293 0.88464515 0.79688935]
 [0.80200187 0.85075846 0.82645473 0.75630487 0.89483996 0.860208  ]
 [0.76386031 0.82912354 0.68031938 0.77066209 0.8623548  0.89374094]
 [0.78362542 0.86612149 0.71345049 0.67266642 0.90437083 0.89033003]
 [0.80521026 0.83022381 0.69925699 0.79761517 0.91951518 0.84913006]
 [0.86058703 0.81911029 0.76132269 0.84090674 0.83220646 0.83802161]
 [0.78796672 0.83502795 0.78211941 0.65822784 0.90017277 0.87235794]
 [0.82584882 0.84076009 0.68563121 0.68794788 0.88485247 0.92006976]
 [0.83995422 0.85240155 0.56175359