# transfer learning in traffic flow prediction

# import and configuration

In [68]:
import h5py
import time
from keras.layers import (
    Input,
    Activation,
    merge,
    Dense,
    Reshape
)
from copy import copy
from keras.layers.convolutional import Convolution2D
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.utils.vis_utils import plot_model as plot
from keras import backend as K
from keras.engine.topology import Layer
# from keras.layers import Dense
import numpy as np
from datetime import datetime, timedelta
import os
import pandas as pd
import cPickle as pickle
import math
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint
np.random.seed(1337)  # for reproducibility
        
class config():
    def __init__(self):
        self.DATAPATH='../Data/CrowdFlows'
        self.nb_epoch = 500  # number of epoch at training stage
        self.nb_epoch_cont = 100  # number of epoch at training (cont) stage
        self.batch_size = 32  # batch size
        self.lr = 0.0002  # learning rate
        self.nb_flow = 2  # there are two types of flows: new-flow and end-flow
        self.meta_data = True
        self.holiday_data = False
        self.meteorol_data = False
    def NYC(self):
        self.T = 24  # number of time intervals in one day
        self.len_closeness = 3  # length of closeness dependent sequence
        self.len_period = 4  # length of peroid dependent sequence
        self.len_trend = 4  # length of trend dependent sequence
        self.nb_residual_unit = 4   # number of residual units

        # divide data into two subsets: Train & Test, of which the test set is the
        # last 10 days
        self.days_test = 10
        self.len_test = self.T * self.days_test
        self.map_height, self.map_width = 16, 8  # grid size
        # For NYC Bike data, there are 81 available grid-based areas, each of
        # which includes at least ONE bike station. Therefore, we modify the final
        # RMSE by multiplying the following factor (i.e., factor).
        self.nb_area = 81
        self.m_factor = math.sqrt(1. * self.map_height * self.map_width / self.nb_area)
        self.DIR = 'BikeNYC'
        self.file = ['NYC14_M16x8_T60_NewEnd.h5']
        return self  
    def BJ(self):
        self.T = 48  # number of time intervals in one day
        self.len_closeness = 3  # length of closeness dependent sequence
        self.len_period = 1  # length of peroid dependent sequence
        self.len_trend = 1  # length of trend dependent sequence
        self.nb_residual_unit = 4
        self.days_test = 7 * 4
        self.len_test = self.T * self.days_test
        self.map_height, self.map_width = 32, 32  # grid size
        self.m_factor = 1
        self.DIR = 'TaxiBJ'
        self.file = ['BJ13_M32x32_T30_InOut.h5','BJ14_M32x32_T30_InOut.h5','BJ15_M32x32_T30_InOut.h5','BJ16_M32x32_T30_InOut.h5']
        self.holiday_data = True
        self.holiday_file = 'BJ_Holiday.txt'
        self.meteorol_data = True
        self.meteorol_file = 'BJ_Meteorology.h5'
        return self

print("load a configuration")
CONFIG = config().BJ()
CACHEDATA = True
path_result = 'RET'
path_model = 'MODEL'
try:
    K.set_image_data_format("channels_first")
except Exception,e:
    pass

if os.path.isdir(path_result) is False:
    os.mkdir(path_result)
if os.path.isdir(path_model) is False:
    os.mkdir(path_model)

load a configuration


# Preprocessing

Define Util Functions

In [69]:
def string2timestamp(strings, T=48):
    timestamps = []

    time_per_slot = 24.0 / T
    num_per_T = T // 24
    for t in strings:
        year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:])-1
        timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot), minute=(slot % num_per_T) * int(60.0 * time_per_slot))))

    return timestamps

def load_stdata(fname):
    f = h5py.File(fname, 'r')
    data = f['data'].value
    timestamps = f['date'].value
    f.close()
    return data, timestamps

def load_holiday(timeslots, fname):
    f = open(fname, 'r')
    holidays = f.readlines()
    holidays = set([h.strip() for h in holidays])
    H = np.zeros(len(timeslots))
    for i, slot in enumerate(timeslots):
        if slot[:8] in holidays:
            H[i] = 1
    print(H.sum())
    # print(timeslots[H==1])
    return H[:, None]

