## Context Encoder

In this notebook we are going to be trying different networks to test their performance.
Let's begin by importing tensorflow and the network.


In [None]:
import tensorflow as tf
from network.contextEncoder import ContextEncoderNetwork

Next, we have a modifiable version of the context encoder. The goal is to be able to easily modify the network and try other ideas.

In [None]:
class ModifiedContextEncoderNetwork(ContextEncoderNetwork):
    def _encoder(self, model, isTraining):
        with tf.variable_scope("Encoder"):
            model.addReshape((self._batch_size, self._window_size - self._gap_length, 1))
            model.addConvLayer(filter_width=129, input_channels=1, output_channels=16,
                                      stride=4, name="First_Conv", isTraining=isTraining)
            model.addConvLayer(filter_width=65, input_channels=16, output_channels=64,
                                      stride=4, name="Second_Conv", isTraining=isTraining)
            model.addConvLayer(filter_width=33, input_channels=64, output_channels=256,
                                      stride=4, name="Third_Conv", isTraining=isTraining)
            model.addConvLayer(filter_width=17, input_channels=256, output_channels=1024,
                                      stride=4, name="Fourth_Conv", isTraining=isTraining)
            model.addConvLayer(filter_width=9, input_channels=1024, output_channels=4096,
                                      stride=4, name="Last_Conv", isTraining=isTraining)

    def _decoder(self, model, isTraining):
        with tf.variable_scope("Decoder"):
            model.addConvLayerWithoutNonLin(filter_width=5, input_channels=4096, output_channels=1024,
                                            stride=4, name="Decode_Conv", isTraining=isTraining)
            model.addReshape((self._batch_size, self._gap_length))


In [None]:
tf.reset_default_graph()

train_filename = 'train_full_w5120_g1024_h512_19404621.tfrecords'
valid_filename = 'valid_full_w5120_g1024_h512_ex913967.tfrecords'

aContextEncoderNetwork = ModifiedContextEncoderNetwork(batch_size=256, window_size=5120, gap_length=1024, 
                                             learning_rate=1e-5, name='first_try')
aContextEncoderNetwork.train(train_filename, valid_filename, num_steps=1e6)