From 2b232d116e9222dd00f6b486b964d8cd25484b69 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 4 Feb 2026 22:08:42 +0800 Subject: [PATCH 1/2] fix: update search_service --- graphgen/operators/search/search_service.py | 31 +++++++++++++-------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/graphgen/operators/search/search_service.py b/graphgen/operators/search/search_service.py index 7e25e225..3d2f4de1 100644 --- a/graphgen/operators/search/search_service.py +++ b/graphgen/operators/search/search_service.py @@ -1,5 +1,5 @@ from functools import partial -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple from graphgen.bases import BaseOperator from graphgen.common.init_storage import init_storage @@ -22,8 +22,9 @@ def __init__( data_sources: list = None, **kwargs, ): - super().__init__(working_dir=working_dir, op_name="search_service") - self.working_dir = working_dir + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="search" + ) self.data_sources = data_sources or [] self.kwargs = kwargs self.search_storage = init_storage( @@ -137,20 +138,23 @@ def _process_single_source( return final_results - def process(self, batch: "pd.DataFrame") -> "pd.DataFrame": - import pandas as pd - - docs = batch.to_dict(orient="records") - + def process(self, batch: list) -> Tuple[list, dict]: + """ + Search for items in the batch across multiple data sources. + :return: A tuple of (results, meta_updates) + results: A list of search results from all data sources. + meta_updates: A dict mapping source IDs to lists of trace IDs for the search results. + """ self._init_searchers() - seed_data = [doc for doc in docs if doc and "content" in doc] + seed_data = [item for item in batch if item and "content" in item] if not seed_data: logger.warning("No valid seeds in batch") - return pd.DataFrame([]) + return [], {} all_results = [] + meta_updates = {} for data_source in self.data_sources: if data_source not in self.searchers: @@ -158,9 +162,12 @@ def process(self, batch: "pd.DataFrame") -> "pd.DataFrame": continue source_results = self._process_single_source(data_source, seed_data) - all_results.extend(source_results) + for result in source_results: + if "_trace_id" not in result: + result["_trace_id"] = self.get_trace_id(result) + all_results.append(result) if not all_results: logger.warning("No search results generated for this batch") - return pd.DataFrame(all_results) + return all_results, meta_updates From d4632ed917de7681cd727c64c5176190082753d6 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 4 Feb 2026 22:30:39 +0800 Subject: [PATCH 2/2] refactor: refactor search_service --- .../search/search_dna/search_dna_config.yaml | 2 +- .../search_protein/search_protein_config.yaml | 2 +- .../search/search_rna/search_rna_config.yaml | 2 +- graphgen/operators/search/search_service.py | 158 +++++++----------- 4 files changed, 65 insertions(+), 99 deletions(-) diff --git a/examples/search/search_dna/search_dna_config.yaml b/examples/search/search_dna/search_dna_config.yaml index db87b16e..81bbfb37 100644 --- a/examples/search/search_dna/search_dna_config.yaml +++ b/examples/search/search_dna/search_dna_config.yaml @@ -22,7 +22,7 @@ nodes: batch_size: 10 save_output: true params: - data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral + data_source: ncbi # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral ncbi_params: email: test@example.com # NCBI requires an email address tool: GraphGen # tool name for NCBI API diff --git a/examples/search/search_protein/search_protein_config.yaml b/examples/search/search_protein/search_protein_config.yaml index 6e6f085c..bbf42abd 100644 --- a/examples/search/search_protein/search_protein_config.yaml +++ b/examples/search/search_protein/search_protein_config.yaml @@ -22,7 +22,7 @@ nodes: batch_size: 10 save_output: true params: - data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot + data_source: uniprot # data source for searcher, support: wikipedia, google, uniprot uniprot_params: use_local_blast: true # whether to use local blast for uniprot search local_blast_db: /path/to/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot diff --git a/examples/search/search_rna/search_rna_config.yaml b/examples/search/search_rna/search_rna_config.yaml index c19793e8..5c02a484 100644 --- a/examples/search/search_rna/search_rna_config.yaml +++ b/examples/search/search_rna/search_rna_config.yaml @@ -22,7 +22,7 @@ nodes: batch_size: 10 save_output: true params: - data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral + data_source: rnacentral # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral rnacentral_params: use_local_blast: true # whether to use local blast for RNA search local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension) diff --git a/graphgen/operators/search/search_service.py b/graphgen/operators/search/search_service.py index 3d2f4de1..1a599e25 100644 --- a/graphgen/operators/search/search_service.py +++ b/graphgen/operators/search/search_service.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseOperator from graphgen.common.init_storage import init_storage -from graphgen.utils import compute_content_hash, logger, run_concurrent +from graphgen.utils import logger, run_concurrent if TYPE_CHECKING: import pandas as pd @@ -19,43 +19,47 @@ def __init__( self, working_dir: str = "cache", kv_backend: str = "rocksdb", - data_sources: list = None, + data_source: str = None, **kwargs, ): super().__init__( working_dir=working_dir, kv_backend=kv_backend, op_name="search" ) - self.data_sources = data_sources or [] + self.data_source = data_source self.kwargs = kwargs self.search_storage = init_storage( backend=kv_backend, working_dir=working_dir, namespace="search" ) - self.searchers = {} + self.searcher = None - def _init_searchers(self): + def _init_searcher(self): """ - Initialize all searchers (deferred import to avoid circular imports). + Initialize the searcher (deferred import to avoid circular imports). """ - for datasource in self.data_sources: - if datasource in self.searchers: - continue - if datasource == "uniprot": - from graphgen.models import UniProtSearch + if self.searcher is not None: + return + + if not self.data_source: + logger.error("Data source not specified") + return + + if self.data_source == "uniprot": + from graphgen.models import UniProtSearch - params = self.kwargs.get("uniprot_params", {}) - self.searchers[datasource] = UniProtSearch(**params) - elif datasource == "ncbi": - from graphgen.models import NCBISearch + params = self.kwargs.get("uniprot_params", {}) + self.searcher = UniProtSearch(**params) + elif self.data_source == "ncbi": + from graphgen.models import NCBISearch - params = self.kwargs.get("ncbi_params", {}) - self.searchers[datasource] = NCBISearch(**params) - elif datasource == "rnacentral": - from graphgen.models import RNACentralSearch + params = self.kwargs.get("ncbi_params", {}) + self.searcher = NCBISearch(**params) + elif self.data_source == "rnacentral": + from graphgen.models import RNACentralSearch - params = self.kwargs.get("rnacentral_params", {}) - self.searchers[datasource] = RNACentralSearch(**params) - else: - logger.error(f"Unknown data source: {datasource}, skipping") + params = self.kwargs.get("rnacentral_params", {}) + self.searcher = RNACentralSearch(**params) + else: + logger.error(f"Unknown data source: {self.data_source}") @staticmethod async def _perform_search( @@ -77,97 +81,59 @@ async def _perform_search( result = searcher_obj.search(query) if result: - result["_doc_id"] = compute_content_hash(str(data_source) + query, "doc-") result["data_source"] = data_source result["type"] = seed.get("type", "text") return result - def _process_single_source( - self, data_source: str, seed_data: list[dict] - ) -> list[dict]: - """ - process a single data source: check cache, search missing, update cache. - """ - searcher = self.searchers[data_source] - - seeds_with_ids = [] - for seed in seed_data: - query = seed.get("content", "") - if not query: - continue - doc_id = compute_content_hash(str(data_source) + query, "doc-") - seeds_with_ids.append((doc_id, seed)) - - if not seeds_with_ids: - return [] - - doc_ids = [doc_id for doc_id, _ in seeds_with_ids] - cached_results = self.search_storage.get_by_ids(doc_ids) - - to_search_seeds = [] - final_results = [] - - for (doc_id, seed), cached in zip(seeds_with_ids, cached_results): - if cached is not None: - if "_doc_id" not in cached: - cached["_doc_id"] = doc_id - final_results.append(cached) - else: - to_search_seeds.append(seed) - - if to_search_seeds: - new_results = run_concurrent( - partial( - self._perform_search, searcher_obj=searcher, data_source=data_source - ), - to_search_seeds, - desc=f"Searching {data_source} database", - unit="keyword", - ) - new_results = [res for res in new_results if res is not None] - - if new_results: - upsert_data = {res["_doc_id"]: res for res in new_results} - self.search_storage.upsert(upsert_data) - logger.info( - f"Saved {len(upsert_data)} new results to {data_source} cache" - ) - - final_results.extend(new_results) - - return final_results - def process(self, batch: list) -> Tuple[list, dict]: """ - Search for items in the batch across multiple data sources. + Search for items in the batch using the configured data source. + + :param batch: List of items with 'content' and '_trace_id' fields :return: A tuple of (results, meta_updates) - results: A list of search results from all data sources. + results: A list of search results. meta_updates: A dict mapping source IDs to lists of trace IDs for the search results. """ - self._init_searchers() + self._init_searcher() + + if not self.searcher: + logger.error("Searcher not initialized") + return [], {} - seed_data = [item for item in batch if item and "content" in item] + # Filter seeds with valid content and _trace_id + seed_data = [ + item for item in batch if item and "content" in item and "_trace_id" in item + ] if not seed_data: logger.warning("No valid seeds in batch") return [], {} - all_results = [] - meta_updates = {} + # Perform concurrent searches + results = run_concurrent( + partial( + self._perform_search, + searcher_obj=self.searcher, + data_source=self.data_source, + ), + seed_data, + desc=f"Searching {self.data_source} database", + unit="keyword", + ) - for data_source in self.data_sources: - if data_source not in self.searchers: - logger.error(f"Data source {data_source} not initialized, skipping") + # Filter out None results and add _trace_id from original seeds + final_results = [] + meta_updates = {} + for result, seed in zip(results, seed_data): + if result is None: continue + result["_trace_id"] = self.get_trace_id(result) + final_results.append(result) + # Map from source seed trace ID to search result trace ID + meta_updates.setdefault(seed["_trace_id"], []).append(result["_trace_id"]) - source_results = self._process_single_source(data_source, seed_data) - for result in source_results: - if "_trace_id" not in result: - result["_trace_id"] = self.get_trace_id(result) - all_results.append(result) - - if not all_results: + if not final_results: logger.warning("No search results generated for this batch") - return all_results, meta_updates + return final_results, meta_updates