def load_meteorol(timeslots, fname):
    '''
    timeslots: the predicted timeslots
    In real-world, we dont have the meteorol data in the predicted timeslot, instead, we use the meteoral at previous timeslots, i.e., slot = predicted_slot - timeslot (you can use predicted meteorol data as well)
    '''
    f = h5py.File(fname, 'r')
    Timeslot = f['date'].value
    WindSpeed = f['WindSpeed'].value
    Weather = f['Weather'].value
    Temperature = f['Temperature'].value
    f.close()

    M = dict()  # map timeslot to index
    for i, slot in enumerate(Timeslot):
        M[slot] = i

    WS = []  # WindSpeed
    WR = []  # Weather
    TE = []  # Temperature
    for slot in timeslots:
        predicted_id = M[slot]
        cur_id = predicted_id - 1
        WS.append(WindSpeed[cur_id])
        WR.append(Weather[cur_id])
        TE.append(Temperature[cur_id])

    WS = np.asarray(WS)
    WR = np.asarray(WR)
    TE = np.asarray(TE)

    # 0-1 scale
    WS = 1. * (WS - WS.min()) / (WS.max() - WS.min())
    TE = 1. * (TE - TE.min()) / (TE.max() - TE.min())

    print("shape: ", WS.shape, WR.shape, TE.shape)

    # concatenate all these attributes
    merge_data = np.hstack([WR, WS[:, None], TE[:, None]])

    # print('meger shape:', merge_data.shape)
    return merge_data

def remove_incomplete_days(data, timestamps, T=48):
    # remove a certain day which has not 48 timestamps
    days = []  # available days: some day only contain some seqs
    days_incomplete = []
    i = 0
    while i < len(timestamps):
        if int(timestamps[i][8:]) != 1:
            i += 1
        elif i+T-1 < len(timestamps) and int(timestamps[i+T-1][8:]) == T:
            days.append(timestamps[i][:8])
            i += T
        else:
            days_incomplete.append(timestamps[i][:8])
            i += 1
    print("incomplete days: ", days_incomplete)
    days = set(days)
    idx = []
    for i, t in enumerate(timestamps):
        if t[:8] in days:
            idx.append(i)

    data = data[idx]
    timestamps = [timestamps[i] for i in idx]
    return data, timestamps

def stat(fname):
    def get_nb_timeslot(f):
        s = f['date'][0]
        e = f['date'][-1]
        year, month, day = map(int, [s[:4], s[4:6], s[6:8]])
        ts = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d")
        year, month, day = map(int, [e[:4], e[4:6], e[6:8]])
        te = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d")
        nb_timeslot = (time.mktime(te) - time.mktime(ts)) / (0.5 * 3600) + 48
        ts_str, te_str = time.strftime("%Y-%m-%d", ts), time.strftime("%Y-%m-%d", te)
        return nb_timeslot, ts_str, te_str

    with h5py.File(fname) as f:
        nb_timeslot, ts_str, te_str = get_nb_timeslot(f)
        nb_day = int(nb_timeslot / 48)
        mmax = f['data'].value.max()
        mmin = f['data'].value.min()
        stat = '=' * 5 + 'stat' + '=' * 5 + '\n' + \
               'data shape: %s\n' % str(f['data'].shape) + \
               '# of days: %i, from %s to %s\n' % (nb_day, ts_str, te_str) + \
               '# of timeslots: %i\n' % int(nb_timeslot) + \
               '# of timeslots (available): %i\n' % f['date'].shape[0] + \
               'missing ratio of timeslots: %.1f%%\n' % ((1. - float(f['date'].shape[0] / nb_timeslot)) * 100) + \
               'max: %.3f, min: %.3f\n' % (mmax, mmin) + \
               '=' * 5 + 'stat' + '=' * 5
        print(stat)
        
def timestamp2vec(timestamps):
    # tm_wday range [0, 6], Monday is 0
    # vec = [time.strptime(str(t[:8], encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps]  # python3
    vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps]  # python2
    ret = []
    for i in vec:
        v = [0 for _ in range(7)]
        v[i] = 1
        if i >= 5: 
            v.append(0)  # weekend
        else:
            v.append(1)  # weekday
        ret.append(v)
    return np.asarray(ret)

