Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add save/load to param store (with some accompanying changes) #47

Merged
merged 10 commits into from Jul 16, 2017

Conversation

martinjankowiak
Copy link
Collaborator

@martinjankowiak martinjankowiak commented Jul 14, 2017

paramstore now has save/load. this is a simple first pass. open to suggestions for improvements.

the usage is as follows.

    # define model, run inference, etc...
    pyro.get_param_store().save('paramstore.save')
    # close session and open new session
    pyro.get_param_store().load('paramstore.save')
    # params are now in paramstore but any modules must be synced before they work
    pyro.sync_module("mymodule", pt_mymodule)
    # now good to go, as if the session had never ended...

-- removed 'tags' from paramstore
-- params from module are named using the pytorch construct and not str(id(object))
-- renamed _clear_cache() to clear()
-- serialization via cloudpickle for now
-- per_param_args for optim changed slightly
-- one unit test for save/load added

@martinjankowiak
Copy link
Collaborator Author

travis build will fail until cloudpickle added to travis environment


# myparam = pyro.param("myparam")
# self.assertFalse(myparam_copy_stale == myparam.data.numpy())
# self.assertTrue(myparam_copy == myparam.data.numpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these comments since its tested above

Copy link
Member

@jpchen jpchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

youre failing a python3 test - can you fix that ?
also probably should remove the commented code

param_store_param_to_name = copy(pyro.get_param_store()._param_to_name)

pyro.get_param_store().save('paramstore.unittest.out')
pyro.get_param_store().clear()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add an assert here just to see that clear() cleared the cache store?

pyro/__init__.py Outdated
def user_param_name(param_name):
if module_namespace_divider in param_name:
return param_name.split(module_namespace_divider)[1]
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont need this else

@@ -109,23 +120,40 @@ def map_data(name, data, observer, *args, **kwargs):
# for now default calls out to pyro.param -- which is handled by poutine


def sync_module(pyro_name, nn_obj):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a short comment here; eventually this will be in the docs

self.assertTrue(param_store_params.keys().sort() ==
pyro.get_param_store()._params.keys().sort())
self.assertTrue(param_store_param_to_name.values().sort() ==
pyro.get_param_store()._param_to_name.values().sort())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these dict.keys().sort() calls are invalid in Python 3, please fix.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python 3 test still failing, but otherwise LGTM.

@martinjankowiak
Copy link
Collaborator Author

martinjankowiak commented Jul 16, 2017

added some asserts and incorporated people's suggestions.

thought i dealt with all python3 issues but strangely travis keeps complaining. i made myself a local 3.5 environment, and everything seems fine there. and when i use travis' debugging functionality and log on to the shell and run the tests manually, everything seems fine. no idea what's going on....

@eb8680
Copy link
Member

eb8680 commented Jul 16, 2017

@martinjankowiak Are you running only the param tests manually? Is it failing because of some other tests not clearing the param store?

@eb8680 eb8680 merged commit db45c12 into dev Jul 16, 2017
@eb8680 eb8680 deleted the martin-dev-paramstore branch July 16, 2017 22:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants