In [1]:
%matplotlib inline

In [3]:
from tensorflow.contrib.keras import preprocessing

In [4]:
pad_sequences = preprocessing.sequence.pad_sequences

In [5]:
import tensorflow as tf
import numpy as np
import os
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp

# Loading Data

In [7]:
path = 'rmtpp-data/real/so/'
idx = 1
event_train_file = os.path.join(path, f'event-{idx}-train.txt')
event_test_file = os.path.join(path, f'event-{idx}-test.txt')
time_train_file = os.path.join(path, f'time-{idx}-train.txt')
time_test_file = os.path.join(path, f'time-{idx}-test.txt')

In [8]:
with open(event_train_file, 'r') as in_file:
    eventTrain = [[int(y) for y in x.strip().split()] for x in in_file]

with open(event_test_file, 'r') as in_file:
    eventTest = [[int(y) for y in x.strip().split()] for x in in_file]

with open(time_train_file, 'r') as in_file:
    timeTrain = [[float(y) for y in x.strip().split()] for x in in_file]

with open(time_test_file, 'r') as in_file:
    timeTest = [[float(y) for y in x.strip().split()] for x in in_file]

assert len(timeTrain) == len(eventTrain)
assert len(eventTest) == len(timeTest)

In [9]:
nb_samples = len(eventTrain)
max_seqlen = max(len(x) for x in eventTrain)
unique_samples = set()
for x in eventTrain + eventTest:
    unique_samples = unique_samples.union(x)
    
maxTime = max(itertools.chain((max(x) for x in timeTrain), (max(x) for x in timeTest)))
minTime = min(itertools.chain((min(x) for x in timeTrain), (min(x) for x in timeTest)))
# minTime, maxTime = 0, 1

eventTrainIn = [x[:-1] for x in eventTrain]
eventTrainOut = [x[1:] for x in eventTrain]
timeTrainIn = [[(y - minTime) / (maxTime - minTime) for y in x[:-1]] for x in timeTrain]
timeTrainOut = [[(y - minTime) / (maxTime - minTime) for y in x[1:]] for x in timeTrain]

train_event_in_seq = pad_sequences(eventTrainIn, padding='post')
train_event_out_seq = pad_sequences(eventTrainOut, padding='post')
train_time_in_seq = pad_sequences(timeTrainIn, dtype=float, padding='post')
train_time_out_seq = pad_sequences(timeTrainOut, dtype=float, padding='post')

eventTestIn = [x[:-1] for x in eventTest]
eventTestOut = [x[1:] for x in eventTest]
timeTestIn = [[(y - minTime) / (maxTime - minTime) for y in x[:-1]] for x in timeTest]
timeTestOut = [[(y - minTime) / (maxTime - minTime) for y in x[1:]] for x in timeTest]

test_event_in_seq = pad_sequences(eventTestIn, padding='post')
test_event_out_seq = pad_sequences(eventTestOut, padding='post')
test_time_in_seq = pad_sequences(timeTestIn, dtype=float, padding='post')
test_time_out_seq = pad_sequences(timeTestOut, dtype=float, padding='post')

In [3]:
# Not doing one-hot encoding because TF provides tf.nn.embedding_lookup

# nb_events = len(unique_samples)

# train_event_out_hot_seq = np.zeros((nb_samples, max_seqlen - 1, nb_events), dtype=int)

# for ii, evs in enumerate(eventTrainOut):
#     for jj, x in enumerate(evs):
#         train_event_out_hot_seq[ii, jj, x - 1] = 1a
        
# nb_tests = len(eventTest)

# max_test_seqlen = max(len(x) for x in eventTest)
# test_event_out_hot_seq = np.zeros((nb_tests, max_test_seqlen - 1, nb_events), dtype=int)

# for ii, evs in enumerate(eventTestOut):
#     for jj, x in enumerate(evs):
#         test_event_out_hot_seq[ii, jj, x - 1] = 1

