Fix problematic behavior of optimizer/scheduler in FeatureInversionTask#101
Fix problematic behavior of optimizer/scheduler in FeatureInversionTask#101ShuntaroAoki merged 3 commits intodevfrom
Conversation
|
|
Based on Otsuka-san's suggestion, I have revised the type definitions as follows. build_optimizer_factory: (type[Optimizer], _GetParamsFnType) -> _OptimizerFactoryType
- _GetParamsFnType: TypeAlias = (BaseGenerator, BaseLatent) -> Iterator[Parameter]
+ _GetParamsFnType: TypeAlias = (BaseGenerator, BaseLatent) -> _ParamsT
_OptimizerFactoryType: TypeAlias = (BaseGenerator, BaseLatent) -> Optimizer
+ _ParamsT: TypeAlias = Iterable[Tensor] | Iterable[Dict[str, Any]] | Iterable[Tuple[str, Tensor]]Reasons behind this modificationPrevious type annotations were not compatible with the use of the optimizer_factory = build_optimizer_factory(
SGD,
get_params_fn=lambda generator, latent: [
{"params": latent.parameters(), "lr": latent_lr},
{"params": generator.parameters(), "lr" generator_lr},
],
lr=base_lr, momentum=0.9
)The Why we redefined the same concept in our codebase instead of just importing it from PyTorch?I decided to define the bdpy/bdpy/recon/torch/modules/optimizer.py Lines 14 to 18 in c23afe7 The name of this type was Note on the type definition of
|
Problem
Current implementation of
FeatureInversionTaskhas several limitations/problems in the use of optimizer/scheduler. Here are the concrete examples:Initialization using param_groups works only one time
Cannot use a learning rate scheduler
Cause
The cause of the problem is in the implementation of
reset_states():bdpy/bdpy/recon/torch/task/inversion.py
Lines 216 to 226 in 9ffe7bc
Originally this method was implemented based on the following assumptions:
In reality, neither of these assumptions were true. In addition, since optimizers generally have dependencies on generator and latent instances, and learning rate schedulers have dependencies on optimizer instances, when any of these dependencies are re-instantiated, the references need to be replaced accordingly.
Solution
Instead of receiving the instances of the optimizer and learning rate scheduler themselves, FeatureInversionTask receives the factory method for creating instances. Following is the example use of the newly designed API:
Breaking changes in API
FeatureInversionTasktakesoptimizer_factory: (BaseGenerator, BaseLatent) -> Optimizerinstead ofoptimizer: Optimizeras an input argumentFeatureInversionTasktakesscheduler_factory: Optimizer -> LRSchedulerinstead ofscheduler: LRScheduleras an input argument