Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,20 @@ print(response.prompt)
By using the `count_tokens` method, you can estimate the billing for a given request.

```python
from ai21 import AI21Client
from ai21.tokenizers import get_tokenizer

client = AI21Client()
client.count_tokens(text="some text") # returns int
tokenizer = get_tokenizer(name="jamba-instruct-tokenizer")
total_tokens = tokenizer.count_tokens(text="some text") # returns int
print(total_tokens)
```

Available tokenizers are:

- `jamba-instruct-tokenizer`
- `j2-tokenizer`

For more information on AI21 Tokenizers, see the [documentation](https://github.com/AI21Labs/ai21-tokenizer).

### File Upload

---
Expand Down
13 changes: 9 additions & 4 deletions ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings
from typing import Optional, Any, Dict

from ai21_tokenizer import PreTrainedTokenizers

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.ai21_http_client import AI21HTTPClient
from ai21.clients.studio.resources.studio_answer import StudioAnswer
Expand Down Expand Up @@ -63,9 +66,11 @@ def __init__(
self.library = StudioLibrary(self._http_client)
self.segmentation = StudioSegmentation(self._http_client)

def count_tokens(self, text: str) -> int:
# We might want to cache the tokenizer instance within the class
# and not globally as it might be used by other instances
def count_tokens(self, text: str, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> int:
warnings.warn(
"Please use the global get_tokenizer() method directly instead of the AI21Client().count_tokens() method.",
DeprecationWarning,
)

tokenizer = get_tokenizer()
tokenizer = get_tokenizer(tokenizer_name)
return tokenizer.count_tokens(text)
4 changes: 4 additions & 0 deletions ai21/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .ai21_tokenizer import AI21Tokenizer
from .factory import get_tokenizer

__all__ = ["AI21Tokenizer", "get_tokenizer"]
8 changes: 4 additions & 4 deletions ai21/tokenizers/ai21_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Any

from ai21_tokenizer import BaseTokenizer

Expand All @@ -16,7 +16,7 @@ def count_tokens(self, text: str) -> int:

return len(encoded_text)

def tokenize(self, text: str) -> List[str]:
encoded_text = self._tokenizer.encode(text)
def tokenize(self, text: str, **kwargs: Any) -> List[str]:
encoded_text = self._tokenizer.encode(text, **kwargs)

return self._tokenizer.convert_ids_to_tokens(encoded_text)
return self._tokenizer.convert_ids_to_tokens(encoded_text, **kwargs)
16 changes: 8 additions & 8 deletions ai21/tokenizers/factory.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from typing import Optional
from typing import Dict

from ai21_tokenizer import Tokenizer
from ai21_tokenizer import Tokenizer, PreTrainedTokenizers

from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer

_cached_tokenizer: Optional[AI21Tokenizer] = None
_cached_tokenizers: Dict[str, AI21Tokenizer] = {}


def get_tokenizer() -> AI21Tokenizer:
def get_tokenizer(name: str = PreTrainedTokenizers.J2_TOKENIZER) -> AI21Tokenizer:
"""
Get the tokenizer instance.

If the tokenizer instance is not cached, it will be created using the Tokenizer.get_tokenizer() method.
"""
global _cached_tokenizer
global _cached_tokenizers

if _cached_tokenizer is None:
_cached_tokenizer = AI21Tokenizer(Tokenizer.get_tokenizer())
if _cached_tokenizers.get(name) is None:
_cached_tokenizers[name] = AI21Tokenizer(Tokenizer.get_tokenizer(name))

return _cached_tokenizer
return _cached_tokenizers[name]
2 changes: 1 addition & 1 deletion examples/studio/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
response = client.chat.create(
system=system,
messages=messages,
model="j2-mid",
model="j2-ultra",
count_penalty=Penalty(
scale=0,
apply_to_emojis=False,
Expand Down
7 changes: 3 additions & 4 deletions examples/studio/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ai21 import AI21Client
from ai21.tokenizers import get_tokenizer

prompt = (
"The following is a conversation between a user of an eCommerce store and a user operation"
Expand Down Expand Up @@ -30,7 +30,6 @@
"- There is no return option\n\nUser gender: Female.\n\nConversation:\n"
"User: Hi, I have a question for you"
)
client = AI21Client()
# This is the new and recommended way to use the Tokenization module. The old "execute" method is deprecated.
response = client.count_tokens(prompt)
tokenizer = get_tokenizer(name="jamba-instruct-tokenizer")
response = tokenizer.count_tokens(prompt)
print(response)
417 changes: 359 additions & 58 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.8"
requests = "^2.31.0"
ai21-tokenizer = "^0.3.9"
ai21-tokenizer = "^0.9.0"
boto3 = { version = "^1.28.82", optional = true }
dataclasses-json = "^0.6.3"
typing-extensions = "^4.9.0"
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/clients/studio/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ai21 import AI21Client
from ai21.models import ChatMessage, RoleType, Penalty, FinishReason

_MODEL = "j2-mid"
_MODEL = "j2-ultra"
_MESSAGES = [
ChatMessage(
text="Hello, I need help studying for the coming test, can you teach me about the US constitution? ",
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_chat_when_finish_reason_defined__should_halt_on_expected_reason(
messages=_MESSAGES,
system=_SYSTEM,
max_tokens=max_tokens,
model="j2-mid",
model="j2-ultra",
temperature=1,
top_p=0,
num_results=1,
Expand Down
52 changes: 44 additions & 8 deletions tests/unittests/tokenizers/test_ai21_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from typing import List

import pytest
from ai21.tokenizers.factory import get_tokenizer


class TestAI21Tokenizer:
def test__count_tokens__should_return_number_of_tokens(self):
expected_number_of_tokens = 8
tokenizer = get_tokenizer()
@pytest.mark.parametrize(
ids=[
"when_j2_tokenizer",
"when_jamba_instruct_tokenizer",
],
argnames=["tokenizer_name", "expected_tokens"],
argvalues=[
("j2-tokenizer", 8),
("jamba-instruct-tokenizer", 9),
],
)
def test__count_tokens__should_return_number_of_tokens(self, tokenizer_name: str, expected_tokens: int):
tokenizer = get_tokenizer(tokenizer_name)

actual_number_of_tokens = tokenizer.count_tokens("Text to Tokenize - Hello world!")

assert actual_number_of_tokens == expected_number_of_tokens

def test__tokenize__should_return_list_of_tokens(self):
expected_tokens = ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"]
tokenizer = get_tokenizer()
assert actual_number_of_tokens == expected_tokens

@pytest.mark.parametrize(
ids=[
"when_j2_tokenizer",
"when_jamba_instruct_tokenizer",
],
argnames=["tokenizer_name", "expected_tokens"],
argvalues=[
("j2-tokenizer", ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"]),
(
"jamba-instruct-tokenizer",
["<|startoftext|>", "Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"],
),
],
)
def test__tokenize__should_return_list_of_tokens(self, tokenizer_name: str, expected_tokens: List[str]):
tokenizer = get_tokenizer(tokenizer_name)

actual_tokens = tokenizer.tokenize("Text to Tokenize - Hello world!")

Expand All @@ -23,3 +49,13 @@ def test__tokenizer__should_be_singleton__when_called_twice(self):
tokenizer2 = get_tokenizer()

assert tokenizer1 is tokenizer2

def test__get_tokenizer__when_called_with_different_tokenizer_name__should_return_different_tokenizer(self):
tokenizer1 = get_tokenizer("j2-tokenizer")
tokenizer2 = get_tokenizer("jamba-instruct-tokenizer")

assert tokenizer1._tokenizer is not tokenizer2._tokenizer

def test__get_tokenizer__when_tokenizer_name_not_supported__should_raise_error(self):
with pytest.raises(ValueError):
get_tokenizer("some-tokenizer")