Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Addtional (de)serialization routines.
Browse files Browse the repository at this point in the history
Deprecated load_weights in favor of the more appropriately named load_params.

Fixes #106
  • Loading branch information
scttl committed Jan 13, 2016
1 parent cac52b7 commit 29ed0eb
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 21 deletions.
2 changes: 1 addition & 1 deletion bin/neon
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ if __name__ == "__main__":
stochastic_rounding=args.rounding)

if args.model_file:
model.load_weights(args.model_file)
model.load_params(args.model_file)
train, test = load_data(data_dir=args.data_dir)
# configure callbacks
callbacks = Callbacks(model, train, eval_set=test, **args.callback_args)
Expand Down
2 changes: 1 addition & 1 deletion examples/babi/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# create model
model_inference = create_model(babi.vocab_size, args.rlayer_type)
model_inference.load_weights(args.model_weights)
model_inference.load_params(args.model_weights)
model_inference.initialize(dataset=valid_set)

ex_story, ex_question, ex_answer = babi.test_parsed[0]
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar10_allcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
if args.model_file:
import os
assert os.path.exists(args.model_file), '%s not found' % args.model_file
mlp.load_weights(args.model_file)
mlp.load_params(args.model_file)

# configure callbacks
callbacks = Callbacks(mlp, train_set, eval_set=valid_set, **args.callback_args)
Expand Down
2 changes: 1 addition & 1 deletion examples/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
model.fit(train_set, optimizer=opt, num_epochs=num_epochs, cost=cost, callbacks=callbacks)

# load model (if exited) and evaluate bleu score on test set
model.load_weights(checkpoint_model_path)
model.load_params(checkpoint_model_path)
test_set = ImageCaptionTest(path=data_path)
sents, targets = test_set.predict(model)
test_set.bleu_score(sents, targets)
2 changes: 1 addition & 1 deletion examples/imagenet_allcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
if args.model_file:
import os
assert os.path.exists(args.model_file), '%s not found' % args.model_file
mlp.load_weights(args.model_file)
mlp.load_params(args.model_file)

# configure callbacks
callbacks = Callbacks(mlp, train, eval_set=test, **args.callback_args)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_generation_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def sample(prob):
Affine(len(train_set.vocab), init, bias=init, activation=Softmax())
]
model_new = Model(layers=layers)
model_new.load_weights(args.save_path)
model_new.load_params(args.save_path)
model_new.initialize(dataset=(train_set.shape[0], time_steps))

# Generate text
Expand Down
2 changes: 1 addition & 1 deletion examples/timeseries_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def err(y, t):
]

model_new = Model(layers=layers)
model_new.load_weights(args.save_path)
model_new.load_params(args.save_path)
model_new.initialize(dataset=(train_set.nfeatures, seq_len))

output = np.zeros((train_set.nfeatures, num_predict))
Expand Down
2 changes: 1 addition & 1 deletion neon/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, model, train_set,
self.callback_data = h5py.File(output_file, "w")

if model_file:
model.load_weights(model_file)
model.load_params(model_file)

self.model = model
self.train_set = train_set
Expand Down
50 changes: 39 additions & 11 deletions neon/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from neon import NervanaObject
from neon.transforms import CrossEntropyBinary, Logistic
from neon.util.persist import load_obj
from neon.util.persist import load_obj, save_obj
from neon.layers import Sequential, Activation, Tree
import numpy as np

Expand Down Expand Up @@ -249,27 +249,56 @@ def get_description(self):
pdict['optimizer'] = self.optimizer.get_description()
return pdict

def load_weights(self, weight_path):
def save_params(self, param_path, keep_states=True):
"""
Serializes and saves model parameters to the path specified.
Arguments:
param_path (str): File to write serialized parameter dict to.
keep_states (bool): Whether to save optimizer states too.
Defaults to True.
"""
Loads the layer weights saved in weight_path from serialize().
save_obj(self.serialize(keep_states), param_path)

def load_params(self, param_path):
"""
Loads the model parameters (per layer weights, epochs run, optimizer
states) saved in param_path from serialize().
Arguments:
weight_path (str): File containing serialized python dict with layer
weights and states.
param_path (str): File containing serialized python dict with layer
weights and states.
"""
pdict = load_obj(param_path)
self.deserialize(pdict)
logger.info('Model weights loaded from %s', param_path)

def load_weights(self, weight_path):
"""
.. deprecated:: 1.1.4
Use :func:`load_params` instead
"""
logger.warning('Calling deprecated load_weights function. Use '
'load_params instead')
self.load_params(weight_path)

def deserialize(self, params):
"""
pdict = load_obj(weight_path)
Loads per layer (weights, states) and other model parameters from the
dictionary passed.
self.epoch_index = pdict['epoch_index']
Arguments:
params (dict): parameters as returned by serialize().
"""
self.epoch_index = params['epoch_index']

param_layers = [l for l in self.layers_to_optimize]
param_dict_list = pdict['layer_params_states']
param_dict_list = params['layer_params_states']
for l, ps in zip(param_layers, param_dict_list):
l.set_params(ps['params'])
if 'states' in ps:
l.set_states(ps['states'])

logger.info('Model weights loaded from %s', weight_path)

# serialize tells how to write out the parameters we've learned so
# far and associate them with layers. it can ignore layers with no
# learned parameters. the model stores states to pass to the
Expand All @@ -285,7 +314,6 @@ def serialize(self, keep_states=True):
Returns:
dict: Model data including layer parameters and epochs complete.
"""

pdict = dict()
params_states = [l.get_params_serialize(keep_states) for l in self.layers_to_optimize]
pdict['layer_params_states'] = params_states
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def test_model_serialize(backend_default, data):
break

# Serialize model
save_obj(mlp.serialize(keep_states=True), tmp_save)
mlp.save_params(tmp_save, keep_states=True)

# Load model
mlp = Model(layers=layers)
mlp.load_weights(tmp_save)
mlp.load_params(tmp_save)

outputs = []
pdicts = [l.get_params_serialize() for l in mlp.layers_to_optimize]
Expand Down

0 comments on commit 29ed0eb

Please sign in to comment.