Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor XLNet interface #208

Merged
merged 1 commit into from
Sep 14, 2019
Merged

Refactor XLNet interface #208

merged 1 commit into from
Sep 14, 2019

Conversation

gpengzhi
Copy link
Collaborator

@gpengzhi gpengzhi commented Sep 13, 2019

Requested by Zhiting's comments in PR in texar-tf

Make the interface of XLNetEncoder be consistent with other modules.

Copy link
Collaborator

@huzecong huzecong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I'm thinking, would this problem be solved if we have a mechanism that can:

  • Add methods that is called prior to __init__ (of the uppermost subclass).
  • Add methods that is called after __init__ (of the uppermost subclass).

I feel that such mechanism could solve both this problem and the repeated-hparams initialization problem we've discussed before.

A hacky way of doing this is via metaclasses. When an instance of a class is created, the metaclass __call__ method is invoked, which internally calls the class __new__ and class __init__ methods. We can create a custom metaclass for ModuleBase that constructs the hparams object before __init__, and allows sub-metaclasses to append method calls after __init__.

A diagram of instance creation is as follows, taken from this very good article that explains it all:
`

@gpengzhi
Copy link
Collaborator Author

Looks good. I'm thinking, would this problem be solved if we have a mechanism that can:

  • Add methods that is called prior to __init__ (of the uppermost subclass).
  • Add methods that is called after __init__ (of the uppermost subclass).

I feel that such mechanism could solve both this problem and the repeated-hparams initialization problem we've discussed before.

A hacky way of doing this is via metaclasses. When an instance of a class is created, the metaclass __call__ method is invoked, which internally calls the class __new__ and class __init__ methods. We can create a custom metaclass for ModuleBase that constructs the hparams object before __init__, and allows sub-metaclasses to append method calls after __init__.

A diagram of instance creation is as follows, taken from this very good article that explains it all:
`

Thanks a lot. I think this modification is related to #41 I will think about it.

@gpengzhi gpengzhi merged commit 1bde817 into asyml:master Sep 14, 2019
@ZhitingHu
Copy link
Member

Why is this PR related to "repeated-hparams initialization"? In particular, what's the init argument in the XLNetEncoder __init__ for in the first place?

@huzecong
Copy link
Collaborator

It's because these two problems have similar root causes.

The init argument. Our pre-trained mixin interfaces requires the concrete classes to call init_pretrained_weights() at the end of their __init__ methods. This method is used to initialize the weights of the registered module parameters from the checkpoint file.

XLNetEncoder, which inherits PretrainedXLNetMixin, follows this pattern. However, XLNetDecoder inherits XLNetEncoder (which is not a good design; but it works at least for now), but has an extra parameter lm_bias that has to be registered before calling init_pretrained_weights(), which happens with the super().__init__ call. Due to PyTorch implementations, we can't register parameters before super().__init__ (which internally calls nn.Module.__init__), so we had this init argument to control whether we call init_pretrained_weights() in XLNetEncoder.__init__.

Relation to repeated-hparams creation. The hparams issue was that: hparams is created in ModuleBase.__init__, which will be called in super().__init__ of our built-in modules. However, certain base classes require additional arguments in their constructors (e.g., DataBase require a data_source argument), and some arguments cannot be constructed without knowing hparams values. As a results, the hparams object is constructed multiple times.

The init issue above could benefit from being able to add method calls (a call to init_pretrained_weights()) at the end of object initialization. The hparams issue could similarly benefit from being able to add calls before initialization (a call to construct hparams).

@ZhitingHu
Copy link
Member

Thanks for explaining, and the possible solution.

The solution can be a candidate we can apply at some point. Since it's a bit intricate, for now (when the issue has not widely occurred), let's just work around case by case.

One of our interface design "principles" could to some extent (not completely) avoid the issue: put hyperparameters in hparams instead of constructor arguments whenever possible. In this way, it's in general unlikely (though not impossible) that an hyperparam is in the hparams of a class but at the same time is an argument of super().__init__

It looks the CharCNN example in the issue #41 is out of the scope of the above principle though. It's reasonable that in_channels is Conv1DEncoder.__init__ argument and char_embed_dim is in CharCNN hparams.

@huzecong
Copy link
Collaborator

Yes, we could leave this for the future.

While the issue in the CharCNN example could be avoided with this principle, it only applies to values that fits in hparams -- and a data source is probably not one of them. There could also be logic that has to be done before calling the super class constructor.

@gpengzhi gpengzhi deleted the xlnet-interface branch November 1, 2019 02:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants