Skip to content

Commit

Permalink
fix(schema): added schema
Browse files Browse the repository at this point in the history
  • Loading branch information
PeriniM committed May 26, 2024
1 parent 8296236 commit 8d76c4b
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 94 deletions.
74 changes: 0 additions & 74 deletions examples/openai/pdf_scraper_openai.py

This file was deleted.

3 changes: 2 additions & 1 deletion scrapegraphai/graphs/pdf_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class PDFScraperGraph(AbstractGraph):
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
super().__init__(prompt, config, source)
super().__init__(prompt, config, source, schema)

self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"

Expand Down Expand Up @@ -76,6 +76,7 @@ def _create_graph(self) -> BaseGraph:
output=["answer"],
node_config={
"llm_model": self.llm_model,
"schema": self.schema
}
)

Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
from .robots import robots_dictionary
from .generate_answer_node_prompts import template_chunks, template_chunks_with_schema, template_no_chunks, template_no_chunks_with_schema, template_merge
from .generate_answer_node_csv_prompts import template_chunks_csv, template_no_chunks_csv, template_merge_csv
from .generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf
from .generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf, template_chunks_pdf_with_schema, template_no_chunks_pdf_with_schema
from .generate_answer_node_omni_prompts import template_chunks_omni, template_no_chunk_omni, template_merge_omni
26 changes: 26 additions & 0 deletions scrapegraphai/helpers/generate_answer_node_pdf_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
Content of {chunk_id}: {context}. \n
"""

template_chunks_pdf_with_schema = """
You are a PDF scraper and you have just scraped the
following content from a PDF.
You are now asked to answer a user question about the content you have scraped.\n
The PDF is big so I am giving you one chunk at the time to be merged later with the other chunks.\n
Ignore all the context sentences that ask you not to extract information from the html code.\n
If you don't find the answer put as value "NA".\n
Make sure the output json is formatted correctly and does not contain errors. \n
The schema as output is the following: {schema}\n
Output instructions: {format_instructions}\n
Content of {chunk_id}: {context}. \n
"""

template_no_chunks_pdf = """
You are a PDF scraper and you have just scraped the
following content from a PDF.
Expand All @@ -25,6 +38,19 @@
PDF content: {context}\n
"""

template_no_chunks_pdf_with_schema = """
You are a PDF scraper and you have just scraped the
following content from a PDF.
You are now asked to answer a user question about the content you have scraped.\n
Ignore all the context sentences that ask you not to extract information from the html code.\n
If you don't find the answer put as value "NA".\n
Make sure the output json is formatted correctly and does not contain errors. \n
The schema as output is the following: {schema}\n
Output instructions: {format_instructions}\n
User question: {question}\n
PDF content: {context}\n
"""

template_merge_pdf = """
You are a PDF scraper and you have just scraped the
following content from a PDF.
Expand Down
40 changes: 24 additions & 16 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,36 @@ def execute(self, state: dict) -> dict:
chains_dict = {}

# Use tqdm to add progress bar
for i, chunk in enumerate(
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
):
if len(doc) == 1:
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
if self.node_config["schema"] is None and len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks,
input_variables=["question"],
partial_variables={
"context": chunk.page_content,
"format_instructions": format_instructions,
},
)
else:
partial_variables={"context": chunk.page_content,
"format_instructions": format_instructions})
elif self.node_config["schema"] is not None and len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks_with_schema,
input_variables=["question"],
partial_variables={"context": chunk.page_content,
"format_instructions": format_instructions,
"schema": self.node_config["schema"]
})
elif self.node_config["schema"] is None and len(doc) > 1:
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
partial_variables={
"context": chunk.page_content,
"chunk_id": i + 1,
"format_instructions": format_instructions,
},
)
partial_variables={"context": chunk.page_content,
"chunk_id": i + 1,
"format_instructions": format_instructions})
elif self.node_config["schema"] is not None and len(doc) > 1:
prompt = PromptTemplate(
template=template_chunks_with_schema,
input_variables=["question"],
partial_variables={"context": chunk.page_content,
"chunk_id": i + 1,
"format_instructions": format_instructions,
"schema": self.node_config["schema"]})

# Dynamically name the chains based on their index
chain_name = f"chunk{i+1}"
Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_pdf_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Imports from the library
from .base_node import BaseNode
from ..helpers.generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf
from ..helpers.generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf, template_chunks_pdf_with_schema, template_no_chunks_pdf_with_schema


class GenerateAnswerPDFNode(BaseNode):
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
node_name (str): name of the node
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
self.llm_model = node_config["llm_model"]
self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)
Expand Down

0 comments on commit 8d76c4b

Please sign in to comment.