In [None]:
from __future__ import print_function
import os
import tensorflow as tf
import keras

config = tf.ConfigProto( device_count = {'GPU': 1 , 'CPU': 56} ) 
sess = tf.Session(config=config) 
keras.backend.set_session(sess)

from custom_model.layers_keras import *
from custom_model.model_keras import *
from custom_model.math_utils import *
from keras import metrics

import matplotlib.pyplot as plt
%matplotlib inline
import h5py
import numpy as np

import warnings
warnings.filterwarnings("ignore")

In [None]:
dt = 2
OBS = 30//dt-1
PRED = 20//dt

In [None]:
Data = h5py.File('Datasets/RRot_cc2_20.h5', 'r')
x_train = np.array(Data['Speed_obs_train'])
y_train = np.array(Data['Speed_pred_train'])
e_train = np.array(Data['E_train'])
x_test = np.array(Data['Speed_obs_test'])
y_test = np.array(Data['Speed_pred_test'])
e_test = np.array(Data['E_test'])

x_size = x_train.shape[1:]
y_size = y_train.shape[1:]

print('x_train shape:', x_train.shape)
print('y_train shape:', y_train.shape)
print('e_train shape:', e_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(y_train.shape[1:], 'output size')

x_train = x_train[:,-OBS:]
x_test = x_test[:,-OBS:]
y_train = y_train[:,:PRED]
e_train =e_train[:,:PRED]
y_test = y_test[:,:PRED]
e_test =e_test[:,:PRED]
#'''
X_train = z_score(x_train, np.mean(x_train), np.std(x_train))
X_test = z_score(x_test, np.mean(x_test), np.std(x_test))
Y_train = z_score(y_train, np.mean(y_train), np.std(y_train))
Y_test = z_score(x_test, np.mean(y_test), np.std(y_test))
E_train = z_score(e_train, np.mean(e_train), np.std(e_train))
E_test = z_score(e_test, np.mean(e_test), np.std(e_test))
#'''
print(x_train.shape, y_train.shape)

In [None]:
model = create_embed_model(obs_timesteps=OBS, pred_timesteps=PRED, nb_nodes=35, k=1)
model.summary()

In [None]:
# load pretrained models, optional
file=h5py.File('pretrained/cc2-k4.h5','r')
weight = []
for i in range(len(file.keys())):
    weight.append(file['weight'+str(i)][:])
model.set_weights(weight)

In [None]:
# Let's train the model
opt = keras.optimizers.Adam(lr=0.001, decay=1e-3)

model.compile(loss = rmse,
              optimizer=opt,
              metrics=['mae', 'mape'])

callbacks = [EarlyStopping(monitor='val_loss', patience=5),
             ModelCheckpoint(filepath='\pretrained\test.h5', save_weights_only=True, monitor='val_loss', save_best_only=True),
             ScheduledSampling(k=4)
            ]

history = model.fit([X_train,E_train], Y_train,
          epochs=256,
          batch_size=128,
          callbacks=callbacks,
          #validation_data = ([x_val,e_val], y_val), optional
          shuffle='batch',
          validation_split = 0.2
                   )

In [None]:
y = model.predict([X_test,E_test])
y =  z_inverse(y, np.mean(x_test), np.std(x_test))

for i in range(1,PRED+1):   
    print(MAE(y_test[:,:i]*120, y[:,:i]*120), ' ',100*MAPE(y_test[:,:i]*120, y[:,:i]*120), ' ',RMSE(y_test[:,:i]*120, y[:,:i]*120))
print('##########')
for i in range(PRED):
    print(MAE(y_test[:,i]*120, y[:,i]*120), ' ',100*MAPE(y_test[:,i]*120, y[:,i]*120), ' ',RMSE(y_test[:,i]*120, y[:,i]*120))

In [None]:
#save weights, in this way to avoid tensorflow version conflict
file = h5py.File('pretrained/cl5-k9.h5','w')
weight = model.get_weights()
for i in range(len(weight)):
    file.create_dataset('weight'+str(i),data=weight[i])
file.close()

In [None]:
# compare predictions and groundtruth
day = 2
j=5 # prediction steps
start = 150*(day-1)
end = start+150
gt = [y_test[i] for i in range(start,end,PRED)]
pred = [y[i][:j] for i in range(start,end,j)]

ground_truth = np.concatenate(gt, axis=0)
prediction = np.concatenate(pred, axis=0)


plt.figure(figsize = (20,7))
plt.imshow(ground_truth.squeeze().transpose(),aspect = 'auto')
plt.title('matrix')
plt.colorbar()

plt.figure(figsize = (20,7))
plt.imshow(prediction.squeeze().transpose(),aspect = 'auto')
plt.title('matrix')
plt.colorbar()

In [None]:
# save figs if necessary
from utils_vis import *
plot_figure(prediction, title1 = '10min prediction of DGCN on RotCC2 (k=0)', title2 = ' ', nb = 208, figtitle='k0.PNG', time=150, dt=2/60, color='autumn')