Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from llama_index.embeddings.openai import OpenAIEmbedding

from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter
from unstract.sdk.adapters.embedding.helper import EmbeddingHelper
from unstract.sdk.adapters.exceptions import AdapterError


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from llama_index.embeddings.fastembed import FastEmbedEmbedding

from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter
from unstract.sdk.adapters.embedding.helper import EmbeddingHelper
from unstract.sdk.adapters.exceptions import AdapterError


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
import os
from typing import Any, Optional
from typing import Any

from google.auth.transport import requests as google_requests
from google.oauth2.service_account import Credentials
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
from typing import Any, Optional
Expand All @@ -11,7 +10,6 @@
TextExtractionResult,
)
from unstract.sdk.adapters.x2text.llm_whisperer_v2.src.constants import (
HTTPMethod,
WhispererEndpoint,
)
from unstract.sdk.adapters.x2text.llm_whisperer_v2.src.dto import WhispererRequestParams
Expand Down
41 changes: 31 additions & 10 deletions src/unstract/sdk/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,50 @@ def __init__(

@log_elapsed(operation="ANSWER_PROMPTS")
def answer_prompt(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
self,
payload: dict[str, Any],
params: Optional[dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
url_path = "answer-prompt"
if self.is_public_call:
url_path = "answer-prompt-public"
return self._post_call(
url_path=url_path,
payload=payload,
params=params,
url_path=url_path, payload=payload, params=params, headers=headers
)

def single_pass_extraction(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
self,
payload: dict[str, Any],
params: Optional[dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return self._post_call(
url_path="single-pass-extraction",
payload=payload,
params=params,
headers=headers,
)

def summarize(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
self,
payload: dict[str, Any],
params: Optional[dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return self._post_call(url_path="summarize", payload=payload, params=params)
return self._post_call(
url_path="summarize",
payload=payload,
params=params,
headers=headers,
)

def _post_call(
self,
url_path: str,
payload: dict[str, Any],
params: Optional[dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
"""Invokes and communicates to prompt service to fetch response for the
prompt.
Expand All @@ -74,6 +88,7 @@ def _post_call(
url_path (str): URL path to the service endpoint
payload (dict): Payload to send in the request body
params (dict, optional): Query parameters to include in the request
headers (dict, optional): Headers to include in the request

Returns:
dict: Response from the prompt service
Expand All @@ -94,13 +109,19 @@ def _post_call(
"status_code": 500,
}
url: str = f"{self.base_url}/{url_path}"
headers: dict[str, str] = {}

default_headers = {}

if not self.is_public_call:
headers = {"Authorization": f"Bearer {self.bearer_token}"}
default_headers = {"Authorization": f"Bearer {self.bearer_token}"}

if headers:
default_headers.update(headers)

response: Response = Response()
try:
response = requests.post(
url=url, json=payload, params=params, headers=headers
url=url, json=payload, params=params, headers=default_headers
)
response.raise_for_status()
result["status"] = "OK"
Expand Down