Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
make fine-tune not expand vocabulary by default (#1623)
Browse files Browse the repository at this point in the history
* make fine-tune not expand vocabulary by default

* fix comment

* address PR comments + add a test for vocab expansion
  • Loading branch information
joelgrus committed Aug 17, 2018
1 parent 3107a0c commit 58119c0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
41 changes: 28 additions & 13 deletions allennlp/commands/fine_tune.py
Expand Up @@ -57,6 +57,14 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
help='a JSON structure used to override the training configuration '
'(only affects the config_file, _not_ the model_archive)')

subparser.add_argument('--extend-vocab',
action='store_true',
default=False,
help='if specified, we will use the instances in your new dataset to '
'extend your vocabulary. Currently expansion of embedding layers '
'is not implemented, so if your model has an embedding layer '
'this will probably make fine-tune crash.')

subparser.add_argument('--file-friendly-logging',
action='store_true',
default=False,
Expand All @@ -75,13 +83,15 @@ def fine_tune_model_from_args(args: argparse.Namespace):
config_file=args.config_file,
serialization_dir=args.serialization_dir,
overrides=args.overrides,
extend_vocab=args.extend_vocab,
file_friendly_logging=args.file_friendly_logging)


def fine_tune_model_from_file_paths(model_archive_path: str,
config_file: str,
serialization_dir: str,
overrides: str = "",
extend_vocab: bool = False,
file_friendly_logging: bool = False) -> Model:
"""
A wrapper around :func:`fine_tune_model` which loads the model archive from a file.
Expand Down Expand Up @@ -110,12 +120,14 @@ def fine_tune_model_from_file_paths(model_archive_path: str,
return fine_tune_model(model=archive.model,
params=params,
serialization_dir=serialization_dir,
extend_vocab=extend_vocab,
file_friendly_logging=file_friendly_logging)


def fine_tune_model(model: Model,
params: Params,
serialization_dir: str,
extend_vocab: bool = False,
file_friendly_logging: bool = False) -> Model:
"""
Fine tunes the given model, using a set of parameters that is largely identical to those used
Expand All @@ -136,6 +148,8 @@ def fine_tune_model(model: Model,
The directory in which to save results and logs.
validation_data_path : ``str``, optional
Path to the validation data to use while fine-tuning.
extend_vocab: ``bool``, optional (default=False)
If ``True``, we use the new instances to extend your vocabulary.
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
Expand All @@ -159,27 +173,28 @@ def fine_tune_model(model: Model,
vocabulary_params = params.pop('vocabulary', {})
if vocabulary_params.get('directory_path', None):
logger.warning("You passed `directory_path` in parameters for the vocabulary in "
"your configuration file, but it will be ignored. "
"Vocabulary from the saved model will be extended with current data.")
"your configuration file, but it will be ignored. ")

all_datasets = datasets_from_params(params)
vocab = model.vocab

datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
if extend_vocab:
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))

for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")

logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
vocab.extend_from_instances(vocabulary_params,
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation))

logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
vocab = model.vocab
vocab.extend_from_instances(vocabulary_params,
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation))
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)
iterator.index_with(model.vocab)

train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
Expand Down
20 changes: 20 additions & 0 deletions allennlp/tests/commands/fine_tune_test.py
Expand Up @@ -28,6 +28,26 @@ def test_fine_tune_model_runs_from_file_paths(self):
config_file=self.config_file,
serialization_dir=self.serialization_dir)

def test_fine_tune_does_not_expand_vocab_by_default(self):
params = Params.from_file(self.config_file)
# snli2 has a new token in it
params["train_data_path"] = str(self.FIXTURES_ROOT / 'data' / 'snli2.jsonl')

model = load_archive(self.model_archive).model

# By default, no vocab expansion.
fine_tune_model(model, params, self.serialization_dir)

def test_fine_tune_runtime_errors_with_vocab_expansion(self):
params = Params.from_file(self.config_file)
params["train_data_path"] = str(self.FIXTURES_ROOT / 'data' / 'snli2.jsonl')

model = load_archive(self.model_archive).model

# If we do vocab expansion, we get a runtime error because of the embedding.
with pytest.raises(RuntimeError):
fine_tune_model(model, params, self.serialization_dir, extend_vocab=True)

def test_fine_tune_runs_from_parser_arguments(self):
raw_args = ["fine-tune",
"-m", self.model_archive,
Expand Down
1 change: 1 addition & 0 deletions allennlp/tests/fixtures/data/snli2.jsonl
@@ -0,0 +1 @@
{"annotator_labels": ["neutral"],"captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a seahorse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN seahorse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his seahorse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN seahorse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"}

0 comments on commit 58119c0

Please sign in to comment.