Skip to content

Commit

Permalink
support wmt21 tokenizer in m2m100 tokenizer (huggingface#14376)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj authored and Alberto Bégué committed Jan 27, 2022
1 parent 3ad9301 commit be98287
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/transformers/models/m2m_100/tokenization_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
}

# fmt: off
FAIRSEQ_LANGUAGE_CODES = ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"]
FAIRSEQ_LANGUAGE_CODES = {
"m2m100": ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"],
"wmt21": ['en', 'ha', 'is', 'ja', 'cs', 'ru', 'zh', 'de']
}
# fmt: on


Expand Down Expand Up @@ -86,6 +89,8 @@ class M2M100Tokenizer(PreTrainedTokenizer):
token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
language_codes (:obj:`str`, `optional`, defaults to :obj:`"m2m100"`):
What language codes to use. Should be one of :obj:`"m2m100"` or :obj:`"wmt21"`.
sp_model_kwargs (:obj:`dict`, `optional`):
Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece
<https://github.com/google/sentencepiece/tree/master/python>`__ can be used, among other things, to set:
Expand Down Expand Up @@ -132,17 +137,21 @@ def __init__(
sep_token="</s>",
pad_token="<pad>",
unk_token="<unk>",
language_codes="m2m100",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
num_madeup_words=8,
**kwargs,
) -> None:
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in FAIRSEQ_LANGUAGE_CODES}
self.language_codes = language_codes
fairseq_language_code = FAIRSEQ_LANGUAGE_CODES[language_codes]
self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in fairseq_language_code}

kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
kwargs["additional_special_tokens"] += [
self.get_lang_token(lang_code)
for lang_code in FAIRSEQ_LANGUAGE_CODES
for lang_code in fairseq_language_code
if self.get_lang_token(lang_code) not in kwargs["additional_special_tokens"]
]

Expand All @@ -154,7 +163,9 @@ def __init__(
sep_token=sep_token,
unk_token=unk_token,
pad_token=pad_token,
language_codes=language_codes,
sp_model_kwargs=self.sp_model_kwargs,
num_madeup_words=num_madeup_words,
**kwargs,
)

Expand All @@ -167,17 +178,17 @@ def __init__(
self.encoder_size = len(self.encoder)

self.lang_token_to_id = {
self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)
self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)
}
self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)}
self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)}
self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()}

self._src_lang = src_lang if src_lang is not None else "en"
self.tgt_lang = tgt_lang
self.cur_lang_id = self.get_lang_id(self._src_lang)
self.set_src_lang_special_tokens(self._src_lang)

self.num_madeup_words = 8
self.num_madeup_words = num_madeup_words

@property
def vocab_size(self) -> int:
Expand Down

0 comments on commit be98287

Please sign in to comment.