Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dataflow/example/KBCleaningPipeline/kbc_test.jsonl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{"raw_content": "../../example_data/KBCleaningPipeline/bitter_lesson.pdf"}
{"raw_content": "https://trafilatura.readthedocs.io/en/latest/quickstart.html"}
{"raw_content": "https://arxiv.org/pdf/2505.07773"}
{"raw_content": "https://arxiv.org/pdf/2503.09516"}
{"source": "../../example_data/KBCleaningPipeline/bitter_lesson.pdf"}
{"source": "https://trafilatura.readthedocs.io/en/latest/quickstart.html"}
{"source": "https://arxiv.org/pdf/2505.07773"}
{"source": "https://arxiv.org/pdf/2503.09516"}
4 changes: 4 additions & 0 deletions dataflow/example/KBCleaningPipeline/kbc_test_1.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"source": "../example_data/KBCleaningPipeline/bitter_lesson.pdf"}
{"source": "https://trafilatura.readthedocs.io/en/latest/quickstart.html"}
{"source": "https://arxiv.org/pdf/2505.07773"}
{"source": "https://arxiv.org/pdf/2503.09516"}
6 changes: 3 additions & 3 deletions dataflow/operators/core_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from .generate.prompted_generator import PromptedGenerator
from .generate.paired_prompted_generator import PairedPromptedGenerator
from .generate.random_domain_knowledge_row_generator import RandomDomainKnowledgeRowGenerator
from .generate.doc2prompt_generator import Doc2PromptGenerator
from .generate.doc2qa_generator import Doc2QAGenerator
from .generate.text2qa_generator import Text2QAGenerator
from .generate.text2multihopqa_generator import Text2MultiHopQAGenerator
from .eval.bench_dataset_evaluator import BenchDatasetEvaluator
from .eval.doc2qa_sample_evaluator import Doc2QASampleEvaluator
from .eval.text2qa_sample_evaluator import Text2QASampleEvaluator
from .eval.prompted_eval import PromptedEvaluator
from .filter.prompted_filter import PromptedFilter
from .filter.kcentergreedy_filter import KCenterGreedyFilter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from dataflow.core import LLMServingABC
from dataflow.core.prompt import prompt_restrict

