Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes LC document deserialization #218

Merged
merged 1 commit into from
May 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions scrapegraphai/graphs/smart_scraper_graph_burr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
SmartScraperGraph Module Burr Version
"""
from typing import Tuple
from typing import Tuple, Union

from burr import tracking
from burr.core import Application, ApplicationBuilder, State, default, when
Expand All @@ -14,6 +14,7 @@
from langchain_community.document_transformers import Html2TextTransformer, EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core import load as lc_serde
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel
Expand Down Expand Up @@ -67,10 +68,10 @@ def parse_node(state: State, chunk_size: int = 4096) -> tuple[dict, State]:

@action(reads=["user_prompt", "parsed_doc", "doc"],
writes=["relevant_chunks"])
def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[dict, State]:
# bug around input serialization with tracker
llm_model = OpenAI({"model_name": "gpt-3.5-turbo"})
embedder_model = OpenAIEmbeddings()
def rag_node(state: State, llm_model: str, embedder_model: object) -> tuple[dict, State]:
# bug around input serialization with tracker -- so instantiate objects here:
llm_model = OpenAI({"model_name": llm_model})
embedder_model = OpenAIEmbeddings() if embedder_model == "openai" else None
user_prompt = state["user_prompt"]
doc = state["parsed_doc"]

Expand Down Expand Up @@ -104,8 +105,10 @@ def rag_node(state: State, llm_model: object, embedder_model: object) -> tuple[d

@action(reads=["user_prompt", "relevant_chunks", "parsed_doc", "doc"],
writes=["answer"])
def generate_answer_node(state: State, llm_model: object) -> tuple[dict, State]:
llm_model = OpenAI({"model_name": "gpt-3.5-turbo"})
def generate_answer_node(state: State, llm_model: str) -> tuple[dict, State]:
# bug around input serialization with tracker -- so instantiate objects here:
llm_model = OpenAI({"model_name": llm_model})

user_prompt = state["user_prompt"]
doc = state.get("relevant_chunks",
state.get("parsed_doc",
Expand Down Expand Up @@ -207,21 +210,49 @@ def post_run_step(
):
print(f"Finishing action: {action.name}")

import json

def _deserialize_document(x: Union[str, dict]) -> Document:
if isinstance(x, dict):
return lc_serde.load(x)
elif isinstance(x, str):
try:
return lc_serde.loads(x)
except json.JSONDecodeError:
return Document(page_content=x)
raise ValueError("Couldn't deserialize document")


def run(prompt: str, input_key: str, source: str, config: dict) -> str:
# these configs aren't really used yet.
llm_model = config["llm_model"]

embedder_model = config["embedder_model"]
open_ai_embedder = OpenAIEmbeddings()
# open_ai_embedder = OpenAIEmbeddings()
chunk_size = config["model_token"]

tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")
app_instance_id = "testing-12345678919"
initial_state = {
"user_prompt": prompt,
input_key: source,
}
from burr.core import expr
tracker = tracking.LocalTrackingClient(project="smart-scraper-graph")

entry_point = "fetch_node"
if app_instance_id:
persisted_state = tracker.load(None, app_id=app_instance_id, sequence_no=None)
if not persisted_state:
print(f"Warning: No persisted state found for app_id {app_instance_id}.")
else:
initial_state = persisted_state["state"]
# for now we need to manually deserialize LangChain messages into LangChain Objects
# i.e. we know which objects need to be LC objects
initial_state = initial_state.update(**{
"doc": _deserialize_document(initial_state["doc"])
})
docs = [_deserialize_document(doc) for doc in initial_state["relevant_chunks"]]
initial_state = initial_state.update(**{
"relevant_chunks": docs
})
entry_point = persisted_state["position"]

app = (
ApplicationBuilder()
Expand All @@ -236,16 +267,17 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
("parse_node", "rag_node", default),
("rag_node", "generate_answer_node", default)
)
# .with_entrypoint("fetch_node")
# .with_state(**initial_state)
.initialize_from(
tracker,
resume_at_next_action=True, # always resume from entrypoint in the case of failure
default_state=initial_state,
default_entrypoint="fetch_node",
)
# .with_identifiers(app_id="testing-123456")
.with_tracker(project="smart-scraper-graph")
.with_entrypoint(entry_point)
.with_state(**initial_state)
# this will work once we get serialization plugin for langchain objects done
# .initialize_from(
# tracker,
# resume_at_next_action=True, # always resume from entrypoint in the case of failure
# default_state=initial_state,
# default_entrypoint="fetch_node",
# )
.with_identifiers(app_id=app_instance_id)
.with_tracker(tracker)
.with_hooks(PrintLnHook())
.build()
)
Expand All @@ -270,8 +302,8 @@ def run(prompt: str, input_key: str, source: str, config: dict) -> str:
source = "https://en.wikipedia.org/wiki/Paris"
input_key = "url"
config = {
"llm_model": "rag-token",
"embedder_model": "foo",
"llm_model": "gpt-3.5-turbo",
"embedder_model": "openai",
"model_token": "bar",
}
run(prompt, input_key, source, config)
print(run(prompt, input_key, source, config))