From db5e61147468f370a6ac71139dcf9c96a6de4588 Mon Sep 17 00:00:00 2001 From: Germey Date: Thu, 11 Oct 2018 00:09:11 +0800 Subject: [PATCH] change examples --- examples/evaluate.py | 4 ---- examples/infer.py | 15 +++++---------- examples/model.py | 1 + examples/train.py | 18 +++++++----------- 4 files changed, 13 insertions(+), 25 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 5507463..5a83f0e 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -1,4 +1,3 @@ -from model import BostonHousingModel from model_zoo.evaluater import BaseEvaluater import tensorflow as tf @@ -6,9 +5,6 @@ class Evaluater(BaseEvaluater): - def __init__(self): - BaseEvaluater.__init__(self) - self.model_class = BostonHousingModel def prepare_data(self): from tensorflow.python.keras.datasets import boston_housing diff --git a/examples/infer.py b/examples/infer.py index b8b1112..b78f339 100644 --- a/examples/infer.py +++ b/examples/infer.py @@ -1,21 +1,16 @@ -from model import BostonHousingModel from model_zoo.inferer import BaseInferer +from model_zoo.preprocess import standardize import tensorflow as tf -from tensorflow.python.keras.datasets import boston_housing -from sklearn.preprocessing import StandardScaler -tf.flags.DEFINE_string('checkpoint_name', 'model.ckpt-38', help='Model name') +tf.flags.DEFINE_string('checkpoint_name', 'model.ckpt-20', help='Model name') + class Inferer(BaseInferer): - def __init__(self): - BaseInferer.__init__(self) - self.model_class = BostonHousingModel def prepare_data(self): + from tensorflow.python.keras.datasets import boston_housing (x_train, y_train), (x_test, y_test) = boston_housing.load_data() - ss = StandardScaler() - ss.fit(x_train) - x_test = ss.transform(x_test) + _, x_test = standardize(x_train, x_test) return x_test diff --git a/examples/model.py b/examples/model.py index 08110e2..6b1b7e2 100644 --- a/examples/model.py +++ b/examples/model.py @@ -1,6 +1,7 @@ from model_zoo.model import BaseModel import tensorflow as tf + class BostonHousingModel(BaseModel): def __init__(self, config): super(BostonHousingModel, self).__init__(config) diff --git a/examples/train.py b/examples/train.py index e5b96e9..fb301a9 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,23 +1,19 @@ -from model import BostonHousingModel +import tensorflow as tf from model_zoo.trainer import BaseTrainer -from tensorflow.python.keras.datasets import boston_housing -from sklearn.preprocessing import StandardScaler +from model_zoo.preprocess import standardize + +tf.flags.DEFINE_integer('epochs', 20, 'Max epochs') +tf.flags.DEFINE_string('model_class', 'BostonHousingModel', 'Model class name') class Trainer(BaseTrainer): - def __init__(self): - BaseTrainer.__init__(self) - self.model_class = BostonHousingModel - def prepare_data(self): + from tensorflow.python.keras.datasets import boston_housing (x_train, y_train), (x_eval, y_eval) = boston_housing.load_data() - ss = StandardScaler() - ss.fit(x_train) - x_train, x_eval = ss.transform(x_train), ss.transform(x_eval) + x_train, x_eval = standardize(x_train, x_eval) train_data, eval_data = (x_train, y_train), (x_eval, y_eval) return train_data, eval_data - if __name__ == '__main__': Trainer().run()