Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Nemo AutoTokenizer & SentencePieceTokenizer API for TensorRT-LLM & AMMO evaluation scripts usage #8818

Closed
wants to merge 8 commits into from
25 changes: 25 additions & 0 deletions nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,18 @@ def ids_to_text(self, ids):
text = self.tokens_to_text(tokens_clean)
return text

def encode(self, *args, **kwargs):
return self.tokenizer.encode(*args, **kwargs)

def batch_encode_plus(self, *args, **kwargs):
return self.tokenizer.batch_encode_plus(*args, **kwargs)

def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)

def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)

@property
def vocab(self):
id2vocab = {v: k for k, v in self.tokenizer.vocab.items()}
Expand All @@ -241,6 +253,19 @@ def eos_id(self):
return None
return self.tokens_to_ids([getattr(self, 'eos_token')])[0]

@property
def pad_token_id(self):
return self.pad_id

@pad_token_id.setter
def pad_token_id(self, value: int):
self.pad_token = self.ids_to_tokens(value)
self.add_special_tokens({'pad_token': self.pad_token})

@property
def eos_token_id(self):
return self.eos_id

@property
def eod(self):
"""Returns EOS token id. Exact copy of the eos_id function. Required for megatron-core."""
Expand Down
53 changes: 52 additions & 1 deletion nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import sentencepiece
import torch

from nemo.collections.common.parts.utils import if_exist
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
Expand All @@ -28,7 +29,7 @@
class SentencePieceTokenizer(TokenizerSpec):
"""
Sentencepiecetokenizer https://github.com/google/sentencepiece.

Args:
model_path: path to sentence piece tokenizer model. To create the model use create_spt_model()
special_tokens: either list of special tokens or dictionary of token name to token value
Expand Down Expand Up @@ -165,6 +166,46 @@ def tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
ids.append(self.token_to_id(token))
return ids

def encode(
self,
text: Union[str, List[str]],
return_tensors: Optional[str] = None,
max_length: Optional[int] = None,
*args,
**kwargs,
):
# Note: keyword arguments other than return_tensors and max_length are ignored.
assert not self.legacy, "Legacy implementation is not available."
assert return_tensors in {None, "pt"}, "Only returning plain list or PyTorch tensor is enabled"
output = self.tokenizer.encode_as_ids(text)
if max_length is not None:
if isinstance(text, str):
output = output[:max_length]
if isinstance(text, list):
output = [x[:max_length] for x in output]
if return_tensors == "pt":
# Only plain text input is supported since for list of strings some padding needs to be introduced
assert isinstance(text, str), "Returning 'pt' tensors is only supported for simple text input"
output = torch.LongTensor(output).reshape((1, -1))
return output

def batch_encode_plus(self, texts, *args, **kwargs):
# Note: keyword arguments are ignored.
assert not self.legacy, "Legacy implementation is not available."
assert isinstance(texts, list), f"Expected list of texts, got {type(texts).__name__}: {texts}"
return {"input_ids": self.tokenizer.encode_as_ids(texts)}

def decode(self, ids: Union[List[int], List[List[int]], np.ndarray, torch.Tensor], *args, **kwargs):
# Note: keyword arguments are ignored.
assert not self.legacy, "Legacy implementation is not available."
if isinstance(ids, np.ndarray) or torch.is_tensor(ids):
ids = ids.tolist()
return self.tokenizer.decode(ids)

def batch_decode(self, ids: List[List[int]], *args, **kwargs):
assert not self.legacy, "Legacy implementation is not available."
return self.decode(ids, **kwargs)

def add_special_tokens(self, special_tokens):
if not self.legacy:
raise AttributeError("Special Token addition does not work when legacy is set to False.")
Expand Down Expand Up @@ -197,6 +238,11 @@ def pad_id(self):
pad_id = self.tokenizer.pad_id()
return pad_id

@property
def pad_token_id(self):
# pad_token_id introduced for consistency with HF tokenizers
return self.pad_id

@property
def bos_id(self):
if self.legacy:
Expand All @@ -213,6 +259,11 @@ def eos_id(self):
eos_id = self.tokenizer.eos_id()
return eos_id

@property
def eos_token_id(self):
# eos_token_id introduced for consistency with HF tokenizers
return self.eos_id

@property
def sep_id(self):
if self.legacy:
Expand Down
17 changes: 16 additions & 1 deletion nemo/collections/common/tokenizers/tokenizer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
class TokenizerSpec(ABC):
"""
Inherit this class to implement a new tokenizer.

Methods encode, batch_encode_plus, decode and batch_decode
are introduced for HuggingFace tokenizers API consistency.
"""

@abstractmethod
Expand All @@ -48,8 +51,20 @@
def ids_to_text(self, ids):
pass

def encode(self, text, *args, **kwargs):
Dismissed Show dismissed Hide dismissed
raise NotImplementedError

def batch_encode_plus(self, texts, *args, **kwargs):
raise NotImplementedError

def decode(self, ids, *args, **kwargs):
raise NotImplementedError

def batch_decode(self, ids, *args, **kwargs):
raise NotImplementedError

def add_special_tokens(self, special_tokens: List[str]):
raise NotImplementedError("To be implemented")
raise NotImplementedError

@property
def name(self):
Expand Down
63 changes: 63 additions & 0 deletions tests/collections/common/test_spc_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pytest
import torch

from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer

Expand Down Expand Up @@ -193,3 +194,65 @@ def test_ids_to_tokens(self, test_data_dir):

for i in range(len(result)):
assert result[i] == tokens[i]

@pytest.mark.unit
def test_encode(self, test_data_dir):
tokenizer = SentencePieceTokenizer(test_data_dir + self.model_name)

text = "This text should encode to sth more than `max_length` tokens..."
result = tokenizer.encode(text)
assert isinstance(result, list)

max_length = 5
result = tokenizer.encode(text, max_length=max_length)
assert len(result) == max_length

n = 2
texts = [text for _ in range(n)]
tokens_list = tokenizer.encode(texts, max_length=max_length)
assert len(tokens_list) == n
assert all(len(tokens) == max_length for tokens in tokens_list)

result = tokenizer.encode(text, max_length=max_length, return_tensors="pt")
assert isinstance(result, torch.LongTensor)
assert result.size() == (1, max_length)

with pytest.raises(AssertionError):
tokenizer.encode(text, return_tensors="np") # Only "pt" option implemented

@pytest.mark.unit
def test_decode(self, test_data_dir):
tokenizer = SentencePieceTokenizer(test_data_dir + self.model_name)

text = "ole ole [SEP] ole ola [SEP]"
tokens = tokenizer.encode(text)
assert text == tokenizer.decode(tokens)

n = 8
texts = [text for _ in range(n)]
tokens_list = tokenizer.encode(texts)
assert isinstance(tokens_list, list)
assert len(tokens_list) == n
for tokens in tokens_list:
assert text == tokenizer.decode(tokens)

@pytest.mark.unit
def test_batch_encode_plus(self, test_data_dir):
tokenizer = SentencePieceTokenizer(test_data_dir + self.model_name)

texts = ["Welcome to NeMo!", "This is fun"]
with pytest.raises(AssertionError):
tokenizer.batch_encode_plus(texts[0]) # Input should be List[str]

tokens_dict = tokenizer.batch_encode_plus(texts)
assert isinstance(tokens_dict, dict)
assert "input_ids" in tokens_dict
assert tokens_dict["input_ids"] == tokenizer.encode(texts)

@pytest.mark.unit
def test_batch_decode(self, test_data_dir):
tokenizer = SentencePieceTokenizer(test_data_dir + self.model_name)

texts = ["Jaki to jest język?", "Kropka."]
tokens = tokenizer.encode(texts)
assert texts == tokenizer.batch_decode(tokens)
Loading