from dataflow.prompts.doc2qa import (
Doc2QAQuestionQualityPrompt,
Doc2QAAnswerAlignmentPrompt,
Doc2QAAnswerVerifiabilityPrompt,
Doc2QADownstreamValuePrompt
from dataflow.prompts.text2qa import (
Text2QAQuestionQualityPrompt,
Text2QAAnswerAlignmentPrompt,
Text2QAAnswerVerifiabilityPrompt,
Text2QADownstreamValuePrompt
)

@prompt_restrict(
Doc2QAQuestionQualityPrompt,
Doc2QAAnswerAlignmentPrompt,
Doc2QAAnswerVerifiabilityPrompt,
Doc2QADownstreamValuePrompt
Text2QAQuestionQualityPrompt,
Text2QAAnswerAlignmentPrompt,
Text2QAAnswerVerifiabilityPrompt,
Text2QADownstreamValuePrompt
)
@OPERATOR_REGISTRY.register()
class Doc2QASampleEvaluator(OperatorABC):
class Text2QASampleEvaluator(OperatorABC):
'''
Answer Generator is a class that generates answers for given questions.
'''
Expand Down Expand Up @@ -82,16 +82,16 @@ def _build_prompts(self, dataframe):
Reformat the prompts in the dataframe to generate questions.
"""
question_quality_inputs = []
self.prompts = Doc2QAQuestionQualityPrompt()
self.prompts = Text2QAQuestionQualityPrompt()
question_quality_prompt = self.prompts.build_prompt()
answer_alignment_inputs = []
self.prompts = Doc2QAAnswerAlignmentPrompt()
self.prompts = Text2QAAnswerAlignmentPrompt()
answer_alignment_prompt = self.prompts.build_prompt()
answer_verifiability_inputs = []
self.prompts = Doc2QAAnswerVerifiabilityPrompt()
self.prompts = Text2QAAnswerVerifiabilityPrompt()
answer_verifiability_prompt = self.prompts.build_prompt()
downstream_value_inputs = []
self.prompts = Doc2QADownstreamValuePrompt()
self.prompts = Text2QADownstreamValuePrompt()
downstream_value_prompt = self.prompts.build_prompt()

for index, row in dataframe.iterrows():
Expand Down
88 changes: 0 additions & 88 deletions dataflow/operators/core_text/generate/doc2prompt_generator.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataflow.prompts.multihopqa import MultiHopQAGeneratorPrompt
from dataflow.prompts.text2qa import Text2MultiHopQAGeneratorPrompt
import pandas as pd
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
Expand All @@ -12,7 +12,15 @@
from tqdm import tqdm
import re

class KBCMultiHopQAGenerator(OperatorABC):
from dataflow.core.prompt import prompt_restrict

import re
@prompt_restrict(
Text2MultiHopQAGeneratorPrompt
)

@OPERATOR_REGISTRY.register()
class Text2MultiHopQAGenerator(OperatorABC):
r"""A processor for generating multi-hop question-answer pairs from user
data.

Expand All @@ -25,7 +33,8 @@ def __init__(self,
llm_serving: LLMServingABC,
seed: int = 0,
lang="en",
prompt_template = None
prompt_template = None,
num_q = 5
):
r"""Initialize the UserDataProcessor.

Expand All @@ -37,10 +46,12 @@ def __init__(self,
self.llm_serving = llm_serving
self.lang = lang
self.logger = get_logger()
self.num_q = num_q

if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = MultiHopQAGeneratorPrompt()
self.prompt_template = Text2MultiHopQAGeneratorPrompt()

@staticmethod
def get_desc(lang: str = "zh") -> tuple:
Expand Down Expand Up @@ -205,16 +216,22 @@ def _validate_dataframe(self, dataframe: pd.DataFrame):

def run(
self,
input_key:str='',
output_key:str='',
input_key:str='cleaned_chunk',
output_key:str='QA_pairs',
output_meta_key:str='QA_metadata',
storage: DataFlowStorage=None,
):
self.input_key, self.output_key = input_key, output_key
self.input_key, self.output_key, self.output_meta_key = input_key, output_key, output_meta_key
dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
texts = dataframe[self.input_key].tolist()
qa_pairs=self.process_batch(texts)
dataframe[self.output_key] = qa_pairs
outputs=self.process_batch(texts)
dataframe[self.output_key] = [
output['qa_pairs'][:self.num_q] if len(output['qa_pairs']) >= self.num_q else output['qa_pairs']
for output in outputs
]

dataframe[self.output_meta_key] = [output['metadata'] for output in outputs]
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")

Expand Down Expand Up @@ -248,11 +265,11 @@ def __init__(
self.logger = get_logger()
self.max_length = max_text_length
self.min_length = min_text_length
# self.prompt = MultiHopQAGeneratorPrompt(lang=self.lang)
# self.prompt = Text2MultiHopQAGeneratorPrompt(lang=self.lang)
if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = MultiHopQAGeneratorPrompt()
self.prompt_template = Text2MultiHopQAGeneratorPrompt()

def construct_examples(
self, raw_data: List[Dict[str, Any]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
from dataflow.core import OperatorABC
from dataflow.core import LLMServingABC
from dataflow.core.prompt import prompt_restrict
import ast
import json

from dataflow.prompts.doc2qa import Doc2QASeedQuestionGeneratorPrompt

from dataflow.prompts.text2qa import Text2QASeedQuestionGeneratorPrompt,Text2QAAutoPromptGeneratorPrompt

@prompt_restrict(
Text2QAAutoPromptGeneratorPrompt,
Text2QASeedQuestionGeneratorPrompt
)
@OPERATOR_REGISTRY.register()
class Doc2QAGenerator:
class Text2QAGenerator:
'''
SeedQAGenerator is a class that uses LLMs to generate QA pairs based on seed input.
'''
Expand All @@ -22,13 +27,13 @@ def __init__(self,
):
self.logger = get_logger()
self.llm_serving = llm_serving
self.prompt_template = Doc2QASeedQuestionGeneratorPrompt()
self.prompt_template = Text2QAAutoPromptGeneratorPrompt()

@staticmethod
def get_desc(lang: str = "zh"):
if lang == "zh":
return (
"该算子用于生成对应文档片段的QA对。\n\n"
"该算子用于为给定的文档片段生成种子QA对。\n\n"
"输入参数:\n"
"- input_key: 包含文档片段的字段名\n"
"- prompt_key: 包含提示词的字段名\n"
Expand All @@ -37,7 +42,7 @@ def get_desc(lang: str = "zh"):
)
elif lang == "en":
return (
"This operator generates QA pairs for given document fragments.\n\n"
"This operator generates seed QA pairs for given document fragments.\n\n"
"Input Parameters:\n"
"- input_key: Field name containing the content\n"
"- prompt_key: Field name containing the generated prompt\n"
Expand All @@ -59,35 +64,64 @@ def _validate_dataframe(self, dataframe: pd.DataFrame):
if conflict:
raise ValueError(f"The following column(s) already exist and would be overwritten: {conflict}")

def _build_prompt(self, df):
prompts = []
for index, row in df.iterrows():
prompts.append(row[self.prompt_key] + self.prompt_template.build_prompt() + row[self.input_key])
return prompts
def _build_prompt(self, df, types):
if types == "prompt":
self.prompt_template = Text2QAAutoPromptGeneratorPrompt()
texts = df[self.input_key].tolist()
output = [self.prompt_template.build_prompt(text) for text in texts]
elif types == "qa":
self.prompt_template = Text2QASeedQuestionGeneratorPrompt()
output = []
for index, row in df.iterrows():
output.append(row[self.output_prompt_key] + self.prompt_template.build_prompt() + row[self.input_key])
return output

def _parse_qa(self, response: str) -> tuple:
lines = response.strip().split('\n')
q = next((line[2:].strip() for line in lines if line.lower().startswith("q:")), "")
a = next((line[2:].strip() for line in lines if line.lower().startswith("a:")), "")
return q, a

def parse_list_string(self, s: str) -> list:
# 去掉前后的 [ ]
s = s.strip()[1:-1]
# 去掉多余逗号并按 , 切分
items = [item.strip() for item in s.split(",") if item.strip()]
return items

def run(
self,
storage: DataFlowStorage,
input_key:str = "text",
input_question_num:int = 1,
output_prompt_key:str = "generated_prompt",
output_quesion_key:str = "generated_question",
output_answer_key:str = "generated_answer"
):
'''
Runs the answer generation process, reading from the input file and saving results to output.
Runs the QA generation process, reading from the input file and saving results to output.
'''

self.input_key, self.prompt_key, self.output_question_key, self.output_answer_key = input_key, output_prompt_key, output_quesion_key, output_answer_key
self.input_key, self.input_question_num, self.output_prompt_key, self.output_question_key, self.output_answer_key = input_key, input_question_num, output_prompt_key, output_quesion_key, output_answer_key

dataframe = storage.read("dataframe")
self._validate_dataframe(dataframe)
formatted_prompts = self._build_prompt(dataframe)
formatted_prompts = self._build_prompt(dataframe, "prompt")
prompts = self.llm_serving.generate_from_input(user_inputs=formatted_prompts, system_prompt="")
prompts = [json.loads(p) for p in prompts]

expanded_rows = []
expanded_prompts = []

for idx, prompt_list in enumerate(prompts):
for p in prompt_list[:min(self.input_question_num,len(prompt_list))]:
expanded_rows.append(dataframe.iloc[idx].to_dict()) # 复制该行
expanded_prompts.append(p) # 对应的 prompt

dataframe = pd.DataFrame(expanded_rows)
dataframe[self.output_prompt_key] = expanded_prompts

formatted_prompts = self._build_prompt(dataframe, "qa")
responses = self.llm_serving.generate_from_input(user_inputs=formatted_prompts, system_prompt="")

questions, answers = zip(*[self._parse_qa(r) for r in responses])
Expand All @@ -98,4 +132,5 @@ def run(
output_file = storage.write(dataframe)
self.logger.info(f"Results saved to {output_file}")

return [self.output_question_key, self.output_answer_key]

return [self.output_question_key, self.output_answer_key]
2 changes: 1 addition & 1 deletion dataflow/operators/knowledge_cleaning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .generate.kbc_text_cleaner import KBCTextCleaner
from .generate.kbc_text_cleaner_batch import KBCTextCleanerBatch
from .generate.mathbook_question_extract import MathBookQuestionExtract
from .generate.kbc_multihop_qa_generator import KBCMultiHopQAGenerator
# from .generate.kbc_multihop_qa_generator import KBCMultiHopQAGenerator
from .generate.kbc_multihop_qa_generator_batch import KBCMultiHopQAGeneratorBatch


Expand Down
Loading