Skip to content

Commit

Permalink
Merge 9be7293 into 65fc6ac
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Dec 2, 2019
2 parents 65fc6ac + 9be7293 commit 5fd2ec7
Show file tree
Hide file tree
Showing 12 changed files with 429 additions and 17 deletions.
1 change: 0 additions & 1 deletion src/emmental/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def _update_lr_scheduler(self, model: EmmentalModel, step: int) -> None:
self.lr_scheduler.step() # type: ignore
elif (
opt in ["step", "multi_step"]
and step > 0
and (step + 1) % self.n_batches_per_epoch == 0
):
self.lr_scheduler.step() # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion src/emmental/utils/parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def parse_args(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
)

learner_config.add_argument(
"--n_epochs", type=int, default=3, help="Total number of learning epochs"
"--n_epochs", type=int, default=1, help="Total number of learning epochs"
)

learner_config.add_argument(
Expand Down
Empty file added tests/lr_schedulers/__init__.py
Empty file.
57 changes: 57 additions & 0 deletions tests/lr_schedulers/test_cosine_annealing_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import shutil

import torch.nn as nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner

logger = logging.getLogger(__name__)


def test_cosine_annealing_scheduler(caplog):
"""Unit test of cosine annealing scheduler"""

caplog.set_level(logging.INFO)

lr_scheduler = "cosine_annealing"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()

Meta.reset()
emmental.init(dirpath)

config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {"lr_scheduler": lr_scheduler},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

emmental_learner._update_lr_scheduler(model, 0)
assert (
abs(emmental_learner.optimizer.param_groups[0]["lr"] - 8.535533905932738) < 1e-5
)

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert (
abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1.4644660940672627)
< 1e-5
)

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"]) < 1e-5

shutil.rmtree(dirpath)
55 changes: 55 additions & 0 deletions tests/lr_schedulers/test_exponential_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
import shutil

import torch.nn as nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner

logger = logging.getLogger(__name__)


def test_exponential_scheduler(caplog):
"""Unit test of exponential scheduler"""

caplog.set_level(logging.INFO)

lr_scheduler = "exponential"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()

Meta.reset()
emmental.init(dirpath)

config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {
"lr_scheduler": lr_scheduler,
"exponential_config": {"gamma": 0.1},
},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

emmental_learner._update_lr_scheduler(model, 0)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.1) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.01) < 1e-5

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.001) < 1e-5

shutil.rmtree(dirpath)
52 changes: 52 additions & 0 deletions tests/lr_schedulers/test_linear_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
import shutil

import torch.nn as nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner

logger = logging.getLogger(__name__)


def test_linear_scheduler(caplog):
"""Unit test of linear scheduler"""

caplog.set_level(logging.INFO)

lr_scheduler = "linear"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()

Meta.reset()
emmental.init(dirpath)

config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {"lr_scheduler": lr_scheduler},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

emmental_learner._update_lr_scheduler(model, 0)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 7.5) < 1e-5

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 2.5) < 1e-5

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"]) < 1e-5

shutil.rmtree(dirpath)
59 changes: 59 additions & 0 deletions tests/lr_schedulers/test_multi_step_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
import shutil

import torch.nn as nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner

logger = logging.getLogger(__name__)


def test_multi_step_scheduler(caplog):
"""Unit test of multi step scheduler"""

caplog.set_level(logging.INFO)

lr_scheduler = "multi_step"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()

Meta.reset()
emmental.init(dirpath)

config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {
"lr_scheduler": lr_scheduler,
"multi_step_config": {
"milestones": [1, 3],
"gamma": 0.1,
"last_epoch": -1,
},
},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

emmental_learner._update_lr_scheduler(model, 0)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.1) < 1e-5

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.1) < 1e-5

shutil.rmtree(dirpath)
55 changes: 55 additions & 0 deletions tests/lr_schedulers/test_step_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
import shutil

import torch.nn as nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner

logger = logging.getLogger(__name__)


def test_step_scheduler(caplog):
"""Unit test of step scheduler"""

caplog.set_level(logging.INFO)

lr_scheduler = "step"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()

Meta.reset()
emmental.init(dirpath)

config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {
"lr_scheduler": lr_scheduler,
"step_config": {"step_size": 2, "gamma": 0.1, "last_epoch": -1},
},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

emmental_learner._update_lr_scheduler(model, 0)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.1) < 1e-5

shutil.rmtree(dirpath)
90 changes: 90 additions & 0 deletions tests/lr_schedulers/test_warmup_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import logging
import shutil

import torch.nn as nn

import emmental
from emmental import Meta
from emmental.learner import EmmentalLearner

logger = logging.getLogger(__name__)


def test_step_scheduler(caplog):
"""Unit test of step scheduler"""

caplog.set_level(logging.INFO)

dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()

Meta.reset()
emmental.init(dirpath)

# Test warmup steps
config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {
"lr_scheduler": None,
"warmup_steps": 2,
"warmup_unit": "batch",
},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 0

emmental_learner._update_lr_scheduler(model, 0)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

Meta.reset()
emmental.init(dirpath)

# Test warmup percentage
config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {
"lr_scheduler": None,
"warmup_percentage": 0.5,
"warmup_unit": "epoch",
},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)

assert emmental_learner.optimizer.param_groups[0]["lr"] == 0

emmental_learner._update_lr_scheduler(model, 0)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

emmental_learner._update_lr_scheduler(model, 1)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

emmental_learner._update_lr_scheduler(model, 2)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

emmental_learner._update_lr_scheduler(model, 3)
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

shutil.rmtree(dirpath)
Loading

0 comments on commit 5fd2ec7

Please sign in to comment.