Skip to content
Permalink
Browse files

feature-enhancement: make trainer registrable (#1884)

* Update trainer.py

I make `Trainer` be `Registrable` and make itself be the default_implementation.

* Update trainer.py

I modify the classmethod `from_params` of Trainer to return a `cls(...)` instead of `Trainer(...)`.

* decide Trainer class by config_file

* minor-bug-fix: do not pop trainer twice

* minor-bug-fix: do not pop `trainer` twice.

* add `type` and `pylint` comments to pass them

* call Trainer.from_params using keyword arguments

* call Trainer.from_params using keyword arguments

* remove unnecessary whitespace

* Pylint and mypy, minor other fixes
  • Loading branch information
WrRan authored and matt-gardner committed Oct 10, 2018
1 parent 91bfb4c commit 24e5547b55e26d8a6b36ef32c1e965f24f0970a1
Showing with 54 additions and 36 deletions.
  1. +16 −6 allennlp/commands/fine_tune.py
  2. +10 −7 allennlp/commands/train.py
  3. +28 −23 allennlp/training/trainer.py
@@ -195,6 +195,12 @@ def fine_tune_model(model: Model,

iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(model.vocab)
validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
validation_iterator = DataIterator.from_params(validation_iterator_params)
validation_iterator.index_with(vocab)
else:
validation_iterator = None

train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
@@ -215,12 +221,16 @@ def fine_tune_model(model: Model,
for name in tunable_parameter_names:
logger.info(name)

trainer = Trainer.from_params(model,
serialization_dir,
iterator,
train_data,
validation_data,
trainer_params)
trainer_choice = trainer_params.pop_choice("type",
Trainer.list_available(),
default_to_first_choice=True)
trainer = Trainer.by_name(trainer_choice).from_params(model=model,
serialization_dir=serialization_dir,
iterator=iterator,
train_data=train_data,
validation_data=validation_data,
params=trainer_params,
validation_iterator=validation_iterator)

evaluate_on_test = params.pop_bool("evaluate_on_test", False)
params.assert_empty('base train command')
@@ -306,13 +306,16 @@ def train_model(params: Params,
for name in tunable_parameter_names:
logger.info(name)

trainer = Trainer.from_params(model,
serialization_dir,
iterator,
train_data,
validation_data,
trainer_params,
validation_iterator=validation_iterator)
trainer_choice = trainer_params.pop_choice("type",
Trainer.list_available(),
default_to_first_choice=True)
trainer = Trainer.by_name(trainer_choice).from_params(model=model,
serialization_dir=serialization_dir,
iterator=iterator,
train_data=train_data,
validation_data=validation_data,
params=trainer_params,
validation_iterator=validation_iterator)

evaluate_on_test = params.pop_bool("evaluate_on_test", False)
params.assert_empty('base train command')
@@ -23,7 +23,7 @@
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from tensorboardX import SummaryWriter

from allennlp.common import Params
from allennlp.common import Params, Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import peak_memory_mb, gpu_memory_mb, dump_metrics
from allennlp.common.tqdm import Tqdm
@@ -156,7 +156,9 @@ def str_to_time(time_str: str) -> datetime.datetime:
return datetime.datetime(*pieces)


class Trainer:
class Trainer(Registrable):
default_implementation = "default"

def __init__(self,
model: Model,
optimizer: torch.optim.Optimizer,
@@ -998,15 +1000,15 @@ def _restore_checkpoint(self) -> Tuple[int, List[float]]:

# Requires custom from_params.
@classmethod
def from_params(cls,
def from_params(cls, # type: ignore
model: Model,
serialization_dir: str,
iterator: DataIterator,
train_data: Iterable[Instance],
validation_data: Optional[Iterable[Instance]],
params: Params,
validation_iterator: DataIterator = None) -> 'Trainer':

# pylint: disable=arguments-differ
patience = params.pop_int("patience", None)
validation_metric = params.pop("validation_metric", "-loss")
shuffle = params.pop_bool("shuffle", True)
@@ -1036,22 +1038,25 @@ def from_params(cls,
should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)

params.assert_empty(cls.__name__)
return Trainer(model, optimizer, iterator,
train_data, validation_data,
patience=patience,
validation_metric=validation_metric,
validation_iterator=validation_iterator,
shuffle=shuffle,
num_epochs=num_epochs,
serialization_dir=serialization_dir,
cuda_device=cuda_device,
grad_norm=grad_norm,
grad_clipping=grad_clipping,
learning_rate_scheduler=scheduler,
num_serialized_models_to_keep=num_serialized_models_to_keep,
keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
model_save_interval=model_save_interval,
summary_interval=summary_interval,
histogram_interval=histogram_interval,
should_log_parameter_statistics=should_log_parameter_statistics,
should_log_learning_rate=should_log_learning_rate)
return cls(model, optimizer, iterator,
train_data, validation_data,
patience=patience,
validation_metric=validation_metric,
validation_iterator=validation_iterator,
shuffle=shuffle,
num_epochs=num_epochs,
serialization_dir=serialization_dir,
cuda_device=cuda_device,
grad_norm=grad_norm,
grad_clipping=grad_clipping,
learning_rate_scheduler=scheduler,
num_serialized_models_to_keep=num_serialized_models_to_keep,
keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
model_save_interval=model_save_interval,
summary_interval=summary_interval,
histogram_interval=histogram_interval,
should_log_parameter_statistics=should_log_parameter_statistics,
should_log_learning_rate=should_log_learning_rate)


Trainer.register("default")(Trainer)

0 comments on commit 24e5547

Please sign in to comment.
You can’t perform that action at this time.