In [25]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from utils.imports import *
from tensorflow.keras.layers import *
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1" for multiple

In [35]:
from utils.transforms import *
PATH = Path('../data')
# sz = 128
sz = 256
nt = 10
bs = 4
MODEL_VERSION = 'prednet_' + str(sz) + '_1'

num_gpus = 2

class Slice(Transform):
    """ Return a slice of the images
    
    Arguments:
    The same as the built-in function slice
    """
    def __init__(self, *args, **kwargs):
        self.slice = slice(*args, **kwargs)
        super().__init__(TfmType.NO)
        
    def do_transform(self, x, is_y):
        return x[self.slice]

aug_tfms = [Slice(nt)]

In [30]:
from models.prednet_refactored import PredNetCell, PredNet

# n_channels, im_height, im_width = (3, 128, 160)
n_channels, im_height, im_width = (1, sz, sz)
input_shape = (im_height, im_width, n_channels)
stack_sizes = (n_channels, 48, 96, 192)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3)

layer_loss_weights = np.array([1., 0., 0., 0.])  # weighting for each layer in final loss; "L_0" model:  [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
time_loss_weights = 1./ (nt - 1) * np.ones((nt,1))  # equally weight all timesteps except the first
time_loss_weights[0] = 0

prednet_cell = PredNetCell(stack_sizes=stack_sizes,
                    R_stack_sizes=R_stack_sizes,
                    A_filt_sizes=A_filt_sizes,
                    Ahat_filt_sizes=Ahat_filt_sizes,
                    R_filt_sizes=R_filt_sizes)

prednet = PredNet(prednet_cell)

In [31]:
inputs = tf.keras.Input(shape=(nt,) + input_shape)
errors = prednet(inputs)  # errors will be (batch_size, nt, nb_layers)
errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors)  # calculate weighted error by layer
errors_by_time = Flatten()(errors_by_time)  # will be (batch_size, nt)
final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time)  # weight errors by time
model = Model(inputs=inputs, outputs=final_errors)
# model.compile(loss='mean_absolute_error', optimizer='adam')
# model = tf.keras.utils.multi_gpu_model(model, gpus=num_gpus)
model.compile(loss='mean_absolute_error', optimizer='adam')
# model.compile(loss='mean_absolute_error', optimizer=tf.train.AdamOptimizer())

In [13]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 10, 128, 128, 1)   0         
_________________________________________________________________
pred_net (PredNet)           (None, 10, 4)             6909818   
_________________________________________________________________
time_distributed_1 (TimeDist (None, 10, 1)             5         
_________________________________________________________________
flatten_1 (Flatten)          (None, 10)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 11        
Total params: 6,909,834
Trainable params: 6,909,818
Non-trainable params: 16
_________________________________________________________________


In [32]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 10, 256, 256, 1)   0         
_________________________________________________________________
pred_net_2 (PredNet)         (None, 10, 4)             6909818   
_________________________________________________________________
time_distributed_3 (TimeDist (None, 10, 1)             5         
_________________________________________________________________
flatten_3 (Flatten)          (None, 10)                0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 11        
Total params: 6,909,834
Trainable params: 6,909,818
Non-trainable params: 16
_________________________________________________________________


In [14]:
def input_fn(fns = ['../data/tfrecords/train_1.tfrecords'],
             sz=128, nt=10, aug_tfms=aug_tfms,
             stats_fn='stat.csv', stats_sep=','):
    dataset = tf.data.TFRecordDataset(fns)
    
    y = tf.zeros([bs, 1])
    def parser_train(serialized_example):
        # experimental. TODO: read only needed samples
        shape = (61, 501, 501, 3)
        context_features = {
                'time_stamp': tf.FixedLenFeature([], tf.string),
            }
        sequence_features = {
                "raw_png": tf.FixedLenSequenceFeature([], dtype=tf.string)
            }
        
        features, sequence_features = tf.parse_single_sequence_example(
            serialized_example, 
            context_features=context_features, 
            sequence_features=sequence_features)

        x = tf.map_fn(tf.image.decode_png, sequence_features['raw_png'], dtype=tf.uint8,
                    back_prop=False, swap_memory=False, infer_shape=False)
        x = tf.cast(x, tf.float32)
        x /= 255
        x.set_shape(shape)
        x = tf.expand_dims(x[:,:,:,0], axis=3)
        return x, y
    
    stats = np.fromfile(stats_fn, sep=stats_sep)
    tfms, _ = tfms_from_stats(stats, sz, aug_tfms=aug_tfms, crop_type=CropType.NO)
    
    dataset = dataset.map(parser_train)
    dataset = dataset.map(tfms)
    dataset = dataset.batch(bs)
    dataset = dataset.repeat()
    # dataset = dataset.prefetch()
    
#     return dataset
    
    y = tf.zeros([bs, 1])
    iterator = dataset.make_one_shot_iterator()
    x, _ = iterator.get_next()
    return x, y

In [15]:
x, y = input_fn()

In [10]:
callbacks = [
    keras.callbacks.TensorBoard(write_grads=True, write_images=True),
    keras.callbacks.History(),
    keras.callbacks.ModelCheckpoint(),
    keras.callbacks.EarlyStopping()
]

In [11]:
model.fit(x, y, steps_per_epoch=10, callbacks=callbacks)

Epoch 1/1


<tensorflow.python.keras.callbacks.History at 0x7f58385499b0>

In [40]:
custom_objects = {'PredNetCell': PredNetCell, 'PredNet': PredNet}
model = tf.keras.models.load_model('keras', custom_objects=custom_objects)

In [37]:
weights_path = PATH/'models'/
if not weights_path.exists: weights_path.mkdir()
model.save_weights(str(weights_path/MODEL_VERSION))

In [23]:
sess = tf.InteractiveSession()

In [34]:
model.weights[0]

<tf.Variable 'pred_net_2/layer_a_0/kernel:0' shape=(3, 3, 2, 48) dtype=float32>

In [33]:
model.load_weights('1.h5')