Skip to content

Commit

Permalink
Fix TextSplitter.from_tiktoken(langchain-ai#4361)
Browse files Browse the repository at this point in the history
Thanks to @danb27 for the fix! Minor update

Fixes langchain-ai#4357

---------

Co-authored-by: Dan Bianchini <42096328+danb27@users.noreply.github.com>
  • Loading branch information
2 people authored and EandrewJones committed May 9, 2023
1 parent 894e269 commit ece3770
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
20 changes: 16 additions & 4 deletions langchain/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
)

Expand All @@ -22,6 +24,8 @@

logger = logging.getLogger(__name__)

TS = TypeVar("TS", bound="TextSplitter")


class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
Expand Down Expand Up @@ -139,13 +143,13 @@ def _huggingface_tokenizer_length(text: str) -> int:

@classmethod
def from_tiktoken_encoder(
cls,
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TextSplitter:
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
Expand All @@ -161,16 +165,24 @@ def from_tiktoken_encoder(
else:
enc = tiktoken.get_encoding(encoding_name)

def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
**kwargs,
)
)

if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}

return cls(length_function=_tiktoken_encoder, **kwargs)

def transform_documents(
Expand Down
37 changes: 21 additions & 16 deletions tests/integration_tests/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,24 @@ def test_huggingface_tokenizer() -> None:
assert output == ["foo", "bar"]


class TestTokenTextSplitter:
"""Test token text splitter."""

def test_basic(self) -> None:
"""Test no overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
assert output == expected_output

def test_overlap(self) -> None:
"""Test with overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
assert output == expected_output
def test_token_text_splitter() -> None:
"""Test no overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
assert output == expected_output


def test_token_text_splitter_overlap() -> None:
"""Test with overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
assert output == expected_output


def test_token_text_splitter_from_tiktoken() -> None:
splitter = TokenTextSplitter.from_tiktoken_encoder(model_name="gpt-3.5-turbo")
expected_tokenizer = "cl100k_base"
actual_tokenizer = splitter._tokenizer.name
assert expected_tokenizer == actual_tokenizer

0 comments on commit ece3770

Please sign in to comment.