Skip to content

Commit

Permalink
fix: improve testcase (georgia-tech-db#1294)
Browse files Browse the repository at this point in the history
Test default values of `chunk_size` and `chunk_overlap`
  • Loading branch information
gaurav274 authored and a0x8o committed Oct 30, 2023
1 parent d4c650b commit f9e9f8b
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 16 deletions.
2 changes: 2 additions & 0 deletions evadb/configuration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@
S3_DOWNLOAD_DIR = "s3_downloads"
TMP_DIR = "tmp"
DEFAULT_TRAIN_TIME_LIMIT = 120
DEFAULT_DOCUMENT_CHUNK_SIZE = 4000
DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200
7 changes: 6 additions & 1 deletion evadb/optimizer/statement_to_opr_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def visit_table_ref(self, table_ref: TableRef):
if table_ref.is_table_atom():
# Table
catalog_entry = table_ref.table.table_obj
self._plan = LogicalGet(table_ref, catalog_entry, table_ref.alias)
self._plan = LogicalGet(
table_ref,
catalog_entry,
table_ref.alias,
chunk_params=table_ref.chunk_params,
)

elif table_ref.is_table_valued_expr():
tve = table_ref.table_valued_expr
Expand Down
8 changes: 4 additions & 4 deletions evadb/parser/lark_visitor/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ def chunk_params(self, tree):
assert len(chunk_params) == 2 or len(chunk_params) == 4
if len(chunk_params) == 4:
return {
"chunk_size": ConstantValueExpression(chunk_params[1]),
"chunk_overlap": ConstantValueExpression(chunk_params[3]),
"chunk_size": chunk_params[1],
"chunk_overlap": chunk_params[3],
}

elif len(chunk_params) == 2:
if chunk_params[0] == "CHUNK_SIZE":
return {
"chunk_size": ConstantValueExpression(chunk_params[1]),
"chunk_size": chunk_params[1],
}
elif chunk_params[0] == "CHUNK_OVERLAP":
return {
"chunk_overlap": ConstantValueExpression(chunk_params[1]),
"chunk_overlap": chunk_params[1],
}
else:
assert f"incorrect keyword found {chunk_params[0]}"
Expand Down
4 changes: 2 additions & 2 deletions evadb/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def parse_predicate_expression(expr: str):

def parse_table_clause(expr: str, chunk_size: int = None, chunk_overlap: int = None):
mock_query_parts = [f"SELECT * FROM {expr}"]
if chunk_size:
if chunk_size is not None:
mock_query_parts.append(f"CHUNK_SIZE {chunk_size}")
if chunk_overlap:
if chunk_overlap is not None:
mock_query_parts.append(f"CHUNK_OVERLAP {chunk_overlap}")
mock_query_parts.append(";")
mock_query = " ".join(mock_query_parts)
Expand Down
10 changes: 8 additions & 2 deletions evadb/readers/document/document_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from typing import Dict, Iterator

from evadb.catalog.sql_config import ROW_NUM_COLUMN
from evadb.configuration.constants import (
DEFAULT_DOCUMENT_CHUNK_OVERLAP,
DEFAULT_DOCUMENT_CHUNK_SIZE,
)
from evadb.readers.abstract_reader import AbstractReader
from evadb.readers.document.registry import (
_lazy_import_loader,
Expand All @@ -31,8 +35,10 @@ def __init__(self, *args, chunk_params, **kwargs):

# https://github.com/hwchase17/langchain/blob/5b6bbf4ab2a33ed0d33ff5d3cb3979a7edc15682/langchain/text_splitter.py#L570
# by default we use chunk_size 4000 and overlap 200
self._chunk_size = chunk_params.get("chunk_size", 4000)
self._chunk_overlap = chunk_params.get("chunk_overlap", 200)
self._chunk_size = chunk_params.get("chunk_size", DEFAULT_DOCUMENT_CHUNK_SIZE)
self._chunk_overlap = chunk_params.get(
"chunk_overlap", DEFAULT_DOCUMENT_CHUNK_OVERLAP
)

def _read(self) -> Iterator[Dict]:
ext = Path(self.file_url).suffix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from pandas.testing import assert_frame_equal

from evadb.binder.binder_utils import BinderError
from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR
from evadb.configuration.constants import (
DEFAULT_DOCUMENT_CHUNK_OVERLAP,
DEFAULT_DOCUMENT_CHUNK_SIZE,
EvaDB_DATABASE_DIR,
EvaDB_ROOT_DIR,
)
from evadb.executor.executor_utils import ExecutorError
from evadb.interfaces.relational.db import connect
from evadb.models.storage.batch import Batch
Expand Down Expand Up @@ -392,22 +397,47 @@ def test_langchain_split_doc(self):
load_pdf.execute()

result1 = (
cursor.table("docs", chunk_size=2000, chunk_overlap=0).select("data").df()
cursor.table(
"docs", chunk_size=2000, chunk_overlap=DEFAULT_DOCUMENT_CHUNK_OVERLAP
)
.select("data")
.df()
)

result2 = (
cursor.table("docs", chunk_size=4000, chunk_overlap=2000)
cursor.table(
"docs", chunk_size=DEFAULT_DOCUMENT_CHUNK_SIZE, chunk_overlap=2000
)
.select("data")
.df()
)

self.assertEqual(len(result1), len(result2))
result3 = (
cursor.table(
"docs", chunk_size=DEFAULT_DOCUMENT_CHUNK_SIZE, chunk_overlap=0
)
.select("data")
.df()
)

self.assertGreater(len(result1), len(result2))
self.assertGreater(len(result2), len(result3))

# should use default value of chunk_overlap and respect chunk_size
result5 = cursor.table("docs", chunk_size=2000).select("data").df()
self.assertEqual(len(result5), len(result1))

# should use the default value of chunk_size and should respect chunk_overlap
result4 = cursor.table("docs", chunk_overlap=0).select("data").df()
self.assertEqual(len(result3), len(result4))

# should use the default values
result1 = cursor.table("docs").select("data").df()

result2 = cursor.query(
"SELECT data from docs chunk_size 4000 chunk_overlap 200"
f"SELECT data from docs chunk_size {DEFAULT_DOCUMENT_CHUNK_SIZE} chunk_overlap {DEFAULT_DOCUMENT_CHUNK_OVERLAP}"
).df()

self.assertEqual(len(result1), len(result2))

def test_show_relational(self):
Expand Down
9 changes: 7 additions & 2 deletions test/unit_tests/optimizer/test_statement_to_opr_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,16 @@ class StatementToOprTest(unittest.TestCase):
@patch("evadb.optimizer.statement_to_opr_converter.LogicalGet")
def test_visit_table_ref_should_create_logical_get_opr(self, mock_lget):
converter = StatementToPlanConverter()
table_ref = MagicMock(spec=TableRef, alias="alias")
table_ref = MagicMock(spec=TableRef, alias="alias", chunk_params={})
table_ref.is_select.return_value = False
table_ref.sample_freq = None
converter.visit_table_ref(table_ref)
mock_lget.assert_called_with(table_ref, table_ref.table.table_obj, "alias")
mock_lget.assert_called_with(
table_ref,
table_ref.table.table_obj,
"alias",
chunk_params=table_ref.chunk_params,
)
self.assertEqual(mock_lget.return_value, converter._plan)

@patch("evadb.optimizer.statement_to_opr_converter.LogicalFilter")
Expand Down

0 comments on commit f9e9f8b

Please sign in to comment.