Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Dec 2, 2019
1 parent b23e701 commit 9be7293
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
5 changes: 3 additions & 2 deletions tests/lr_schedulers/test_warmup_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def test_step_scheduler(caplog):
Meta.reset()
emmental.init(dirpath)

# Test default Adam setting
# Test warmup steps
config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {
"lr_scheduler": None,
"warmup_steps": 2,
"warmup_unit": "epoch",
"warmup_unit": "batch",
},
}
}
Expand All @@ -56,6 +56,7 @@ def test_step_scheduler(caplog):
Meta.reset()
emmental.init(dirpath)

# Test warmup percentage
config = {
"learner_config": {
"n_epochs": 4,
Expand Down
66 changes: 53 additions & 13 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@ def output(module_name, immediate_output_dict):
scorer=Scorer(metrics=["accuracy"]),
)

model = EmmentalModel(name="test", tasks=task1)

assert model.name == "test"
assert model.task_names == set(["task_1"])
assert model.module_pool["m1"].module.weight.data.size() == (10, 10)
assert model.module_pool["m2"].module.weight.data.size() == (2, 10)

task1 = EmmentalTask(
new_task1 = EmmentalTask(
name="task_1",
module_pool=nn.ModuleDict(
{"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
Expand All @@ -67,11 +60,6 @@ def output(module_name, immediate_output_dict):
scorer=Scorer(metrics=["accuracy"]),
)

model.update_task(task1)

assert model.module_pool["m1"].module.weight.data.size() == (5, 10)
assert model.module_pool["m2"].module.weight.data.size() == (2, 5)

task2 = EmmentalTask(
name="task_2",
module_pool=nn.ModuleDict(
Expand All @@ -86,6 +74,58 @@ def output(module_name, immediate_output_dict):
scorer=Scorer(metrics=["accuracy"]),
)

# Test w/ dataparallel
model = EmmentalModel(name="test", tasks=task1)

assert model.name == "test"
assert model.task_names == set(["task_1"])
assert model.module_pool["m1"].module.weight.data.size() == (10, 10)
assert model.module_pool["m2"].module.weight.data.size() == (2, 10)

model.update_task(new_task1)

assert model.module_pool["m1"].module.weight.data.size() == (5, 10)
assert model.module_pool["m2"].module.weight.data.size() == (2, 5)

model.update_task(task2)

assert model.task_names == set(["task_1"])

model.add_task(task2)

assert model.task_names == set(["task_1", "task_2"])

model.remove_task("task_1")
assert model.task_names == set(["task_2"])

model.save(f"{dirpath}/saved_model.pth")

model.load(f"{dirpath}/saved_model.pth")

# Test w/o dataparallel

Meta.reset()
emmental.init(dirpath)

config = {"model_config": {"dataparallel": False}}
emmental.Meta.update_config(config)

model = EmmentalModel(name="test", tasks=task1)

assert model.name == "test"
assert model.task_names == set(["task_1"])
assert model.module_pool["m1"].weight.data.size() == (10, 10)
assert model.module_pool["m2"].weight.data.size() == (2, 10)

model.update_task(new_task1)

assert model.module_pool["m1"].weight.data.size() == (5, 10)
assert model.module_pool["m2"].weight.data.size() == (2, 5)

model.update_task(task2)

assert model.task_names == set(["task_1"])

model.add_task(task2)

assert model.task_names == set(["task_1", "task_2"])
Expand Down

0 comments on commit 9be7293

Please sign in to comment.