diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index f27d4fbd3f920e..f8df8fb70c1df3 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -54,7 +54,10 @@ } # fmt: off -FAIRSEQ_LANGUAGE_CODES = ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"] +FAIRSEQ_LANGUAGE_CODES = { + "m2m100": ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"], + "wmt21": ['en', 'ha', 'is', 'ja', 'cs', 'ru', 'zh', 'de'] +} # fmt: on @@ -86,6 +89,8 @@ class M2M100Tokenizer(PreTrainedTokenizer): token instead. pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): The token used for padding, for example when batching sequences of different lengths. + language_codes (:obj:`str`, `optional`, defaults to :obj:`"m2m100"`): + What language codes to use. Should be one of :obj:`"m2m100"` or :obj:`"wmt21"`. sp_model_kwargs (:obj:`dict`, `optional`): Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece `__ can be used, among other things, to set: @@ -132,17 +137,21 @@ def __init__( sep_token="", pad_token="", unk_token="", + language_codes="m2m100", sp_model_kwargs: Optional[Dict[str, Any]] = None, + num_madeup_words=8, **kwargs, ) -> None: self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in FAIRSEQ_LANGUAGE_CODES} + self.language_codes = language_codes + fairseq_language_code = FAIRSEQ_LANGUAGE_CODES[language_codes] + self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in fairseq_language_code} kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) kwargs["additional_special_tokens"] += [ self.get_lang_token(lang_code) - for lang_code in FAIRSEQ_LANGUAGE_CODES + for lang_code in fairseq_language_code if self.get_lang_token(lang_code) not in kwargs["additional_special_tokens"] ] @@ -154,7 +163,9 @@ def __init__( sep_token=sep_token, unk_token=unk_token, pad_token=pad_token, + language_codes=language_codes, sp_model_kwargs=self.sp_model_kwargs, + num_madeup_words=num_madeup_words, **kwargs, ) @@ -167,9 +178,9 @@ def __init__( self.encoder_size = len(self.encoder) self.lang_token_to_id = { - self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES) + self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code) } - self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)} + self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)} self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()} self._src_lang = src_lang if src_lang is not None else "en" @@ -177,7 +188,7 @@ def __init__( self.cur_lang_id = self.get_lang_id(self._src_lang) self.set_src_lang_special_tokens(self._src_lang) - self.num_madeup_words = 8 + self.num_madeup_words = num_madeup_words @property def vocab_size(self) -> int: