Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions src/unstract/sdk/index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, Optional
import logging
from typing import Any, Callable, Optional

from llama_index.core import Document
from llama_index.core.node_parser import SimpleNodeParser
Expand All @@ -25,6 +26,8 @@
from unstract.sdk.vector_db import VectorDB
from unstract.sdk.x2txt import X2Text

logger = logging.getLogger(__name__)


class Constants:
TOP_K = 5
Expand Down Expand Up @@ -101,27 +104,6 @@ def query_index(
finally:
vector_db.close()

def _cleanup_text(self, full_text):
# Remove text which is not required
full_text_lines = full_text.split("\n")
new_context_lines = []
empty_line_count = 0
for line in full_text_lines:
if line.strip() == "":
empty_line_count += 1
else:
if empty_line_count >= 3:
empty_line_count = 3
for i in range(empty_line_count):
new_context_lines.append("")
empty_line_count = 0
new_context_lines.append(line.rstrip())
self.tool.stream_log(
f"Old context length: {len(full_text_lines)}, "
f"New context length: {len(new_context_lines)}"
)
return "\n".join(new_context_lines)

def index(
self,
tool_id: str,
Expand All @@ -136,6 +118,7 @@ def index(
output_file_path: Optional[str] = None,
enable_highlight: bool = False,
usage_kwargs: dict[Any, Any] = {},
process_text: Optional[Callable[[str], str]] = None,
) -> str:
"""Indexes an individual file using the passed arguments.

Expand Down Expand Up @@ -276,10 +259,17 @@ def index(
except AdapterError as e:
# Wrapping AdapterErrors with SdkError
raise IndexingError(str(e)) from e
if process_text:
try:
result = process_text(extracted_text)
if isinstance(result, str):
extracted_text = result
except Exception as e:
logger.error(f"Error occured inside function 'process_text': {e}")
full_text.append(
{
"section": "full",
"text_contents": self._cleanup_text(extracted_text),
"text_contents": extracted_text,
}
)

Expand Down
32 changes: 29 additions & 3 deletions src/unstract/sdk/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from typing import Any, Optional
from typing import Any, Callable, Optional

from llama_index.core.base.llms.types import CompletionResponseGen
from llama_index.core.llms import LLM as LlamaIndexLLM
Expand Down Expand Up @@ -69,15 +69,41 @@ def _initialise(self):
def complete(
self,
prompt: str,
retries: int = 3,
process_text: Optional[Callable[[str], str]] = None,
**kwargs: Any,
) -> Optional[dict[str, Any]]:
"""Generates a completion response for the given prompt.

Args:
prompt (str): The input text prompt for generating the completion.
process_text (Optional[Callable[[str], str]], optional): A callable that
processes the generated text and extracts specific information.
Defaults to None.
**kwargs (Any): Additional arguments passed to the completion function.

Returns:
Optional[dict[str, Any]]: A dictionary containing the result of the
completion and processed output or None if the completion fails.

Raises:
Any: If an error occurs during the completion process, it will be
raised after being processed by `parse_llm_err`.
"""
try:
response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs)
process_text_output = {}
if process_text:
try:
process_text_output = process_text(response, LLM.json_regex)
if not isinstance(process_text_output, dict):
process_text_output = {}
except Exception as e:
logger.error(f"Error occured inside function 'process_text': {e}")
process_text_output = {}
match = LLM.json_regex.search(response.text)
if match:
response.text = match.group(0)
return {LLM.RESPONSE: response}
return {LLM.RESPONSE: response, **process_text_output}
except Exception as e:
raise parse_llm_err(e) from e

Expand Down