Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

PoC: Revamp optimizer and scheduler experience using registries #777

Conversation

karthikrangasai
Copy link
Contributor

@karthikrangasai karthikrangasai commented Sep 20, 2021

What does this PR do?

Fixes #752

Optimizer and scheduler can currently be an instance but that doesn't comply with the way the code works since you need the model to the instantiate optimizer and need the optimizer to instantiate the scheduler. The main idea is to use the FlashRegistry class for this. This blends in well with the idea of Flash being a library for fast prototyping of ML models.

This PR focuses on fixing this issue using a few ideas:

  • To pass a partial function which instantiates the optimizer or scheduler given the object it needs to wrap
  • Use registries: So in theory you can select a scheduler with a string.

How the API looks after this PR is appiled:

The optimizer of choice can be passed as a

# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.AdaDelta, eps=0.5), lr_scheduler=None)

# - Tuple[string, dict]: (The dict takes in the optimizer kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("AdaDelta", {"epa": 0.5}), lr_scheduler=None)

The scheduler of choice can be passed as a

# - String value
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule")

# - Callable
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5))

# - Tuple[string, dict]: (The dict takes in the scheduler kwargs)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10]))

You can also register you own custom scheduler recipes beforehand and use them shown as above:

@ImageClassifier.lr_schedulers
def my_steplr_recipe(optimizer):
    return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)

model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe")

This keeps the work of instantiation and setup of optimizer and scheduler from the flash library to mitigate any errors a user might make if we allowed class names to be passed.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@karthikrangasai karthikrangasai changed the title [WIP] PoC: Revamp optimizer and scheduler experience using registries [skip CI][WIP] PoC: Revamp optimizer and scheduler experience using registries Sep 20, 2021
@tchaton
Copy link
Contributor

tchaton commented Sep 27, 2021

Dear @karthikrangasai,

I added some modifications. Still wip.

Do you think we should remove entirely optimizer_kwargs and scheduler_kwargs ?
Adding support for a tuple seems slightly counter-intuive as users might want to control arguments through keyword arguments.

Best,
T.C

@codecov
Copy link

codecov bot commented Sep 27, 2021

Codecov Report

Merging #777 (35f3834) into master (a94ed6c) will decrease coverage by 6.68%.
The diff coverage is 97.12%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #777      +/-   ##
==========================================
- Coverage   85.18%   78.49%   -6.69%     
==========================================
  Files         228      230       +2     
  Lines       12566    12666     +100     
==========================================
- Hits        10704     9942     -762     
- Misses       1862     2724     +862     
Flag Coverage Δ
unittests 78.49% <97.12%> (-6.69%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/core/model.py 88.15% <93.93%> (+0.57%) ⬆️
flash/audio/speech_recognition/model.py 100.00% <100.00%> (ø)
flash/core/optimizers/__init__.py 100.00% <100.00%> (ø)
flash/core/optimizers/optimizers.py 100.00% <100.00%> (ø)
flash/core/optimizers/schedulers.py 100.00% <100.00%> (ø)
flash/core/utilities/imports.py 90.90% <100.00%> (+0.06%) ⬆️
flash/core/utilities/types.py 100.00% <100.00%> (ø)
flash/graph/classification/model.py 100.00% <100.00%> (ø)
flash/image/classification/model.py 70.45% <100.00%> (-6.15%) ⬇️
flash/image/detection/model.py 73.33% <100.00%> (-26.67%) ⬇️
... and 61 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a94ed6c...35f3834. Read the comment docs.

@karthikrangasai karthikrangasai changed the title [skip CI][WIP] PoC: Revamp optimizer and scheduler experience using registries [WIP] PoC: Revamp optimizer and scheduler experience using registries Sep 29, 2021
flash/core/model.py Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
scheduler: The scheduler or scheduler class to use.
scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
lr_scheduler: The scheduler or scheduler class to use.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the learning rate which is obselete now.

flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
@tchaton
Copy link
Contributor

tchaton commented Oct 12, 2021

Hey @SeanNaren @ethanwharris.

Should we drop learning_rate from the Tasks and force kwargs directly ?

With @karthikrangasai, we are currently thinking to remove it entirely. It will make accessibility worse but it will be fully consistent across the codebase.

Best,
T.C

@SeanNaren
Copy link
Contributor

Hey @SeanNaren @ethanwharris.

Should we drop learning_rate from the Tasks and force kwargs directly ?

With @karthikrangasai, we are currently thinking to remove it entirely. It will make accessibility worse but it will be fully consistent across the codebase.

Similar to the discussion of deprecating arguments in the Trainer, it's a convenience for the user. Do we feel it isn't warranted to add to the Task itself? Do we think asking users to just use kwargs all the time is intuitive enough to drop it? I'm personally unsure but curious what others have to say!

@mergify mergify bot removed the has conflicts label Oct 13, 2021
@mergify mergify bot removed the has conflicts label Oct 15, 2021
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After those updates, it should be good to go.

docs/source/general/optimization.rst Show resolved Hide resolved
docs/source/general/optimization.rst Outdated Show resolved Hide resolved
docs/source/general/optimization.rst Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Outdated Show resolved Hide resolved
flash/core/model.py Show resolved Hide resolved
flash/core/model.py Show resolved Hide resolved
flash/core/optimizers/schedulers.py Show resolved Hide resolved
@mergify mergify bot removed the has conflicts label Oct 18, 2021
flash/core/model.py Outdated Show resolved Hide resolved
@tchaton tchaton merged commit b41722a into Lightning-Universe:master Oct 18, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Enable Lightning Flash to support scheduler on step or monitor
3 participants