Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Multilingual parser and Cross-lingual ELMo (#2628)
Browse files Browse the repository at this point in the history
* Multilingual version of biaffine dependency parser with elmo alignment

* alignment addition to elmo embedder

* dataset reader for multiple sources

* iterator that groups samples from the same language

* Clean multilang biaffine

* Inherit unmodified functions from the biaffine parser.

* Also using defaultdict for same_lang_iterator.

* Multi-lang dep config example

* Jsonnet format

* Larger softmax value

The previous one caused overflow sometimes. e-10 is good enough

* formating

* reorganize multilang dataset reader to work with a pathname

* factoring biaffine parser to prevent duplicating code

* multilangTokenEmbedder interface

* fix parser factorization and pylint staff

* more pylint

* inspect embedder and fix params_test

* make mypy happy

* cr comments and doc

* doc

* fix doc

* Multilingual tests (#4)

* Added dependencies data for es, fr and it + json configuration for tests

* Tests for multilingual UD reader and same-language iterator

* Tests for multilingual dependency parser

* fixed some of the tests - not final

* use not lazy option in the parser test

* better doc

* test basic text field emb

* pylint

* multilingual embedder test

* cr comments

* new link
  • Loading branch information
TalSchuster authored and matt-gardner committed Jun 12, 2019
1 parent 92ee421 commit da16ad1
Show file tree
Hide file tree
Showing 41 changed files with 1,430 additions and 51 deletions.
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from allennlp.data.dataset_readers.sequence_tagging import SequenceTaggingDatasetReader
from allennlp.data.dataset_readers.snli import SnliReader
from allennlp.data.dataset_readers.universal_dependencies import UniversalDependenciesDatasetReader
from allennlp.data.dataset_readers.universal_dependencies_multilang import UniversalDependenciesMultiLangDatasetReader
from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import (
StanfordSentimentTreeBankDatasetReader)
from allennlp.data.dataset_readers.quora_paraphrase import QuoraParaphraseDatasetReader
Expand Down
192 changes: 192 additions & 0 deletions allennlp/data/dataset_readers/universal_dependencies_multilang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Dict, Tuple, List, Iterator, Any
import logging
import itertools
import glob
import os
import numpy as np

from overrides import overrides

from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.dataset_readers.universal_dependencies import lazy_parse

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def get_file_paths(pathname: str, languages: List[str]):
"""
Gets a list of all files by the pathname with the given language ids.
Filenames are assumed to have the language identifier followed by a dash
as a prefix (e.g. en-universal.conll).
Parameters
----------
pathname : ``str``, required.
An absolute or relative pathname (can contain shell-style wildcards)
languages : ``List[str]``, required
The language identifiers to use.
Returns
-------
A list of tuples (language id, file path).
"""
paths = []
for file_path in glob.glob(pathname):
base = os.path.splitext(os.path.basename(file_path))[0]
lang_id = base.split('-')[0]
if lang_id in languages:
paths.append((lang_id, file_path))

if not paths:
raise ConfigurationError("No dataset files to read")

return paths


@DatasetReader.register("universal_dependencies_multilang")
class UniversalDependenciesMultiLangDatasetReader(DatasetReader):
"""
Reads multiple files in the conllu Universal Dependencies format.
All files should be in the same directory and the filenames should have
the language identifier followed by a dash as a prefix (e.g. en-universal.conll)
When using the alternate option, the reader alternates randomly between
the files every instances_per_file. The is_first_pass_for_vocab disables
this behaviour for the first pass (could be useful for a single full path
over the dataset in order to generate a vocabulary).
Notice: when using the alternate option, one should also use the ``instances_per_epoch``
option for the iterator. Otherwise, each epoch will loop infinitely.
Parameters
----------
languages : ``List[str]``, required
The language identifiers to use.
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
The token indexers to be applied to the words TextField.
use_language_specific_pos : ``bool``, optional (default = False)
Whether to use UD POS tags, or to use the language specific POS tags
provided in the conllu format.
alternate : ``bool``, optional (default = True)
Whether to alternate between input files.
is_first_pass_for_vocab : ``bool``, optional (default = True)
Whether the first pass will be for generating the vocab. If true,
the first pass will run over the entire dataset of each file (even if alternate is on).
instances_per_file : ``int``, optional (default = 32)
The amount of consecutive cases to sample from each input file when alternating.
"""
def __init__(self,
languages: List[str],
token_indexers: Dict[str, TokenIndexer] = None,
use_language_specific_pos: bool = False,
lazy: bool = False,
alternate: bool = True,
is_first_pass_for_vocab: bool = True,
instances_per_file: int = 32) -> None:
super().__init__(lazy)
self._languages = languages
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
self._use_language_specific_pos = use_language_specific_pos

self._is_first_pass_for_vocab = is_first_pass_for_vocab
self._alternate = alternate
self._instances_per_file = instances_per_file

self._is_first_pass = True
self._iterators: List[Tuple[str, Iterator[Any]]] = None

def _read_one_file(self, lang: str, file_path: str):
with open(file_path, 'r') as conllu_file:
logger.info("Reading UD instances for %s language from conllu dataset at: %s", lang, file_path)

for annotation in lazy_parse(conllu_file.read()):
# CoNLLU annotations sometimes add back in words that have been elided
# in the original sentence; we remove these, as we're just predicting
# dependencies for the original sentence.
# We filter by None here as elided words have a non-integer word id,
# and are replaced with None by the conllu python library.
annotation = [x for x in annotation if x["id"] is not None]

heads = [x["head"] for x in annotation]
tags = [x["deprel"] for x in annotation]
words = [x["form"] for x in annotation]
if self._use_language_specific_pos:
pos_tags = [x["xpostag"] for x in annotation]
else:
pos_tags = [x["upostag"] for x in annotation]
yield self.text_to_instance(lang, words, pos_tags, list(zip(tags, heads)))

@overrides
def _read(self, file_path: str):
file_paths = get_file_paths(file_path, self._languages)
if (self._is_first_pass and self._is_first_pass_for_vocab) or (not self._alternate):
iterators = [iter(self._read_one_file(lang, file_path))
for (lang, file_path) in file_paths]
self._is_first_pass = False
for inst in itertools.chain(*iterators):
yield inst

else:
if self._iterators is None:
self._iterators = [(lang, iter(self._read_one_file(lang, file_path)))
for (lang, file_path) in file_paths]
num_files = len(file_paths)
while True:
ind = np.random.randint(num_files)
lang, lang_iter = self._iterators[ind]
for _ in range(self._instances_per_file):
try:
yield lang_iter.__next__()
except StopIteration:
lang, file_path = file_paths[ind]
lang_iter = iter(self._read_one_file(lang, file_path))
self._iterators[ind] = (lang, lang_iter)
yield lang_iter.__next__()

@overrides
def text_to_instance(self, # type: ignore
lang: str,
words: List[str],
upos_tags: List[str],
dependencies: List[Tuple[str, int]] = None) -> Instance:
# pylint: disable=arguments-differ
"""
Parameters
----------
lang : ``str``, required.
The language identifier.
words : ``List[str]``, required.
The words in the sentence to be encoded.
upos_tags : ``List[str]``, required.
The universal dependencies POS tags for each word.
dependencies ``List[Tuple[str, int]]``, optional (default = None)
A list of (head tag, head index) tuples. Indices are 1 indexed,
meaning an index of 0 corresponds to that word being the root of
the dependency tree.
Returns
-------
An instance containing words, upos tags, dependency head tags and head
indices as fields. The language identifier is stored in the metadata.
"""
fields: Dict[str, Field] = {}

tokens = TextField([Token(w) for w in words], self._token_indexers)
fields["words"] = tokens
fields["pos_tags"] = SequenceLabelField(upos_tags, tokens, label_namespace="pos")
if dependencies is not None:
# We don't want to expand the label namespace with an additional dummy token, so we'll
# always give the 'ROOT_HEAD' token a label of 'root'.
fields["head_tags"] = SequenceLabelField([x[0] for x in dependencies],
tokens,
label_namespace="head_tags")
fields["head_indices"] = SequenceLabelField([int(x[1]) for x in dependencies],
tokens,
label_namespace="head_index_tags")

fields["metadata"] = MetadataField({"words": words, "pos": upos_tags, "lang": lang})
return Instance(fields)
1 change: 1 addition & 0 deletions allennlp/data/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator
from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator
from allennlp.data.iterators.same_language_iterator import SameLanguageIterator
47 changes: 47 additions & 0 deletions allennlp/data/iterators/same_language_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from collections import deque, defaultdict
from typing import Iterable, Deque
import logging
import random

from allennlp.common.util import lazy_groups_of
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.dataset import Batch

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

def split_by_language(instance_list):
insts_by_lang = defaultdict(lambda: [])
for inst in instance_list:
inst_lang = inst.fields['metadata'].metadata['lang']
insts_by_lang[inst_lang].append(inst)

return iter(insts_by_lang.values())

@DataIterator.register("same_language")
class SameLanguageIterator(BucketIterator):
"""
Splits batches into batches containing the same language.
The language of each instance is determined by looking at the 'lang' value
in the metadata.
It takes the same parameters as :class:`allennlp.data.iterators.BucketIterator`
"""
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
# First break the dataset into memory-sized lists:
for instance_list in self._memory_sized_lists(instances):
if shuffle:
random.shuffle(instance_list)
instance_list = split_by_language(instance_list)
for same_lang_batch in instance_list:
iterator = iter(same_lang_batch)
excess: Deque[Instance] = deque()
# Then break each memory-sized list into batches.
for batch_instances in lazy_groups_of(iterator, self._batch_size):
for poss_smaller_batches in self._ensure_batch_is_sufficiently_small( # type: ignore
batch_instances, excess):
batch = Batch(poss_smaller_batches)
yield batch
if excess:
yield Batch(excess)
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from allennlp.models.biattentive_classification_network import BiattentiveClassificationNetwork
from allennlp.models.constituency_parser import SpanConstituencyParser
from allennlp.models.biaffine_dependency_parser import BiaffineDependencyParser
from allennlp.models.biaffine_dependency_parser_multilang import BiaffineDependencyParserMultiLang
from allennlp.models.coreference_resolution.coref import CoreferenceResolver
from allennlp.models.crf_tagger import CrfTagger
from allennlp.models.decomposable_attention import DecomposableAttention
Expand Down
104 changes: 59 additions & 45 deletions allennlp/models/biaffine_dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,64 @@ def forward(self, # type: ignore
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

mask = get_text_field_mask(words)

predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse(
embedded_text_input, mask, head_tags, head_indices)

loss = arc_nll + tag_nll

if head_indices is not None and head_tags is not None:
evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
# We calculate attatchment scores for the whole sentence
# but excluding the symbolic ROOT token at the start,
# which is why we start from the second element in the sequence.
self._attachment_scores(predicted_heads[:, 1:],
predicted_head_tags[:, 1:],
head_indices,
head_tags,
evaluation_mask)

output_dict = {
"heads": predicted_heads,
"head_tags": predicted_head_tags,
"arc_loss": arc_nll,
"tag_loss": tag_nll,
"loss": loss,
"mask": mask,
"words": [meta["words"] for meta in metadata],
"pos": [meta["pos"] for meta in metadata]
}

return output_dict

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
heads = output_dict.pop("heads").cpu().detach().numpy()
mask = output_dict.pop("mask")
lengths = get_lengths_from_binary_sequence_mask(mask)
head_tag_labels = []
head_indices = []
for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
instance_heads = list(instance_heads[1:length])
instance_tags = instance_tags[1:length]
labels = [self.vocab.get_token_from_index(label, "head_tags")
for label in instance_tags]
head_tag_labels.append(labels)
head_indices.append(instance_heads)

output_dict["predicted_dependencies"] = head_tag_labels
output_dict["predicted_heads"] = head_indices
return output_dict

def _parse(self,
embedded_text_input: torch.Tensor,
mask: torch.LongTensor,
head_tags: torch.LongTensor = None,
head_indices: torch.LongTensor = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

embedded_text_input = self._input_dropout(embedded_text_input)
encoded_text = self.encoder(embedded_text_input, mask)

Expand Down Expand Up @@ -258,59 +316,15 @@ def forward(self, # type: ignore
head_indices=head_indices,
head_tags=head_tags,
mask=mask)
loss = arc_nll + tag_nll

evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
# We calculate attatchment scores for the whole sentence
# but excluding the symbolic ROOT token at the start,
# which is why we start from the second element in the sequence.
self._attachment_scores(predicted_heads[:, 1:],
predicted_head_tags[:, 1:],
head_indices[:, 1:],
head_tags[:, 1:],
evaluation_mask)
else:
arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
child_tag_representation=child_tag_representation,
attended_arcs=attended_arcs,
head_indices=predicted_heads.long(),
head_tags=predicted_head_tags.long(),
mask=mask)
loss = arc_nll + tag_nll

output_dict = {
"heads": predicted_heads,
"head_tags": predicted_head_tags,
"arc_loss": arc_nll,
"tag_loss": tag_nll,
"loss": loss,
"mask": mask,
"words": [meta["words"] for meta in metadata],
"pos": [meta["pos"] for meta in metadata]
}

return output_dict

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
heads = output_dict.pop("heads").cpu().detach().numpy()
mask = output_dict.pop("mask")
lengths = get_lengths_from_binary_sequence_mask(mask)
head_tag_labels = []
head_indices = []
for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
instance_heads = list(instance_heads[1:length])
instance_tags = instance_tags[1:length]
labels = [self.vocab.get_token_from_index(label, "head_tags")
for label in instance_tags]
head_tag_labels.append(labels)
head_indices.append(instance_heads)

output_dict["predicted_dependencies"] = head_tag_labels
output_dict["predicted_heads"] = head_indices
return output_dict
return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll

def _construct_loss(self,
head_tag_representation: torch.Tensor,
Expand Down

0 comments on commit da16ad1

Please sign in to comment.