Skip to content

Commit

Permalink
Implement save method
Browse files Browse the repository at this point in the history
  • Loading branch information
Hironsan committed Mar 4, 2018
1 parent ae46f6e commit e41a7a5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
12 changes: 10 additions & 2 deletions anago/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
Model definition.
"""
import json
import os

import keras.backend as K
from keras.layers import Dense, LSTM, Bidirectional, Embedding, Input, Dropout, Lambda, Activation
from keras.layers.merge import Concatenate
Expand All @@ -18,8 +21,13 @@ def predict(self, X, *args, **kwargs):
y_pred = self.model.predict(X, batch_size=1)
return y_pred

def save(self, filepath):
self.model.save_weights(filepath)
def save_weights(self, file_path):
self.model.save_weights(file_path)

def save_params(self, file_path):
with open(file_path, 'w') as f:
params = {name: val for name, val in vars(self).items() if name not in {'_loss', 'model'}}
json.dump(params, f, sort_keys=True, indent=4)

def load(self, filepath):
self.model.load_weights(filepath=filepath)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,24 @@ def test_predict(self):
pass

def test_save(self):
pass
weight_file = os.path.join(SAVE_ROOT, 'weights.h5')
param_file = os.path.join(SAVE_ROOT, 'hyperparameters.h5')
model = BiLSTMCRF(char_vocab_size=100,
word_vocab_size=10000,
ntags=10)
model.build_model()

self.assertFalse(os.path.exists(weight_file))
self.assertFalse(os.path.exists(param_file))

model.save_weights(weight_file)
model.save_params(param_file)

self.assertTrue(os.path.exists(weight_file))
self.assertTrue(os.path.exists(param_file))

os.remove(weight_file)
os.remove(param_file)

def test_load(self):
pass

0 comments on commit e41a7a5

Please sign in to comment.