Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.Sign up
Add option to support model.load_state_dict(model_state, strict=False) in `evaluate` & `predict`. #2054
Is your feature request related to a problem? Please describe.
I'd like to train a new model on top of the ELMo and save only the trainable parameters by overriding the torch.nn.Module.state_dict method, since the trainer now will save all the parameters returned by
This should be fine for
Describe the solution you'd like
I wonder if we could add an option to
Describe alternatives you've considered
I could come up with two ideas:
Which one do you think is better?
If all you need is to add
I was more worried about the model saving piece of this - I'm not sure at all how to make the
I believe something like this should work.
@overrides def state_dict(self, destination=None, prefix='', keep_vars=False): original = super().state_dict(destination, prefix, keep_vars) if not keep_vars: return original for key in list(original.keys()): param = original[key] if not param.requires_grad: original.pop(key) return original
If you want to put that on your model and it works for you, great, go ahead. This is not something that we can add upstream, though, as it's incredibly non-obvious and it's the wrong thing to do in a whole lot of cases (e.g., if you have fixed word embeddings; you don't want to throw away your embedding layer).