Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Oct 4, 2021
1 parent 702cb62 commit 4553e8b
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions tests/e2e/test_e2e_skip_trained.py
@@ -1,4 +1,4 @@
"""Emmental e2e with skipping learned."""
"""Emmental e2e with skipping trained data."""
import logging
import shutil
from functools import partial
Expand Down Expand Up @@ -27,7 +27,7 @@ def test_e2e_skip_trained(caplog):
caplog.set_level(logging.INFO)

dirpath = "temp_test_e2e_skip_trained"
use_exact_log_path = False
use_exact_log_path = True
Meta.reset()
init(dirpath, use_exact_log_path=use_exact_log_path)

Expand Down Expand Up @@ -149,24 +149,24 @@ def forward(self, input):
]
# Build model

mtl_model = EmmentalModel(name="all", tasks=tasks)
model = EmmentalModel(name="all", tasks=tasks)

# Create learner
emmental_learner = EmmentalLearner()

config = {
"meta_config": {"seed": 0, "verbose": False},
"meta_config": {"seed": 0, "verbose": True},
"learner_config": {
"n_steps": 200,
"n_steps": 10,
"epochs_learned": 0,
"steps_learned": 130,
"skip_learned_data": True,
"steps_learned": 0,
"skip_learned_data": False,
"online_eval": True,
"optimizer_config": {"lr": 0.01, "grad_clip": 100},
},
"logging_config": {
"counter_unit": "epoch",
"evaluation_freq": 0.2,
"counter_unit": "batch",
"evaluation_freq": 5,
"writer_config": {"writer": "json", "verbose": True},
"checkpointing": True,
"checkpointer_config": {
Expand All @@ -177,39 +177,48 @@ def forward(self, input):
"checkpoint_runway": 1,
"checkpoint_all": False,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": True,
"clear_all_checkpoints": False,
},
},
}
Meta.update_config(config)

# Learning
emmental_learner.learn(
mtl_model,
model,
[train_dataloader, dev_dataloader],
)

test_score = mtl_model.score(test_dataloader)
test_score = model.score(test_dataloader)

assert test_score["task1/synthetic/test/loss"] > 0.1
assert test_score["task1/synthetic/test/loss"] > 0.4

Meta.reset()
init(dirpath, use_exact_log_path=use_exact_log_path)

config = {
"meta_config": {"seed": 0, "verbose": False},
"meta_config": {"seed": 0, "verbose": True},
"learner_config": {
"n_steps": 200,
"n_steps": 40,
"epochs_learned": 0,
"steps_learned": 0,
"skip_learned_data": False,
"steps_learned": 10,
"skip_learned_data": True,
"online_eval": True,
"optimizer_config": {"lr": 0.01, "grad_clip": 100},
"optimizer_path": (
f"{dirpath}/" "best_model_model_all_train_loss.optimizer.pth"
),
"scheduler_path": (
f"{dirpath}/" "best_model_model_all_train_loss.scheduler.pth"
),
},
"model_config": {
"model_path": f"{dirpath}/best_model_model_all_train_loss.model.pth"
},
"logging_config": {
"counter_unit": "epoch",
"evaluation_freq": 0.2,
"writer_config": {"writer": "tensorboard", "verbose": True},
"counter_unit": "batch",
"evaluation_freq": 5,
"writer_config": {"writer": "json", "verbose": True},
"checkpointing": True,
"checkpointer_config": {
"checkpoint_path": None,
Expand All @@ -219,20 +228,23 @@ def forward(self, input):
"checkpoint_runway": 1,
"checkpoint_all": False,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": True,
"clear_all_checkpoints": False,
},
},
}
Meta.update_config(config)

if Meta.config["model_config"]["model_path"]:
model.load(Meta.config["model_config"]["model_path"])

# Learning
emmental_learner.learn(
mtl_model,
model,
[train_dataloader, dev_dataloader],
)

test_score = mtl_model.score(test_dataloader)
test_score = model.score(test_dataloader)

assert test_score["task1/synthetic/test/loss"] <= 0.1
assert test_score["task1/synthetic/test/loss"] <= 0.4

shutil.rmtree(dirpath)

0 comments on commit 4553e8b

Please sign in to comment.