Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to support model.load_state_dict(model_state, strict=False) in `evaluate` & `predict`. #2054

Closed
huntzhan opened this issue Nov 15, 2018 · 7 comments

Comments

@huntzhan
Copy link
Contributor

@huntzhan huntzhan commented Nov 15, 2018

Is your feature request related to a problem? Please describe.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

I'd like to train a new model on top of the ELMo and save only the trainable parameters by overriding the torch.nn.Module.state_dict method, since the trainer now will save all the parameters returned by state_dict.

model_state = self.model.state_dict()

This should be fine for allennlp train, but will lead to an error during allennlp evaluate and allennlp predict, since the _load method calls load_state_dict with the default option strict=True and hence a RuntimeError(... Missing key(s) in state_dict: ...) should be raised.

model = Model.load(config.duplicate(),
weights_file=weights_path,
serialization_dir=serialization_dir,
cuda_device=cuda_device)

model = Model.from_params(vocab=vocab, params=model_params)
model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
model.load_state_dict(model_state)

Describe the solution you'd like
A clear and concise description of what you want to happen.

I wonder if we could add an option to allennlp evaluate & allennlp predict to support model.load_state_dict(model_state, strict=False)?

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

N/A

Additional context
Add any other context or screenshots about the feature request here.

During initialization, _ElmoCharacterEncoder(torch.nn.Module) and ElmoLstm(_EncoderBase) will load the options_file and weight_file no matter for training or testing.

options_file and weight_file will be archived after training so removing those parameters from weights.th dump should be safe:

params.add_file_to_archive('options_file')
params.add_file_to_archive('weight_file')

.
├── config.json
├── files_to_archive.json
├── fta
│   ├── model.text_field_embedder.elmo.options_file
│   └── model.text_field_embedder.elmo.weight_file
├── vocabulary
│   ├── labels.txt
│   └── non_padded_namespaces.txt
└── weights.th
@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Nov 15, 2018

This seems to me like a reasonable thing to do. I'm not sure how to do this in a clean way without thinking about it a lot more, but if there's a way, go for it. PRs welcome.

@huntzhan

This comment has been minimized.

Copy link
Contributor Author

@huntzhan huntzhan commented Nov 16, 2018

I could come up with two ideas:

  1. Add an new parameter disable_strict_load to load_archive, Model.load, Model._load.
  2. Add a new key to config and extract its value in Model._load.

Which one do you think is better?

Call stack:

evaluate & predict:

archive = load_archive(args.archive_file, args.cuda_device, args.overrides, args.weights_file)

load_archive:

model = Model.load(config.duplicate(),
weights_file=weights_path,
serialization_dir=serialization_dir,
cuda_device=cuda_device)

Model.load:

return cls.by_name(model_type)._load(config, serialization_dir, weights_file, cuda_device)

Model._load:

model = Model.from_params(vocab=vocab, params=model_params)
model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
model.load_state_dict(model_state)

@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Nov 16, 2018

If all you need is to add strict=False to model.load_state_dict(), your first option seems like the right solution, and pretty straightforward. A PR for that would be fine.

I was more worried about the model saving piece of this - I'm not sure at all how to make the Trainer save only the right parts of the model in a way that's generally useful. You definitely don't want to tell it to only save the trainable weights, because that's not the right thing to do all the time.

@huntzhan

This comment has been minimized.

Copy link
Contributor Author

@huntzhan huntzhan commented Dec 3, 2018

I believe something like this should work.

    @overrides
    def state_dict(self, destination=None, prefix='', keep_vars=False):
        original = super().state_dict(destination, prefix, keep_vars)
        if not keep_vars:
            return original

        for key in list(original.keys()):
            param = original[key]
            if not param.requires_grad:
                original.pop(key)

        return original
@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Dec 3, 2018

If you want to put that on your model and it works for you, great, go ahead. This is not something that we can add upstream, though, as it's incredibly non-obvious and it's the wrong thing to do in a whole lot of cases (e.g., if you have fixed word embeddings; you don't want to throw away your embedding layer).

@huntzhan

This comment has been minimized.

Copy link
Contributor Author

@huntzhan huntzhan commented Dec 3, 2018

This is not something that we can add upstream

Fully agree. I'm not intended to change Trainer at all.
Will submit a PR to support strict=False later.

@matt-gardner

This comment has been minimized.

Copy link
Member

@matt-gardner matt-gardner commented Dec 17, 2019

I think the right solution here is to make it easier to customize / write your own training loop, where you can put in custom code to do whatever you want to your model (e.g., with callbacks). I'm closing this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
3 participants
You can’t perform that action at this time.