In [4]:
# assert np.sum(test_event_out_hot_seq) == sum(len(x) for x in eventTestOut)
# assert np.sum(train_event_out_hot_seq) == sum(len(x) for x in eventTrainOut)

# Params

In [26]:
HIDDEN_LAYER_SIZE = 128 # 64, 128, 256, 512, 1024
BATCH_SIZE = 28 # 16, 32, 64
LEARNING_RATE = 0.1 # 0.1, 0.01, 0.001
MOMENTUM = 0.9
L2_PENALTY = 0.001
EMBED_SIZE = 100 # ??

In [27]:
NUM_CATEGORIES = len(unique_samples)
FLOAT_TYPE = tf.float32
RNN_CELL_TYPE = tf.contrib.rnn.GRUCell
BPTT = 10

# Model construction

In [28]:
tf.reset_default_graph()
seed = 42
RS = np.random.RandomState(seed)
scope = "RMTPP"

with tf.variable_scope(scope):

    with tf.device('/gpu:0'):
        # Make input variables
        events_in = tf.placeholder(tf.int32, [BATCH_SIZE, BPTT])
        times_in = tf.placeholder(FLOAT_TYPE, [BATCH_SIZE, BPTT])

        events_out = tf.placeholder(tf.int32, [BATCH_SIZE, BPTT])
        times_out = tf.placeholder(FLOAT_TYPE, [BATCH_SIZE, BPTT])

        # Make variables
        with tf.variable_scope('hidden_state'):
            Wt = tf.get_variable(name='Wt', shape=(1, HIDDEN_LAYER_SIZE), 
                                 dtype=FLOAT_TYPE)            
            # The first row of Wem is merely a placeholder (will not be trained).
            Wem = tf.get_variable(name='Wem', shape=(NUM_CATEGORIES + 1, EMBED_SIZE), 
                                  dtype=FLOAT_TYPE)
            Wh = tf.get_variable(name='Wh', shape=(HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE), 
                                 dtype=FLOAT_TYPE)
            bh = tf.get_variable(name='bh', shape=(1, HIDDEN_LAYER_SIZE),
                                 dtype=FLOAT_TYPE)
            
        with tf.variable_scope('output'):
            wt = tf.get_variable(name='wt', shape=(1, 1), 
                                 dtype=FLOAT_TYPE)

            Wy = tf.get_variable(name='Wy', shape=(EMBED_SIZE, HIDDEN_LAYER_SIZE), 
                             dtype=FLOAT_TYPE)

            # The first column of Vy is merely a placeholder (will not be trained).
            Vy = tf.get_variable(name='Vy', shape=(HIDDEN_LAYER_SIZE, NUM_CATEGORIES + 1),
                                 dtype=FLOAT_TYPE)
            Vt = tf.get_variable(name='Vt', shape=(HIDDEN_LAYER_SIZE, 1),
                                 dtype=FLOAT_TYPE,
                                 initializer=tf.uniform_unit_scaling_initializer())
            bt = tf.get_variable(name='bt', shape=(1, 1),
                                 dtype=FLOAT_TYPE)
            bk = tf.get_variable(name='bk', shape=(1, NUM_CATEGORIES + 1),
                                 dtype=FLOAT_TYPE)

        # Make graph    
        # RNNcell = RNN_CELL_TYPE(HIDDEN_LAYER_SIZE)

        # Initial state for GRU cells
        initial_state = state = tf.zeros([BATCH_SIZE, HIDDEN_LAYER_SIZE], dtype=FLOAT_TYPE, name='hidden_state')

        loss = 0.0
        batch_ones = tf.ones((BATCH_SIZE, 1), dtype=FLOAT_TYPE)
        for i in range(BPTT):
            events_embedded = tf.nn.embedding_lookup(Wem, events_in[:, i])
            time = tf.expand_dims(times_in[:, i], axis=-1)

            # output, state = RNNcell(events_embedded, state)
            # TODO Does TF automatically broadcast? Then we'll not need multiplication
            # with tf.ones

            state = tf.clip_by_value(
                tf.matmul(state, Wh) + 
                tf.matmul(events_embedded, Wy) + 
                tf.matmul(time, Wt) + 
                tf.matmul(batch_ones, bh), 
                0.0, 1e6, 
                name='h_t')

            base_intensity = tf.matmul(batch_ones, bt)
            delta_t = tf.expand_dims(times_out[:, i] - times_in[:, i], axis=-1)
            log_lambda_ = (tf.matmul(state, Vt) + 
                           delta_t * wt + 
                           base_intensity)

            lambda_ = tf.exp(tf.minimum(50.0, log_lambda_), name='lambda_')
            wt_non_zero = tf.sign(wt) * tf.maximum(1e-6, tf.abs(wt))
            log_f_star = (log_lambda_ + 
                          (1/wt_non_zero) * tf.exp(tf.minimum(50.0, tf.matmul(state, Vt) + base_intensity)) -
                          (1/wt_non_zero) * lambda_)

            events_pred = tf.nn.softmax(tf.minimum(50.0, 
                                                   tf.matmul(state, Vy) + batch_ones * bk),
                                        name='Pr_events'
                                       )

            time_loss = log_f_star
            mark_loss = tf.expand_dims(
                tf.log(
                    tf.maximum(1e-6, 
                        tf.gather_nd(
                            events_pred, 
                            tf.concat([
                                tf.expand_dims(tf.range(BATCH_SIZE), -1),
                                tf.expand_dims(events_out[:, i], -1)                             
                            ], axis=1, name='Pr_next_event'
                            )
                        )
                    )
                ), axis=-1, name='log_Pr_next_event'
            )
            step_loss = time_loss + mark_loss

            # In the batch some of the sequences may have ended before we get to the
            # end of the seq. In such cases, the events will be zero.
            # TODO Figure out how to do this with RNNCell, LSTM, etc.
            num_events = tf.reduce_sum(tf.where(events_in[:, i] > 0, 
                                       tf.ones(BATCH_SIZE, dtype=FLOAT_TYPE), 
                                       tf.zeros(BATCH_SIZE, dtype=FLOAT_TYPE)),
                                       name='num_events')
            loss -= tf.cond(num_events > 0, 
                            lambda: tf.reduce_sum(tf.where(events_in[:, i] > 0, 
                                               tf.squeeze(step_loss) / num_events,
                                               tf.zeros(BATCH_SIZE)), name='batch_bptt_loss'),
                            lambda: 0.0)

        final_state = state
        # optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        # update = optimizer.minimize(loss)

        # Performing manual gradient clipping.
        gvs = optimizer.compute_gradients(loss)
        # update = optimizer.apply_gradients(gvs)

        capped_gvs = [(tf.clip_by_norm(grad, 100.0), var) for grad, var in gvs]
        update = optimizer.apply_gradients(capped_gvs)


        init = tf.global_variables_initializer()
        check_nan = tf.add_check_numerics_ops()

## Creation of batches and execution

In [29]:
iterSession.close()

In [30]:
iterSession = tf.InteractiveSession()

In [None]:
%%time
idxes = list(range(len(eventTrainIn)))

rs = np.random.RandomState(seed=42)

iterSession.run(init)

num_epochs = 5
for epoch in range(num_epochs):
    rs.shuffle(idxes)
    
    print("Starting epoch...", epoch)
    
    for batch_idx in range(len(idxes) // BATCH_SIZE):
        batch_idxes = idxes[batch_idx * BATCH_SIZE:(batch_idx + 1) * BATCH_SIZE]
        batch_event_train_in = train_event_in_seq[batch_idxes, :]
        batch_event_train_out = train_event_out_seq[batch_idxes, :]
        batch_time_train_in = train_time_in_seq[batch_idxes, :]
        batch_time_train_out = train_time_out_seq[batch_idxes, :]
        
        cur_state = np.zeros((BATCH_SIZE, HIDDEN_LAYER_SIZE))
        total_loss = 0.0
        
        for bptt_idx in range(0, len(batch_event_train_in[0]) - BPTT, BPTT):
            bptt_range = range(bptt_idx, (bptt_idx + BPTT))
            bptt_event_in = batch_event_train_in[:, bptt_range]
            bptt_event_out = batch_event_train_out[:, bptt_range]
            bptt_time_in = batch_time_train_in[:, bptt_range]
            bptt_time_out = batch_time_train_out[:, bptt_range]
            
            feed_dict = {
                  initial_state: cur_state,
                  events_in: bptt_event_in,
                  events_out: bptt_event_out,
                  times_in: bptt_time_in,
                  times_out: bptt_time_out
            }
            
#             _, _, cur_state, loss_ = \
#                 iterSession.run([check_nan, update, final_state, loss],
#                                feed_dict=feed_dict)
            _, cur_state, loss_ = \
                iterSession.run([update, final_state, loss],
                                feed_dict=feed_dict)
            total_loss += loss_
        
        if batch_idx % 10 == 0:
            print('Loss on last batch = {}'.format(total_loss))

Starting epoch... 0
Loss on last batch = 825160007.9513018
Loss on last batch = 6.474327160478761e+17
Loss on last batch = 2.0493850487793584e+17
Loss on last batch = 141193462248.3692
Loss on last batch = 2.522614200609171e+18
Loss on last batch = 1.1854313146611139e+18
Loss on last batch = 1.8593636353279263e+18
Loss on last batch = 1296907385177359.2
Loss on last batch = 8.34440814795948e+17
Loss on last batch = 6.682331945587704e+18
Loss on last batch = 1975665688909.9126
Loss on last batch = 443871990385787.2
Loss on last batch = 637247049171641.6
Loss on last batch = 1.920611449516902e+16
Loss on last batch = 7.995072172576022e+16
Loss on last batch = 571.3106479644775
Loss on last batch = 405.51602387428284
Loss on last batch = 940.1500968933105
Loss on last batch = 430496.77731609344
Starting epoch... 1
Loss on last batch = 576.1623139381409
Loss on last batch = 960.4234066009521
Loss on last batch = 424.3694438934326
Loss on last batch = 414.5264720916748
Loss on last batch = 

In [351]:
v = tf.get_default_graph().get_tensor_by_name('truediv_11:0')

In [352]:
iterSession.run(num_events, feed_dict=feed_dict)

0.0

In [None]:
writer = tf.summary.FileWriter('logs', iterSession.graph)

In [None]:
writer.close()

In [284]:
Vt.eval()

array([[  4.15257047e+28],
       [  4.13777913e+27],
       [             nan],
       [  4.00074491e+27],
       [             nan],
       [             nan],
       [  6.41804373e+27],
       [  9.25703779e+28],
       [  2.07751234e+28],
       [  1.17640731e+24],
       [  6.19882375e+12],
       [  1.87156031e+12],
       [  4.53088586e+28],
       [  9.79549677e+27],
       [             nan],
       [             nan],
       [  1.94037092e+28],
       [  4.33254788e+28],
       [  6.58284523e+28],
       [  3.65443046e+28],
       [  1.85322862e+28],
       [  5.40412750e+06],
       [  2.81002788e+28],
       [  9.31091716e+28],
       [             nan],
       [  5.79678986e+28],
       [             nan],
       [             nan],
       [             nan],
       [  7.19456973e+28],
       [             nan],
       [  1.02156405e+13],
       [             nan],
       [             nan],
       [             nan],
       [  2.11805504e+28],
       [             nan],
 

In [271]:
final_state.eval(feed_dict=feed_dict)

array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32)