# Application of Fourier Flows to Real Data

In this notebook, we show how to train a Fourier Flow model using the setup in the second experiment (Section 5.2) of our paper "Generative Time-series Modeling with Fourier Flows" published in ICLR 2021 by Ahmed M. Alaa, Alex Chan and Mihaela van der Schaar. 

In [4]:
from __future__ import absolute_import, division, print_function
from data_loading import real_data_loading
import numpy as np
import pickle

from utils.spectral import *
from SequentialFlows import *

from metrics.PRcurve import *
from metrics.MAE import *

from matplotlib import pyplot as plt
%matplotlib inline
import seaborn as sns

### Example: training on the stocks data set

In [18]:
# Options for data_name = ["stock", "energy", "lung"]
# -

T         = 100
data_name = "stock"

X         = real_data_loading(data_name=data_name, seq_len=T)

In [19]:
FF_model  = FourierFlow(hidden=200, fft_size=T + 1, n_flows=5, normalize=True) 

In [20]:
FF_losses = FF_model.fit(X, epochs=500, batch_size=500, 
                         learning_rate=1e-3, display_step=100)

step: 0 	/ 500 	-----	loss: 179.473
step: 100 	/ 500 	|----	loss: 42.948
step: 200 	/ 500 	||---	loss: 30.412
step: 300 	/ 500 	|||--	loss: 27.624
step: 400 	/ 500 	||||-	loss: 23.632
step: 499 	/ 500 	|||||	loss: 22.163
Finished training!


In [21]:
X_synthetic = FF_model.sample(10000)

In [22]:
computeF1(X, X_synthetic) 

0.9916293741563228

In [None]:
computeMAE(X, X_synthetic)

Epoch:  0 | train loss: 0.1324
Epoch:  1 | train loss: 0.0016
Epoch:  2 | train loss: 0.0012
Epoch:  3 | train loss: 0.0007
Epoch:  4 | train loss: 0.0007


### Save model

In [None]:
pickle.dump(FF_model, open("results/FF_model_stock.p", "wb"))

### Notes on error bars and robustness of evaluation

Because the training/testing paradigm does not apply to generative models, we cannot use "cross-validation" to evaluate variations in model performance as we do in predictive models. There two ways where we can assess the stability of the results. The first is by sampling a large number of synthetic samples and evaluate the convergence of the performance metric $M$, i.e. $\lim_{N\to \infty} M$. This gets rid of variations due to the finite synthetic sample. To evaluate variations due to the finite training sample, we can train multiple models and average the performance metric, $\mathbb{E}[M]$, as shown below.

In [23]:
FF_models = []
num_runs  = 5

for k in range(num_runs):
    
    print("Experiment number: ", k)
    FF_model_  = FourierFlow(hidden=200, fft_size=T + 1, n_flows=5, normalize=True) 
    _          = FF_model_.fit(X, epochs=500, batch_size=500, 
                               learning_rate=1e-3, display_step=100)
    
    FF_models.append(FF_model_)
    print("F1 score", computeF1(X, FF_model_.sample(10000)))

Experiment number:  0
step: 0 	/ 500 	-----	loss: 192.278
step: 100 	/ 500 	|----	loss: 52.626
step: 200 	/ 500 	||---	loss: 33.135
step: 300 	/ 500 	|||--	loss: 30.056
step: 400 	/ 500 	||||-	loss: 26.186
step: 499 	/ 500 	|||||	loss: 26.628
Finished training!
F1 score 0.9904076736916793
Experiment number:  1
step: 0 	/ 500 	-----	loss: 196.948
step: 100 	/ 500 	|----	loss: 49.698
step: 200 	/ 500 	||---	loss: 32.530
step: 300 	/ 500 	|||--	loss: 27.203
step: 400 	/ 500 	||||-	loss: 24.339
step: 499 	/ 500 	|||||	loss: 32.403
Finished training!
F1 score 0.989612473934336
Experiment number:  2
step: 0 	/ 500 	-----	loss: 182.618
step: 100 	/ 500 	|----	loss: 48.284
step: 200 	/ 500 	||---	loss: 30.790
step: 300 	/ 500 	|||--	loss: 26.690
step: 400 	/ 500 	||||-	loss: 23.271
step: 499 	/ 500 	|||||	loss: 21.366
Finished training!
F1 score 0.9937573102036994
Experiment number:  3
step: 0 	/ 500 	-----	loss: 190.327
step: 100 	/ 500 	|----	loss: 50.147
step: 200 	/ 500 	||---	loss: 29.676

In [24]:
X_synthetic_samples = [FF_models[k].sample(10000) for k in range(len(FF_models))]

In [25]:
F1_scores           = np.array([computeF1(X, X_synthetic_samples[k]) for k in range(len(X_synthetic_samples))])

(np.mean(F1_scores), np.std(F1_scores))

(0.9851363855153196, 0.010091262499501728)