def read_cache(fname):
    mmn = pickle.load(open('preprocessing.pkl', 'rb'))

    f = h5py.File(fname, 'r')
    num = int(f['num'].value)
    X_train, Y_train, X_test, Y_test = [], [], [], []
    for i in xrange(num):
        X_train.append(f['X_train_%i' % i].value)
        X_test.append(f['X_test_%i' % i].value)
    Y_train = f['Y_train'].value
    Y_test = f['Y_test'].value
    external_dim = f['external_dim'].value
    timestamp_train = f['T_train'].value
    timestamp_test = f['T_test'].value
    f.close()

    return X_train, Y_train, X_test, Y_test, mmn, external_dim, timestamp_train, timestamp_test


def cache(fname, X_train, Y_train, X_test, Y_test, external_dim, timestamp_train, timestamp_test):
    h5 = h5py.File(fname, 'w')
    h5.create_dataset('num', data=len(X_train))

    for i, data in enumerate(X_train):
        h5.create_dataset('X_train_%i' % i, data=data)
    # for i, data in enumerate(Y_train):
    for i, data in enumerate(X_test):
        h5.create_dataset('X_test_%i' % i, data=data)
    h5.create_dataset('Y_train', data=Y_train)
    h5.create_dataset('Y_test', data=Y_test)
    external_dim = -1 if external_dim is None else int(external_dim)
    h5.create_dataset('external_dim', data=external_dim)
    h5.create_dataset('T_train', data=timestamp_train)
    h5.create_dataset('T_test', data=timestamp_test)
    h5.close()

define normalizer

In [70]:
class MinMaxNormalization(object):
    '''MinMax Normalization --> [-1, 1]
       x = (x - min) / (max - min).
       x = x * 2 - 1
    '''
    def __init__(self):
        pass

    def fit(self, X):
        self._min = X.min()
        self._max = X.max()
        print("min:", self._min, "max:", self._max)

    def transform(self, X):
        X = 1. * (X - self._min) / (self._max - self._min)
        X = X * 2. - 1.
        return X

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def inverse_transform(self, X):
        X = (X + 1.) / 2.
        X = 1. * X * (self._max - self._min) + self._min
        return X

define a data structure for ST data

In [56]:
class STMatrix(object):
    """docstring for STMatrix"""

    def __init__(self, data, timestamps, T=48, CheckComplete=True):
        super(STMatrix, self).__init__()
        assert len(data) == len(timestamps)
        self.data = data
        self.timestamps = timestamps
        self.T = T
        self.pd_timestamps = string2timestamp(timestamps, T=self.T)
        if CheckComplete:
            self.check_complete()
        # index
        self.make_index()

    def make_index(self):
        self.get_index = dict()
        for i, ts in enumerate(self.pd_timestamps):
            self.get_index[ts] = i

    def check_complete(self):
        missing_timestamps = []
        offset = pd.DateOffset(minutes=24 * 60 // self.T)
        pd_timestamps = self.pd_timestamps
        i = 1
        while i < len(pd_timestamps):
            if pd_timestamps[i-1] + offset != pd_timestamps[i]:
                missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i]))
            i += 1
        for v in missing_timestamps:
            print(v)
        assert len(missing_timestamps) == 0

    def get_matrix(self, timestamp):
        return self.data[self.get_index[timestamp]]

    def save(self, fname):
        pass

    def check_it(self, depends):
        for d in depends:
            if d not in self.get_index.keys():
                return False
        return True

    def create_dataset(self, len_closeness=3, len_trend=3, TrendInterval=7, len_period=3, PeriodInterval=1):
        """current version
        """
        # offset_frame: one-frame offset
        offset_frame = pd.DateOffset(minutes=24 * 60 // self.T)
        XC = []
        XP = []
        XT = []
        Y = []
        timestamps_Y = []
        depends = [range(1, len_closeness+1),
                   [PeriodInterval * self.T * j for j in range(1, len_period+1)],
                   [TrendInterval * self.T * j for j in range(1, len_trend+1)]]

        i = max(self.T * TrendInterval * len_trend, self.T * PeriodInterval * len_period, len_closeness)
        while i < len(self.pd_timestamps):
            Flag = True
            for depend in depends:
                if Flag is False:
                    break
                Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])

            if Flag is False:
                i += 1
                continue
            x_c = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[0]]
            x_p = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[1]]
            x_t = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[2]]
            y = self.get_matrix(self.pd_timestamps[i])
            if len_closeness > 0:
                XC.append(np.vstack(x_c))
            if len_period > 0:
                XP.append(np.vstack(x_p))
            if len_trend > 0:
                XT.append(np.vstack(x_t))
            Y.append(y)
            timestamps_Y.append(self.timestamps[i])
            i += 1
        XC = np.asarray(XC)
        XP = np.asarray(XP)
        XT = np.asarray(XT)
        Y = np.asarray(Y)
        return XC, XP, XT, Y, timestamps_Y

