From 9434fd383f909990716af386754dd0863c0d3586 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 30 Sep 2025 14:03:32 +0800 Subject: [PATCH 1/5] refactor: refactor generate pipeline --- .env.example | 1 + graphgen/configs/aggregated_config.yaml | 28 ++++----- graphgen/configs/atomic_config.yaml | 28 ++++----- graphgen/configs/cot_config.yaml | 19 +++--- graphgen/configs/multi_hop_config.yaml | 28 ++++----- graphgen/generate.py | 46 +++++++-------- graphgen/graphgen.py | 77 +++++++++++-------------- graphgen/models/tokenizer/__init__.py | 2 + 8 files changed, 114 insertions(+), 115 deletions(-) diff --git a/.env.example b/.env.example index 1a670126..c1102c1c 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,4 @@ +TOKENIZER_MODEL= SYNTHESIZER_MODEL= SYNTHESIZER_BASE_URL= SYNTHESIZER_API_KEY= diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index 45bd1d9e..828bf662 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -6,19 +6,21 @@ split: search: # web search configuration enabled: false # whether to enable web search search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -output_data_type: aggregated # atomic, aggregated, multi_hop, cot -output_data_format: ChatML # Alpaca, Sharegpt, ChatML -tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path -quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points +quiz_and_judge: # quiz and test whether the LLM masters the knowledge points enabled: true quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples -traverse_strategy: # strategy for clustering sub-graphs using comprehension loss - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 5 # maximum depth for graph traversal - max_extra_edges: 20 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both +partition: # graph partition configuration + method: ece # ece is a custom partition method based on comprehension loss + ece_params: + bidirectional: true # whether to traverse the graph in both directions + edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss + expand_method: max_width # expand method, support: max_width, max_depth + isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add + max_depth: 5 # maximum depth for graph traversal + max_extra_edges: 20 # max edges per direction (if expand_method="max_width") + max_tokens: 256 # restricts input length (if expand_method="max_tokens") + loss_strategy: only_edge # defines loss computation focus, support: only_edge, both +generate: + mode: aggregated # atomic, aggregated, multi_hop, cot + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index 0a58bfc5..39a6dc7d 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -6,19 +6,21 @@ split: search: # web search configuration enabled: false # whether to enable web search search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -output_data_type: atomic # atomic, aggregated, multi_hop, cot -output_data_format: Alpaca # Alpaca, Sharegpt, ChatML -tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path -quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points +quiz_and_judge: # quiz and test whether the LLM masters the knowledge points enabled: true quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples -traverse_strategy: # strategy for clustering sub-graphs using comprehension loss - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 3 # maximum depth for graph traversal - max_extra_edges: 5 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both +partition: # graph partition configuration + method: ece # ece is a custom partition method based on comprehension loss + ece_params: + bidirectional: true # whether to traverse the graph in both directions + edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss + expand_method: max_width # expand method, support: max_width, max_depth + isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add + max_depth: 3 # maximum depth for graph traversal + max_extra_edges: 5 # max edges per direction (if expand_method="max_width") + max_tokens: 256 # restricts input length (if expand_method="max_tokens") + loss_strategy: only_edge # defines loss computation focus, support: only_edge, both +generate: + mode: atomic # atomic, aggregated, multi_hop, cot + data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index f6ca8ad2..82340476 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -6,11 +6,14 @@ split: search: # web search configuration enabled: false # whether to enable web search search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -output_data_type: cot # atomic, aggregated, multi_hop, cot -output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML -tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path -method_params: - method: leiden - max_size: 20 # Maximum size of communities - use_lcc: false - random_seed: 42 +quiz_and_judge: # quiz and test whether the LLM masters the knowledge points + enabled: false +partition: # graph partition configuration + method: leiden # leiden is a community detection algorithm + leiden_params: + max_size: 20 # Maximum size of communities + use_lcc: false + random_seed: 42 +generate: + mode: cot # atomic, aggregated, multi_hop, cot + data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index 76f5ea06..1d0bd943 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -6,19 +6,21 @@ split: search: # web search configuration enabled: false # whether to enable web search search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia -output_data_type: multi_hop # atomic, aggregated, multi_hop, cot -output_data_format: ChatML # Alpaca, Sharegpt, ChatML -tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path -quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points +quiz_and_judge: # quiz and test whether the LLM masters the knowledge points enabled: false quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples -traverse_strategy: # strategy for clustering sub-graphs using comprehension loss - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 1 # maximum depth for graph traversal - max_extra_edges: 2 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both +partition: # graph partition configuration + method: ece # ece is a custom partition method based on comprehension loss + ece_params: + bidirectional: true # whether to traverse the graph in both directions + edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss + expand_method: max_width # expand method, support: max_width, max_depth + isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add + max_depth: 1 # maximum depth for graph traversal + max_extra_edges: 2 # max edges per direction (if expand_method="max_width") + max_tokens: 256 # restricts input length (if expand_method="max_tokens") + loss_strategy: only_edge # defines loss computation focus, support: only_edge, both +generate: + mode: multi_hop # strategy for generating multi-hop QA pairs + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/generate.py b/graphgen/generate.py index 26dafcce..83303ebd 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -6,8 +6,8 @@ import yaml from dotenv import load_dotenv -from .graphgen import GraphGen -from .utils import logger, set_logger +from graphgen.graphgen import GraphGen +from graphgen.utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -50,12 +50,10 @@ def main(): with open(args.config_file, "r", encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.FullLoader) - output_data_type = config["output_data_type"] + mode = config["generate"]["mode"] unique_id = int(time.time()) - output_path = os.path.join( - working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}" - ) + output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}_{mode}") set_working_dir(output_path) set_logger( @@ -65,35 +63,35 @@ def main(): logger.info( "GraphGen with unique ID %s logging to %s", unique_id, - os.path.join( - working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log" - ), + os.path.join(working_dir, f"{unique_id}.log"), ) - graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config) + graph_gen = GraphGen(working_dir=working_dir, output_path=output_path) - graph_gen.insert() + graph_gen.insert(read_config=config["read"], split_config=config["split"]) - if config["search"]["enabled"]: - graph_gen.search() + graph_gen.search(search_config=config["search"]) # Use pipeline according to the output data type - if output_data_type in ["atomic", "aggregated", "multi_hop"]: - if "quiz_and_judge_strategy" in config and config[ - "quiz_and_judge_strategy" - ].get("enabled", False): - graph_gen.quiz() - graph_gen.judge() + if mode in ["atomic", "aggregated", "multi_hop"]: + logger.info("Generation mode set to '%s'. Start generation.", mode) + if "quiz_and_judge" in config: + graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) else: logger.warning( "Quiz and Judge strategy is disabled. Edge sampling falls back to random." ) - graph_gen.traverse_strategy.edge_sampling = "random" - graph_gen.traverse() - elif output_data_type == "cot": - graph_gen.generate_reasoning(method_params=config["method_params"]) + # TODO: make edge sampling random + # graph_gen.traverse_strategy.edge_sampling = "random" + elif mode == "cot": + logger.info("Generation mode set to 'cot'. Start generation.") else: - raise ValueError(f"Unsupported output data type: {output_data_type}") + raise ValueError(f"Unsupported output data type: {mode}") + + graph_gen.generate( + partition_config=config["partition"], + generate_config=config["generate"], + ) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 8abb0b45..5b1323be 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,7 +1,7 @@ import asyncio import os import time -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Dict, cast import gradio as gr @@ -14,7 +14,6 @@ NetworkXStorage, OpenAIClient, Tokenizer, - TraverseStrategy, ) from graphgen.operators import ( chunk_documents, @@ -40,30 +39,24 @@ @dataclass class GraphGen: - unique_id: int = int(time.time()) working_dir: str = os.path.join(sys_path, "cache") - config: Dict = field(default_factory=dict) + output_path: str = os.path.join( + working_dir, "data", "graphgen", str(int(time.time())) + ) # llm tokenizer_instance: Tokenizer = None synthesizer_llm_client: OpenAIClient = None trainee_llm_client: OpenAIClient = None - # search - search_config: dict = field( - default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]} - ) - - # traversal - traverse_strategy: TraverseStrategy = None - # webui progress_bar: gr.Progress = None def __post_init__(self): self.tokenizer_instance: Tokenizer = Tokenizer( - model_name=self.config["tokenizer"] + model_name=os.getenv("TOKENIZER_MODEL") ) + self.synthesizer_llm_client: OpenAIClient = OpenAIClient( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), @@ -76,12 +69,6 @@ def __post_init__(self): base_url=os.getenv("TRAINEE_BASE_URL"), tokenizer=self.tokenizer_instance, ) - self.search_config = self.config["search"] - - if "traverse_strategy" in self.config: - self.traverse_strategy = TraverseStrategy( - **self.config["traverse_strategy"] - ) self.full_docs_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="full_docs" @@ -99,24 +86,17 @@ def __post_init__(self): self.working_dir, namespace="rephrase" ) self.qa_storage: JsonListStorage = JsonListStorage( - os.path.join( - self.working_dir, - "data", - "graphgen", - f"{self.unique_id}_{self.config['output_data_type']}", - ), + self.working_dir, namespace="qa", ) @async_to_sync_method - async def insert(self): + async def insert(self, read_config: Dict, split_config: Dict): """ insert chunks into the graph """ - input_file = self.config["read"]["input_file"] - # Step 1: Read files - data = read_files(input_file) + data = read_files(read_config["input_file"]) if len(data) == 0: logger.warning("No data to process") return @@ -141,8 +121,8 @@ async def insert(self): inserting_chunks = await chunk_documents( new_docs, - self.config["split"]["chunk_size"], - self.config["split"]["chunk_overlap"], + split_config["chunk_size"], + split_config["chunk_overlap"], self.tokenizer_instance, self.progress_bar, ) @@ -178,6 +158,7 @@ async def insert(self): return await self._insert_done() + return _add_entities_and_relations async def _insert_done(self): tasks = [] @@ -193,14 +174,12 @@ async def _insert_done(self): await asyncio.gather(*tasks) @async_to_sync_method - async def search(self): + async def search(self, search_config: Dict): logger.info( - "Search is %s", "enabled" if self.search_config["enabled"] else "disabled" + "Search is %s", "enabled" if search_config["enabled"] else "disabled" ) - if self.search_config["enabled"]: - logger.info( - "[Search] %s ...", ", ".join(self.search_config["search_types"]) - ) + if search_config["enabled"]: + logger.info("[Search] %s ...", ", ".join(search_config["search_types"])) all_nodes = await self.graph_storage.get_all_nodes() all_nodes_names = [node[0] for node in all_nodes] new_search_entities = await self.full_docs_storage.filter_keys( @@ -210,7 +189,7 @@ async def search(self): "[Search] Found %d entities to search", len(new_search_entities) ) _add_search_data = await search_all( - search_types=self.search_config["search_types"], + search_types=search_config["search_types"], search_entities=new_search_entities, ) if _add_search_data: @@ -230,27 +209,37 @@ async def search(self): await self.insert() @async_to_sync_method - async def quiz(self): - max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"] + async def quiz_and_judge(self, quiz_and_judge_config: Dict): + if quiz_and_judge_config is None or not quiz_and_judge_config.get( + "enabled", False + ): + logger.warning("Quiz and Judge is not used in this pipeline.") + return + max_samples = quiz_and_judge_config["quiz_samples"] await quiz( self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples, ) - await self.rephrase_storage.index_done_callback() - @async_to_sync_method - async def judge(self): - re_judge = self.config["quiz_and_judge_strategy"]["re_judge"] + # TODO: assert trainee_llm_client is valid before judge + re_judge = quiz_and_judge_config["re_judge"] _update_relations = await judge_statement( self.trainee_llm_client, self.graph_storage, self.rephrase_storage, re_judge, ) + await self.rephrase_storage.index_done_callback() await _update_relations.index_done_callback() + @async_to_sync_method + async def generate(self, partition_config: Dict, generate_config: Dict): + # Step 1: partition the graph + # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage) + pass + @async_to_sync_method async def traverse(self): output_data_type = self.config["output_data_type"] diff --git a/graphgen/models/tokenizer/__init__.py b/graphgen/models/tokenizer/__init__.py index 43c7e258..27559191 100644 --- a/graphgen/models/tokenizer/__init__.py +++ b/graphgen/models/tokenizer/__init__.py @@ -39,6 +39,8 @@ class Tokenizer(BaseTokenizer): _impl: BaseTokenizer = field(init=False, repr=False) def __post_init__(self): + if not self.model_name: + raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.") self._impl = get_tokenizer_impl(self.model_name) def encode(self, text: str) -> List[int]: From 76b53fa5b18f24ac7eee07322931d1d1f5b8d144 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 30 Sep 2025 15:07:00 +0800 Subject: [PATCH 2/5] fix: implement generate method --- graphgen/generate.py | 15 +++--- graphgen/graphgen.py | 54 ++++++++----------- graphgen/models/__init__.py | 1 - graphgen/models/strategy/__init__.py | 0 .../models/strategy/travserse_strategy.py | 28 ---------- graphgen/operators/build_kg/split_kg.py | 31 +++++------ graphgen/operators/traverse_graph.py | 22 +++----- 7 files changed, 54 insertions(+), 97 deletions(-) delete mode 100644 graphgen/models/strategy/__init__.py delete mode 100644 graphgen/models/strategy/travserse_strategy.py diff --git a/graphgen/generate.py b/graphgen/generate.py index 83303ebd..74208170 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -53,20 +53,20 @@ def main(): mode = config["generate"]["mode"] unique_id = int(time.time()) - output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}_{mode}") + output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") set_working_dir(output_path) set_logger( - os.path.join(output_path, f"{unique_id}.log"), + os.path.join(output_path, f"{unique_id}_{mode}.log"), if_stream=True, ) logger.info( "GraphGen with unique ID %s logging to %s", unique_id, - os.path.join(working_dir, f"{unique_id}.log"), + os.path.join(working_dir, f"{unique_id}_{mode}.log"), ) - graph_gen = GraphGen(working_dir=working_dir, output_path=output_path) + graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir) graph_gen.insert(read_config=config["read"], split_config=config["split"]) @@ -81,8 +81,11 @@ def main(): logger.warning( "Quiz and Judge strategy is disabled. Edge sampling falls back to random." ) - # TODO: make edge sampling random - # graph_gen.traverse_strategy.edge_sampling = "random" + assert ( + config["partition"]["method"] == "ece" + and "ece_params" in config["partition"] + ), "Only ECE partition with edge sampling is supported." + config["partition"]["ece_params"]["edge_sampling"] = "random" elif mode == "cot": logger.info("Generation mode set to 'cot'. Start generation.") else: diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 5b1323be..374f6384 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -39,10 +39,8 @@ @dataclass class GraphGen: + unique_id: int = int(time.time()) working_dir: str = os.path.join(sys_path, "cache") - output_path: str = os.path.join( - working_dir, "data", "graphgen", str(int(time.time())) - ) # llm tokenizer_instance: Tokenizer = None @@ -86,7 +84,7 @@ def __post_init__(self): self.working_dir, namespace="rephrase" ) self.qa_storage: JsonListStorage = JsonListStorage( - self.working_dir, + os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), namespace="qa", ) @@ -238,59 +236,49 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict): async def generate(self, partition_config: Dict, generate_config: Dict): # Step 1: partition the graph # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage) - pass - - @async_to_sync_method - async def traverse(self): - output_data_type = self.config["output_data_type"] - - if output_data_type == "atomic": + mode = generate_config["mode"] + if mode == "atomic": results = await traverse_graph_for_atomic( self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, - self.traverse_strategy, + partition_config["ece_params"], self.text_chunks_storage, self.progress_bar, ) - elif output_data_type == "multi_hop": + elif mode == "multi_hop": results = await traverse_graph_for_multi_hop( self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, - self.traverse_strategy, + partition_config["ece_params"], self.text_chunks_storage, self.progress_bar, ) - elif output_data_type == "aggregated": + elif mode == "aggregated": results = await traverse_graph_for_aggregated( self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, - self.traverse_strategy, + partition_config["ece_params"], self.text_chunks_storage, self.progress_bar, ) + elif mode == "cot": + method_params = generate_config.get("method_params", {}) + results = await generate_cot( + self.graph_storage, + self.synthesizer_llm_client, + method_params=method_params, + ) else: - raise ValueError(f"Unknown qa_form: {output_data_type}") - - results = format_generation_results( - results, output_data_format=self.config["output_data_format"] - ) - - await self.qa_storage.upsert(results) - await self.qa_storage.index_done_callback() - - @async_to_sync_method - async def generate_reasoning(self, method_params): - results = await generate_cot( - self.graph_storage, - self.synthesizer_llm_client, - method_params=method_params, - ) + raise ValueError(f"Unknown generation mode: {mode}") + # Step 2: generate QA pairs + # TODO + # Step 3: format results = format_generation_results( - results, output_data_format=self.config["output_data_format"] + results, output_data_format=generate_config["data_format"] ) await self.qa_storage.upsert(results) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index cea2fc45..f006f481 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -13,5 +13,4 @@ from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter from .storage.json_storage import JsonKVStorage, JsonListStorage from .storage.networkx_storage import NetworkXStorage -from .strategy.travserse_strategy import TraverseStrategy from .tokenizer import Tokenizer diff --git a/graphgen/models/strategy/__init__.py b/graphgen/models/strategy/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/models/strategy/travserse_strategy.py b/graphgen/models/strategy/travserse_strategy.py deleted file mode 100644 index 5739dea8..00000000 --- a/graphgen/models/strategy/travserse_strategy.py +++ /dev/null @@ -1,28 +0,0 @@ -from dataclasses import dataclass, fields - - -@dataclass -class TraverseStrategy: - # 生成的QA形式:原子、多跳、聚合型 - qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated" - # 最大边数和最大token数方法中选择一个生效 - expand_method: str = "max_tokens" # "max_width" or "max_tokens" - # 单向拓展还是双向拓展 - bidirectional: bool = True - # 每个方向拓展的最大边数 - max_extra_edges: int = 5 - # 最长token数 - max_tokens: int = 256 - # 每个方向拓展的最大深度 - max_depth: int = 2 - # 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合) - edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random" - # 孤立节点的处理策略 - isolated_node_strategy: str = "add" # "add" or "ignore" - loss_strategy: str = "only_edge" # only_edge, both - - def to_yaml(self): - strategy_dict = {} - for f in fields(self): - strategy_dict[f.name] = getattr(self, f.name) - return {"traverse_strategy": strategy_dict} diff --git a/graphgen/operators/build_kg/split_kg.py b/graphgen/operators/build_kg/split_kg.py index a3307a86..6033bc85 100644 --- a/graphgen/operators/build_kg/split_kg.py +++ b/graphgen/operators/build_kg/split_kg.py @@ -1,9 +1,10 @@ import random from collections import defaultdict +from typing import Dict from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import NetworkXStorage, TraverseStrategy +from graphgen.models import NetworkXStorage from graphgen.utils import logger @@ -247,9 +248,9 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches nodes: list, edges: list, graph_storage: NetworkXStorage, - traverse_strategy: TraverseStrategy, + traverse_strategy: Dict, ): - expand_method = traverse_strategy.expand_method + expand_method = traverse_strategy["expand_method"] if expand_method == "max_width": logger.info("Using max width strategy") elif expand_method == "max_tokens": @@ -257,8 +258,8 @@ async def get_batches_with_strategy( # pylint: disable=too-many-branches else: raise ValueError(f"Invalid expand method: {expand_method}") - max_depth = traverse_strategy.max_depth - edge_sampling = traverse_strategy.edge_sampling + max_depth = traverse_strategy["max_depth"] + edge_sampling = traverse_strategy["edge_sampling"] # 构建临接矩阵 edge_adj_list = defaultdict(list) @@ -275,16 +276,16 @@ async def get_cached_node_info(node_id: str) -> dict: for i, (node_name, _) in enumerate(nodes): node_dict[node_name] = i - if traverse_strategy.loss_strategy == "both": + if traverse_strategy["loss_strategy"] == "both": er_tuples = [ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges ] edges = _sort_tuples(er_tuples, edge_sampling) - elif traverse_strategy.loss_strategy == "only_edge": + elif traverse_strategy["loss_strategy"] == "only_edge": edges = _sort_edges(edges, edge_sampling) else: - raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}") + raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}") for i, (src, tgt, _) in enumerate(edges): edge_adj_list[src].append(i) @@ -315,10 +316,10 @@ async def get_cached_node_info(node_id: str) -> dict: nodes, edge, max_depth, - traverse_strategy.bidirectional, - traverse_strategy.max_extra_edges, + traverse_strategy["bidirectional"], + traverse_strategy["max_extra_edges"], edge_sampling, - traverse_strategy.loss_strategy, + traverse_strategy["loss_strategy"], ) else: level_n_edges = _get_level_n_edges_by_max_tokens( @@ -328,10 +329,10 @@ async def get_cached_node_info(node_id: str) -> dict: nodes, edge, max_depth, - traverse_strategy.bidirectional, - traverse_strategy.max_tokens, + traverse_strategy["bidirectional"], + traverse_strategy["max_tokens"], edge_sampling, - traverse_strategy.loss_strategy, + traverse_strategy["loss_strategy"], ) for _edge in level_n_edges: @@ -352,7 +353,7 @@ async def get_cached_node_info(node_id: str) -> dict: logger.info("Processing batches: %d", len(processing_batches)) # isolate nodes - isolated_node_strategy = traverse_strategy.isolated_node_strategy + isolated_node_strategy = traverse_strategy["isolated_node_strategy"] if isolated_node_strategy == "add": processing_batches = await _add_isolated_nodes( nodes, processing_batches, graph_storage diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/traverse_graph.py index ff3faab7..dff63b0b 100644 --- a/graphgen/operators/traverse_graph.py +++ b/graphgen/operators/traverse_graph.py @@ -1,15 +1,10 @@ import asyncio +from typing import Dict import gradio as gr from tqdm.asyncio import tqdm as tqdm_async -from graphgen.models import ( - JsonKVStorage, - NetworkXStorage, - OpenAIClient, - Tokenizer, - TraverseStrategy, -) +from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer from graphgen.operators.build_kg.split_kg import get_batches_with_strategy from graphgen.templates import ( ANSWER_REPHRASING_PROMPT, @@ -164,7 +159,7 @@ async def traverse_graph_for_aggregated( llm_client: OpenAIClient, tokenizer: Tokenizer, graph_storage: NetworkXStorage, - traverse_strategy: TraverseStrategy, + traverse_strategy: Dict, text_chunks_storage: JsonKVStorage, progress_bar: gr.Progress = None, max_concurrent: int = 1000, @@ -240,7 +235,7 @@ async def _process_single_batch( "question": question, "answer": context, "loss": get_average_loss( - _process_batch, traverse_strategy.loss_strategy + _process_batch, traverse_strategy["loss_strategy"] ), } } @@ -272,7 +267,7 @@ async def _process_single_batch( "question": qa["question"], "answer": qa["answer"], "loss": get_average_loss( - _process_batch, traverse_strategy.loss_strategy + _process_batch, traverse_strategy["loss_strategy"] ), } return final_results @@ -313,7 +308,7 @@ async def traverse_graph_for_atomic( llm_client: OpenAIClient, tokenizer: Tokenizer, graph_storage: NetworkXStorage, - traverse_strategy: TraverseStrategy, + traverse_strategy: Dict, text_chunks_storage: JsonKVStorage, progress_bar: gr.Progress = None, max_concurrent: int = 1000, @@ -331,7 +326,6 @@ async def traverse_graph_for_atomic( :return: question and answer """ - assert traverse_strategy.qa_form == "atomic" semaphore = asyncio.Semaphore(max_concurrent) def _parse_qa(qa: str) -> tuple: @@ -429,7 +423,7 @@ async def traverse_graph_for_multi_hop( llm_client: OpenAIClient, tokenizer: Tokenizer, graph_storage: NetworkXStorage, - traverse_strategy: TraverseStrategy, + traverse_strategy: Dict, text_chunks_storage: JsonKVStorage, progress_bar: gr.Progress = None, max_concurrent: int = 1000, @@ -517,7 +511,7 @@ async def _process_single_batch(_process_batch: tuple) -> dict: "question": question, "answer": answer, "loss": get_average_loss( - _process_batch, traverse_strategy.loss_strategy + _process_batch, traverse_strategy["loss_strategy"] ), } } From 5faf16dc82aa74a642e09cbf218f3421354a6b96 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 30 Sep 2025 15:12:48 +0800 Subject: [PATCH 3/5] fix: fix multi_hop & cot generation --- graphgen/generate.py | 2 +- graphgen/graphgen.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/graphgen/generate.py b/graphgen/generate.py index 74208170..527a2def 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -75,7 +75,7 @@ def main(): # Use pipeline according to the output data type if mode in ["atomic", "aggregated", "multi_hop"]: logger.info("Generation mode set to '%s'. Start generation.", mode) - if "quiz_and_judge" in config: + if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]: graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) else: logger.warning( diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 374f6384..1c9e2a40 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -265,11 +265,10 @@ async def generate(self, partition_config: Dict, generate_config: Dict): self.progress_bar, ) elif mode == "cot": - method_params = generate_config.get("method_params", {}) results = await generate_cot( self.graph_storage, self.synthesizer_llm_client, - method_params=method_params, + method_params=partition_config["leiden_params"], ) else: raise ValueError(f"Unknown generation mode: {mode}") From 05c827ded3f2ec8be56d372882e1c4b5a367abcb Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 30 Sep 2025 15:16:51 +0800 Subject: [PATCH 4/5] refactor: rename mathod_params --- graphgen/configs/aggregated_config.yaml | 2 +- graphgen/configs/atomic_config.yaml | 2 +- graphgen/configs/cot_config.yaml | 2 +- graphgen/configs/multi_hop_config.yaml | 2 +- graphgen/graphgen.py | 8 ++++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index 828bf662..2809ca77 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -12,7 +12,7 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points re_judge: false # whether to re-judge the existing quiz samples partition: # graph partition configuration method: ece # ece is a custom partition method based on comprehension loss - ece_params: + method_params: bidirectional: true # whether to traverse the graph in both directions edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss expand_method: max_width # expand method, support: max_width, max_depth diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index 39a6dc7d..90037ec3 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -12,7 +12,7 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points re_judge: false # whether to re-judge the existing quiz samples partition: # graph partition configuration method: ece # ece is a custom partition method based on comprehension loss - ece_params: + method_params: bidirectional: true # whether to traverse the graph in both directions edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss expand_method: max_width # expand method, support: max_width, max_depth diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 82340476..69d1e608 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -10,7 +10,7 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points enabled: false partition: # graph partition configuration method: leiden # leiden is a community detection algorithm - leiden_params: + method_params: max_size: 20 # Maximum size of communities use_lcc: false random_seed: 42 diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index 1d0bd943..1754cec4 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -12,7 +12,7 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points re_judge: false # whether to re-judge the existing quiz samples partition: # graph partition configuration method: ece # ece is a custom partition method based on comprehension loss - ece_params: + method_params: bidirectional: true # whether to traverse the graph in both directions edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss expand_method: max_width # expand method, support: max_width, max_depth diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 1c9e2a40..27cd0958 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -242,7 +242,7 @@ async def generate(self, partition_config: Dict, generate_config: Dict): self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, - partition_config["ece_params"], + partition_config["method_params"], self.text_chunks_storage, self.progress_bar, ) @@ -251,7 +251,7 @@ async def generate(self, partition_config: Dict, generate_config: Dict): self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, - partition_config["ece_params"], + partition_config["method_params"], self.text_chunks_storage, self.progress_bar, ) @@ -260,7 +260,7 @@ async def generate(self, partition_config: Dict, generate_config: Dict): self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, - partition_config["ece_params"], + partition_config["method_params"], self.text_chunks_storage, self.progress_bar, ) @@ -268,7 +268,7 @@ async def generate(self, partition_config: Dict, generate_config: Dict): results = await generate_cot( self.graph_storage, self.synthesizer_llm_client, - method_params=partition_config["leiden_params"], + method_params=partition_config["method_params"], ) else: raise ValueError(f"Unknown generation mode: {mode}") From fb3fc25fb5a12c81cd05de798e3e84224c185ad9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 30 Sep 2025 15:56:03 +0800 Subject: [PATCH 5/5] fix: fix webui --- graphgen/generate.py | 2 +- graphgen/graphgen.py | 18 ++++++----- webui/app.py | 71 +++++++++++++++++++++++++------------------- 3 files changed, 52 insertions(+), 39 deletions(-) diff --git a/graphgen/generate.py b/graphgen/generate.py index 527a2def..506b116e 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -85,7 +85,7 @@ def main(): config["partition"]["method"] == "ece" and "ece_params" in config["partition"] ), "Only ECE partition with edge sampling is supported." - config["partition"]["ece_params"]["edge_sampling"] = "random" + config["partition"]["method_params"]["edge_sampling"] = "random" elif mode == "cot": logger.info("Generation mode set to 'cot'. Start generation.") else: diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 27cd0958..d23b5dcd 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -51,17 +51,21 @@ class GraphGen: progress_bar: gr.Progress = None def __post_init__(self): - self.tokenizer_instance: Tokenizer = Tokenizer( + self.tokenizer_instance: Tokenizer = self.tokenizer_instance or Tokenizer( model_name=os.getenv("TOKENIZER_MODEL") ) - self.synthesizer_llm_client: OpenAIClient = OpenAIClient( - model_name=os.getenv("SYNTHESIZER_MODEL"), - api_key=os.getenv("SYNTHESIZER_API_KEY"), - base_url=os.getenv("SYNTHESIZER_BASE_URL"), - tokenizer=self.tokenizer_instance, + self.synthesizer_llm_client: OpenAIClient = ( + self.synthesizer_llm_client + or OpenAIClient( + model_name=os.getenv("SYNTHESIZER_MODEL"), + api_key=os.getenv("SYNTHESIZER_API_KEY"), + base_url=os.getenv("SYNTHESIZER_BASE_URL"), + tokenizer=self.tokenizer_instance, + ) ) - self.trainee_llm_client: OpenAIClient = OpenAIClient( + + self.trainee_llm_client: OpenAIClient = self.trainee_llm_client or OpenAIClient( model_name=os.getenv("TRAINEE_MODEL"), api_key=os.getenv("TRAINEE_API_KEY"), base_url=os.getenv("TRAINEE_BASE_URL"), diff --git a/webui/app.py b/webui/app.py index 8e2ec9ff..1179a7d0 100644 --- a/webui/app.py +++ b/webui/app.py @@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: set_logger(log_file, if_stream=True) os.environ.update({k: str(v) for k, v in env.items()}) - graph_gen = GraphGen(working_dir=working_dir, config=config) - # Set up LLM clients - graph_gen.synthesizer_llm_client = OpenAIClient( + tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) + synthesizer_llm_client = OpenAIClient( model_name=env.get("SYNTHESIZER_MODEL", ""), base_url=env.get("SYNTHESIZER_BASE_URL", ""), api_key=env.get("SYNTHESIZER_API_KEY", ""), request_limit=True, rpm=RPM(env.get("RPM", 1000)), tpm=TPM(env.get("TPM", 50000)), + tokenizer=tokenizer_instance, ) - - graph_gen.trainee_llm_client = OpenAIClient( + trainee_llm_client = OpenAIClient( model_name=env.get("TRAINEE_MODEL", ""), base_url=env.get("TRAINEE_BASE_URL", ""), api_key=env.get("TRAINEE_API_KEY", ""), request_limit=True, rpm=RPM(env.get("RPM", 1000)), tpm=TPM(env.get("TPM", 50000)), + tokenizer=tokenizer_instance, ) - graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) + graph_gen = GraphGen( + working_dir=working_dir, + tokenizer_instance=tokenizer_instance, + synthesizer_llm_client=synthesizer_llm_client, + trainee_llm_client=trainee_llm_client, + ) return graph_gen @@ -78,27 +83,32 @@ def sum_tokens(client): "chunk_size": params.chunk_size, "chunk_overlap": params.chunk_overlap, }, - "output_data_type": params.output_data_type, - "output_data_format": params.output_data_format, - "tokenizer": params.tokenizer, "search": {"enabled": False}, - "quiz_and_judge_strategy": { + "quiz_and_judge": { "enabled": params.if_trainee_model, "quiz_samples": params.quiz_samples, }, - "traverse_strategy": { - "bidirectional": params.bidirectional, - "expand_method": params.expand_method, - "max_extra_edges": params.max_extra_edges, - "max_tokens": params.max_tokens, - "max_depth": params.max_depth, - "edge_sampling": params.edge_sampling, - "isolated_node_strategy": params.isolated_node_strategy, - "loss_strategy": params.loss_strategy, + "partition": { + "method": "ece", + "method_params": { + "bidirectional": params.bidirectional, + "expand_method": params.expand_method, + "max_extra_edges": params.max_extra_edges, + "max_tokens": params.max_tokens, + "max_depth": params.max_depth, + "edge_sampling": params.edge_sampling, + "isolated_node_strategy": params.isolated_node_strategy, + "loss_strategy": params.loss_strategy, + }, + }, + "generate": { + "mode": params.output_data_type, + "data_format": params.output_data_format, }, } env = { + "TOKENIZER_MODEL": params.tokenizer, "SYNTHESIZER_BASE_URL": params.synthesizer_url, "SYNTHESIZER_MODEL": params.synthesizer_model, "TRAINEE_BASE_URL": params.trainee_url, @@ -128,19 +138,18 @@ def sum_tokens(client): try: # Process the data - graph_gen.insert() + graph_gen.insert(read_config=config["read"], split_config=config["split"]) if config["if_trainee_model"]: - # Generate quiz - graph_gen.quiz() - - # Judge statements - graph_gen.judge() + # Quiz and Judge + graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) else: - graph_gen.traverse_strategy.edge_sampling = "random" + config["partition"]["method_params"]["edge_sampling"] = "random" - # Traverse graph - graph_gen.traverse() + graph_gen.generate( + partition_config=config["partition"], + generate_config=config["generate"], + ) # Save output output_data = graph_gen.qa_storage.data @@ -249,6 +258,9 @@ def sum_tokens(client): ) with gr.Accordion(label=_("Model Config"), open=False): + tokenizer = gr.Textbox( + label="Tokenizer", value="cl100k_base", interactive=True + ) synthesizer_url = gr.Textbox( label="Synthesizer URL", value="https://api.siliconflow.cn/v1", @@ -300,9 +312,6 @@ def sum_tokens(client): step=100, interactive=True, ) - tokenizer = gr.Textbox( - label="Tokenizer", value="cl100k_base", interactive=True - ) output_data_type = gr.Radio( choices=["atomic", "multi_hop", "aggregated"], label="Output Data Type",