New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add save/load to param store (with some accompanying changes) #47
Conversation
travis build will fail until cloudpickle added to travis environment |
tests/test_param.py
Outdated
|
||
# myparam = pyro.param("myparam") | ||
# self.assertFalse(myparam_copy_stale == myparam.data.numpy()) | ||
# self.assertTrue(myparam_copy == myparam.data.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove these comments since its tested above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
youre failing a python3 test - can you fix that ?
also probably should remove the commented code
tests/test_param.py
Outdated
param_store_param_to_name = copy(pyro.get_param_store()._param_to_name) | ||
|
||
pyro.get_param_store().save('paramstore.unittest.out') | ||
pyro.get_param_store().clear() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add an assert here just to see that clear()
cleared the cache store?
pyro/__init__.py
Outdated
def user_param_name(param_name): | ||
if module_namespace_divider in param_name: | ||
return param_name.split(module_namespace_divider)[1] | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont need this else
@@ -109,23 +120,40 @@ def map_data(name, data, observer, *args, **kwargs): | |||
# for now default calls out to pyro.param -- which is handled by poutine | |||
|
|||
|
|||
def sync_module(pyro_name, nn_obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a short comment here; eventually this will be in the docs
tests/test_param.py
Outdated
self.assertTrue(param_store_params.keys().sort() == | ||
pyro.get_param_store()._params.keys().sort()) | ||
self.assertTrue(param_store_param_to_name.values().sort() == | ||
pyro.get_param_store()._param_to_name.values().sort()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these dict.keys().sort()
calls are invalid in Python 3, please fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python 3 test still failing, but otherwise LGTM.
added some asserts and incorporated people's suggestions. thought i dealt with all python3 issues but strangely travis keeps complaining. i made myself a local 3.5 environment, and everything seems fine there. and when i use travis' debugging functionality and log on to the shell and run the tests manually, everything seems fine. no idea what's going on.... |
@martinjankowiak Are you running only the param tests manually? Is it failing because of some other tests not clearing the param store? |
paramstore now has save/load. this is a simple first pass. open to suggestions for improvements.
the usage is as follows.
-- removed 'tags' from paramstore
-- params from module are named using the pytorch construct and not str(id(object))
-- renamed _clear_cache() to clear()
-- serialization via cloudpickle for now
-- per_param_args for optim changed slightly
-- one unit test for save/load added