In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.manifold import TSNE
import tensorflow as tf
from ts2vec import TS2Vec
import datautils
from tools import MinMaxScaler
import pickle
from scipy.linalg import sqrtm

In [2]:
dataset = 'sine_cpx'
full_train_data = np.load('../datasets/'+dataset+'.npy')
N, T, D = full_train_data.shape
valid_perc = 0.0

In [3]:
N_train = int(N * (1 - valid_perc))
N_valid = N - N_train
np.random.shuffle(full_train_data)
train_data = full_train_data[:N_train]
valid_data = full_train_data[N_train:]
scaler = MinMaxScaler()        
x_train = scaler.fit_transform(train_data)
x_valid = scaler.transform(valid_data)

In [4]:
x_gen = np.load('../save_model/gen_'+ dataset + '.npy')

In [5]:
def calculate_fid(act1, act2):
    # calculate mean and covariance statistics
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [6]:
config = dict(
        batch_size=8,
        lr=0.001,
        output_dims=320,
        max_train_length=3000
    )

In [7]:
fid_s = []
for i in range(5):
    model = TS2Vec(
        input_dims=x_train.shape[-1],
        device=0,
        **config
    )
    model.fit(x_train, verbose=False)
    ori_repr = model.encode(x_train, encoding_window='full_series')
    gen_repr = model.encode(x_gen, encoding_window='full_series')
    select = x_gen.shape[0]
    idx = np.random.permutation(select)
    ori = ori_repr[idx]
    gen = gen_repr[idx]
    fid_s.append(calculate_fid(ori, gen))
print("Avg:{}\xB1{}".format(np.mean(fid_s), 1.96*(np.std(fid_s)/np.sqrt(len(fid_s)))))

Avg:1.8230335689306258±0.08929151793637233