define load data function
return X_train, Y_train, X_test, Y_test, mmn, metadata_dim, timestamp_train, timestamp_test

X_train:[XC_train,XP_train,XT_train,meta_train]
X_test:[XC_test,XP_test,XT_test,meta_test]
XCPT_train/test.shape:(# of seq, len of seq, 16, 8)
meta:(# of seq, 8), [:7]:week of day; [7]:is weekday
Y_train,Y_test:(#of seq, 2, 16, 8)

mmn: minmax normalizer

external_dim: 8
timestamp_train, timestamp_test: [2014042901...]

In [57]:
def load_data(T=48, nb_flow=2, len_closeness=None, len_period=None, len_trend=None,
              len_test=None, preprocess_name='preprocessing.pkl',
              meta_data=True, meteorol_data=True, holiday_data=True):
    assert(len_closeness + len_period + len_trend > 0)
    data_all = []
    timestamps_all = list()
    for f in CONFIG.file:
        fname = os.path.join(CONFIG.DATAPATH, CONFIG.DIR, f)
        print("file name: ", fname)
        stat(fname)
        data, timestamps = load_stdata(fname)
        data, timestamps = remove_incomplete_days(data, timestamps, T)
        data = data[:, :nb_flow]
        data[data < 0] = 0.
        data_all.append(data)
        timestamps_all.append(timestamps)
        print("\n")
    data_train = np.vstack(copy(data_all))[:-len_test]

    print('train_data shape: ', data_train.shape)
    mmn = MinMaxNormalization()
    mmn.fit(data_train)
    data_all_mmn = [mmn.transform(d) for d in data_all]

    fpkl = open('preprocessing.pkl', 'wb')
    for obj in [mmn]:
        pickle.dump(obj, fpkl)
    fpkl.close()

    #generate feature sequences
    XC, XP, XT = [], [], []
    Y = []
    timestamps_Y = []
    for data, timestamps in zip(data_all_mmn, timestamps_all):
        # instance-based dataset --> sequences with format as (X, Y) where X is a sequence of images and Y is an image.
        st = STMatrix(data, timestamps, T, CheckComplete=False)
        _XC, _XP, _XT, _Y, _timestamps_Y = st.create_dataset(len_closeness=len_closeness, len_period=len_period, len_trend=len_trend)
        XC.append(_XC)
        XP.append(_XP)
        XT.append(_XT)
        Y.append(_Y)
        timestamps_Y += _timestamps_Y
    
    # load meta feature
    meta_feature = []
    if CONFIG.meta_data:
        time_feature = timestamp2vec(timestamps_Y)
        meta_feature.append(time_feature)
    if CONFIG.holiday_data:
        holiday_feature = load_holiday(timestamps_Y,fname=os.path.join(CONFIG.DATAPATH, CONFIG.DIR, CONFIG.holiday_file))
        meta_feature.append(holiday_feature)
    if meteorol_data:
        meteorol_feature = load_meteorol(timestamps_Y,fname=os.path.join(CONFIG.DATAPATH, CONFIG.DIR, CONFIG.meteorol_file))
        meta_feature.append(meteorol_feature)
   
    meta_feature = np.hstack(meta_feature) if len(meta_feature) > 0 else np.asarray(meta_feature)
    metadata_dim = meta_feature.shape[1] if len(meta_feature.shape) > 1 else None
    if metadata_dim < 1:
        metadata_dim = None
    if meta_data and holiday_data and meteorol_data:
        print('time feature:', time_feature.shape, 'holiday feature:', holiday_feature.shape,
              'meteorol feature: ', meteorol_feature.shape, 'mete feature: ', meta_feature.shape)
    
        
    XC = np.vstack(XC)
    XP = np.vstack(XP)
    XT = np.vstack(XT)
    Y = np.vstack(Y)
    print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape)
    XC_train, XP_train, XT_train, Y_train = XC[:-len_test], XP[:-len_test], XT[:-len_test], Y[:-len_test]
    XC_test, XP_test, XT_test, Y_test = XC[-len_test:], XP[-len_test:], XT[-len_test:], Y[-len_test:]
    timestamp_train, timestamp_test = timestamps_Y[:-len_test], timestamps_Y[-len_test:]

    X_train = []
    X_test = []
    for l, X_ in zip([len_closeness, len_period, len_trend], [XC_train, XP_train, XT_train]):
        if l > 0:
            X_train.append(X_)
    for l, X_ in zip([len_closeness, len_period, len_trend], [XC_test, XP_test, XT_test]):
        if l > 0:
            X_test.append(X_)
    print('train shape:', XC_train.shape, Y_train.shape, 'test shape: ', XC_test.shape, Y_test.shape)
    
    if metadata_dim is not None:
        meta_feature_train, meta_feature_test = meta_feature[:-len_test], meta_feature[-len_test:]
        X_train.append(meta_feature_train)
        X_test.append(meta_feature_test)
    for _X in X_train:
        print(_X.shape, )
    print()
    for _X in X_test:
        print(_X.shape, )
    print()
    return X_train, Y_train, X_test, Y_test, mmn, metadata_dim, timestamp_train, timestamp_test

In [59]:
print("loading data...")
ts = time.time()
fname = os.path.join(CONFIG.DATAPATH, 'CACHE', 'TaxiBJ_C{}_P{}_T{}.h5'.format(CONFIG.len_closeness, CONFIG.len_period, CONFIG.len_trend))
if os.path.exists(fname) and CACHEDATA:
    X_train, Y_train, X_test, Y_test, mmn, external_dim, timestamp_train, timestamp_test = read_cache(fname)
    print("load %s successfully" % fname)
else:
    X_train, Y_train, X_test, Y_test, mmn, external_dim, timestamp_train, timestamp_test = load_data(\
            T=CONFIG.T, nb_flow=CONFIG.nb_flow, len_closeness=CONFIG.len_closeness, len_period=CONFIG.len_period, \
            len_trend=CONFIG.len_trend, len_test=CONFIG.len_test, preprocess_name='preprocessing.pkl', \
            meta_data=CONFIG.meta_data, meteorol_data=CONFIG.meteorol_data, holiday_data=CONFIG.holiday_data)
    if CACHEDATA:
        cache(fname, X_train, Y_train, X_test, Y_test, external_dim, timestamp_train, timestamp_test)

loading data...
('file name: ', '../Data/CrowdFlows/TaxiBJ/BJ13_M32x32_T30_InOut.h5')
=====stat=====
data shape: (4888, 2, 32, 32)
# of days: 121, from 2013-07-01 to 2013-10-29
# of timeslots: 5808
# of timeslots (available): 4888
missing ratio of timeslots: 15.8%
max: 1230.000, min: 0.000
=====stat=====
('incomplete days: ', ['20130926'])


('file name: ', '../Data/CrowdFlows/TaxiBJ/BJ14_M32x32_T30_InOut.h5')
=====stat=====
data shape: (4780, 2, 32, 32)
# of days: 119, from 2014-03-01 to 2014-06-27
# of timeslots: 5712
# of timeslots (available): 4780
missing ratio of timeslots: 16.3%
max: 1292.000, min: 0.000
=====stat=====
('incomplete days: ', ['20140304', '20140313', '20140323', '20140326', '20140401', '20140402', '20140409', '20140410', '20140412', '20140422', '20140501', '20140526', '20140618', '20140627'])


('file name: ', '../Data/CrowdFlows/TaxiBJ/BJ15_M32x32_T30_InOut.h5')
=====stat=====
data shape: (5596, 2, 32, 32)
# of days: 122, from 2015-03-01 to 2015-06-30
# of timesl

# build model

define elements

In [71]:
class iLayer(Layer):
    def __init__(self, **kwargs):
        # self.output_dim = output_dim
        super(iLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        initial_weight_value = np.random.random(input_shape[1:])
        self.W = K.variable(initial_weight_value)
        self.trainable_weights = [self.W]

    def call(self, x, mask=None):
        return x * self.W

    def get_output_shape_for(self, input_shape):
        return input_shape

def _shortcut(input, residual):
    return merge([input, residual], mode='sum')


def _bn_relu_conv(nb_filter, nb_row, nb_col, subsample=(1, 1), bn=False):
    def f(input):
        if bn:
            input = BatchNormalization(mode=0, axis=1)(input)
        activation = Activation('relu')(input)
        return Convolution2D(nb_filter=nb_filter, nb_row=nb_row, nb_col=nb_col, subsample=subsample, border_mode="same")(activation)
    return f


def _residual_unit(nb_filter, init_subsample=(1, 1)):
    def f(input):
        residual = _bn_relu_conv(nb_filter, 3, 3)(input)
        residual = _bn_relu_conv(nb_filter, 3, 3)(residual)
        return _shortcut(input, residual)
    return f


def ResUnits(residual_unit, nb_filter, repetations=1):
    def f(input):
        for i in range(repetations):
            init_subsample = (1, 1)
            input = residual_unit(nb_filter=nb_filter,
                                  init_subsample=init_subsample)(input)
        return input
    return f

define metric

In [72]:
def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true))


def root_mean_square_error(y_true, y_pred):
    return mean_squared_error(y_true, y_pred) ** 0.5


def rmse(y_true, y_pred):
    return mean_squared_error(y_true, y_pred) ** 0.5

# aliases
mse = MSE = mean_squared_error
# rmse = RMSE = root_mean_square_error


def masked_mean_squared_error(y_true, y_pred):
    idx = (y_true > 1e-6).nonzero()
    return K.mean(K.square(y_pred[idx] - y_true[idx]))


def masked_rmse(y_true, y_pred):
    return masked_mean_squared_error(y_true, y_pred) ** 0.5

build model

In [75]:
def stresnet(c_conf=(3, 2, 32, 32), p_conf=(3, 2, 32, 32), t_conf=(3, 2, 32, 32), external_dim=8, nb_residual_unit=3):
    '''
    C - Temporal Closeness
    P - Period
    T - Trend
    conf = (len_seq, nb_flow, map_height, map_width)
    external_dim
    '''

    # main input
    main_inputs = []
    outputs = []
    for conf in [c_conf, p_conf, t_conf]:
        if conf is not None:
            len_seq, nb_flow, map_height, map_width = conf
            input = Input(shape=(nb_flow * len_seq, map_height, map_width))
            main_inputs.append(input)
            # Conv1
            conv1 = Convolution2D(
                nb_filter=64, nb_row=3, nb_col=3, border_mode="same",input_shape=input.shape)(input)
            # [nb_residual_unit] Residual Units
            residual_output = ResUnits(_residual_unit, nb_filter=64,
                              repetations=nb_residual_unit)(conv1)
            # Conv2
            activation = Activation('relu')(residual_output)
            conv2 = Convolution2D(
                nb_filter=nb_flow, nb_row=3, nb_col=3, border_mode="same")(activation)
            outputs.append(conv2)

    # parameter-matrix-based fusion
    if len(outputs) == 1:
        main_output = outputs[0]
    else:
        new_outputs = []
        for output in outputs:
            new_outputs.append(iLayer()(output))
        main_output = merge(new_outputs, mode='sum')

    # fusing with external component
    if external_dim != None and external_dim > 0:
        # external input
        external_input = Input(shape=(external_dim,))
        main_inputs.append(external_input)
        embedding = Dense(output_dim=10)(external_input)
        embedding = Activation('relu')(embedding)
        h1 = Dense(output_dim=nb_flow * map_height * map_width)(embedding)
        activation = Activation('relu')(h1)
        external_output = Reshape((nb_flow, map_height, map_width))(activation)
        print(main_output.shape)
        print(external_output.shape)
        main_output = merge([main_output, external_output], mode='sum')
    else:
        print('external_dim:', external_dim)

    main_output = Activation('tanh')(main_output)
    model = Model(input=main_inputs, output=main_output)

    return model

def build_model(external_dim):
    c_conf = (CONFIG.len_closeness, CONFIG.nb_flow, CONFIG.map_height,
              CONFIG.map_width) if CONFIG.len_closeness > 0 else None
    p_conf = (CONFIG.len_period, CONFIG.nb_flow, CONFIG.map_height,
              CONFIG.map_width) if CONFIG.len_period > 0 else None
    t_conf = (CONFIG.len_trend, CONFIG.nb_flow, CONFIG.map_height,
              CONFIG.map_width) if CONFIG.len_trend > 0 else None

    model = stresnet(c_conf=c_conf, p_conf=p_conf, t_conf=t_conf,
                     external_dim=external_dim, nb_residual_unit=CONFIG.nb_residual_unit)
    adam = Adam(lr=CONFIG.lr)
    model.compile(loss='mse', optimizer=adam, metrics=[rmse])
    return model

plot model structure

In [76]:
print("compiling model...")
model = build_model(external_dim)
plot(model, to_file='ST-ResNet.png', show_shapes=True)
model.summary()

compiling model...




(?, 2, 32, 32)
(?, 2, 32, 32)
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_5 (InputLayer)             (None, 6, 32, 32)     0                                            
____________________________________________________________________________________________________
input_6 (InputLayer)             (None, 2, 32, 32)     0                                            
____________________________________________________________________________________________________
input_7 (InputLayer)             (None, 2, 32, 32)     0                                            
____________________________________________________________________________________________________
conv2d_31 (Conv2D)               (None, 64, 32, 32)    3520        input_5[0][0]                    
_____________________________________________________________

# Run

train

In [79]:
hyperparams_name = 'c{}.p{}.t{}.resunit{}.lr{}'.format(CONFIG.len_closeness, CONFIG.len_period, CONFIG.len_trend, CONFIG.nb_residual_unit, CONFIG.lr)
fname_param = os.path.join('../model', '{}.best.h5'.format(hyperparams_name))
early_stopping = EarlyStopping(monitor='val_rmse', patience=5, mode='min')
model_checkpoint = ModelCheckpoint(fname_param, monitor='val_rmse', verbose=0, save_best_only=True, mode='min')
print("training model...")
history = model.fit(X_train, Y_train,\
                        nb_epoch=CONFIG.nb_epoch,\
                        batch_size=CONFIG.batch_size,\
                        validation_split=0.1,\
                        callbacks=[early_stopping, model_checkpoint],\
                        verbose=1)
model.save_weights(os.path.join('../model', '{}.h5'.format(hyperparams_name)), overwrite=True)
pickle.dump((history.history), open(os.path.join('../result', '{}.history.pkl'.format(hyperparams_name)), 'wb'))

training model...


  


Train on 12355 samples, validate on 1373 samples
Epoch 1/500
   64/12355 [..............................] - ETA: 4951s - loss: 0.4391 - rmse: 0.6490

KeyboardInterrupt: 

test

In [None]:
print('evaluating using the model that has the best loss on the valid set')
model.load_weights(fname_param)
score = model.evaluate(X_train, Y_train, batch_size=Y_train.shape[0] // 48, verbose=0)
print('Train score: %.6f rmse (norm): %.6f rmse (real): %.6f' % (score[0], score[1], score[1] * (mmn._max - mmn._min) / 2. * m_factor))
score = model.evaluate(X_test, Y_test, batch_size=Y_test.shape[0], verbose=0)
print('Test score: %.6f rmse (norm): %.6f rmse (real): %.6f' % (score[0], score[1], score[1] * (mmn._max - mmn._min) / 2. * m_factor))