In [7]:
from estimator.estimator.Estimator import Estimator, Mode
from estimator.estimator.Hook import Hook

import numpy as np
import tensorflow as tf

tf.reset_default_graph()
EPOCHS = 100
# create the dataset
train_data = (np.random.rand(1000,2),np.random.rand(1000,1))
test_data = (np.random.rand(100,2),np.random.rand(100,1))
# define the input function
input_fn = Estimator.create_input_fn(input_shape=[None,2],output_shape=[None,1])
# define the model builder
def model_builder(x, y, config):
    # config is a dictionary that can be passed to the estimator
    net = tf.layers.dense(x, 16, activation=tf.nn.relu)
    predictions = tf.layers.dense(net, 1, activation=tf.nn.sigmoid)

    loss = tf.losses.mean_squared_error(labels=y, predictions=predictions)
    train_step = tf.train.AdamOptimizer(0.01).minimize(loss)
    # it must return a dictionary contain the operation to train, predict and evaluate
    return {
            Mode.TRAIN: {'train_step': train_step },
            Mode.PREDICT: { 'predictions': predictions },
            Mode.EVAL: { 'loss': loss } # used as metrics
    }

class Logger(Hook):

    def after_run_epoch(self, estimator, epoch, data, batch_n, tot_res):
        # tot_res is an array with the result for each bach, in this case { 'loss' : [...] }
        # take the mean of each metric key
        mean_res = {k: np.mean(tot_res[k]) for k in estimator.metrics[Mode.EVAL].keys()}
        print(mean_res)


estimator = Estimator(model_builder, input_fn, hooks=[Logger()])



In [8]:
# we can define a batch size before train, default is one
estimator.train_and_evaluate(data=train_data, validation=test_data, epochs=EPOCHS, batch_size=64, batch_size_eval=64)
res = estimator.evaluate(data=test_data)
print(res)

pred = estimator.predict(np.array([[2,1]]))
print(pred)

{'loss': 0.08661313}
{'loss': 0.07500905}
{'loss': 0.08483291}
{'loss': 0.07544698}
{'loss': 0.08530777}
{'loss': 0.07563208}
{'loss': 0.085121274}
{'loss': 0.07546256}
{'loss': 0.08599944}
{'loss': 0.07576929}
{'loss': 0.08478665}
{'loss': 0.075698465}
{'loss': 0.08553242}
{'loss': 0.07656501}
{'loss': 0.08644321}
{'loss': 0.07520376}
{'loss': 0.08580646}
{'loss': 0.07590302}
{'loss': 0.08545609}
{'loss': 0.075246856}
{'loss': 0.08515479}
{'loss': 0.07596704}
{'loss': 0.085531145}
{'loss': 0.0757267}
{'loss': 0.08577366}
{'loss': 0.075311095}
{'loss': 0.0845397}
{'loss': 0.0752019}
{'loss': 0.08492048}
{'loss': 0.07593995}
{'loss': 0.085196905}
{'loss': 0.07542121}
{'loss': 0.08489785}
{'loss': 0.07609701}
{'loss': 0.085247636}
{'loss': 0.075919345}
{'loss': 0.086019635}
{'loss': 0.075492516}
{'loss': 0.08555007}
{'loss': 0.07563598}
{'loss': 0.08537464}
{'loss': 0.07560341}
{'loss': 0.08532715}
{'loss': 0.07573288}
{'loss': 0.085173495}
{'loss': 0.075998634}
{'loss': 0.085409135}
{'l