Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions baselines/BDS/bds.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import argparse
import asyncio
import json
import os
from dataclasses import dataclass
from typing import List

import networkx as nx
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import NetworkXStorage, OpenAIClient, Tokenizer
from graphgen.bases import BaseLLMWrapper
from graphgen.models import NetworkXStorage
from graphgen.operators import init_llm
from graphgen.utils import create_event_loop

QA_GENERATION_PROMPT = """
Expand Down Expand Up @@ -52,10 +52,12 @@ def _post_process(text: str) -> dict:
return {}


@dataclass
class BDS:
llm_client: OpenAIClient = None
max_concurrent: int = 1000
def __init__(self, llm_client: BaseLLMWrapper = None, max_concurrent: int = 1000):
self.llm_client: BaseLLMWrapper = llm_client or init_llm(
"synthesizer"
)
self.max_concurrent: int = max_concurrent
Comment on lines +56 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This refactoring from a @dataclass to a regular class is a good step towards better dependency management. However, there are a couple of improvements that could be made:

  1. Robustness: The init_llm function can return None if the synthesizer configuration is missing from the environment variables. This would cause self.llm_client to be None, leading to a runtime AttributeError later. It's best to add a check and fail early if the client can't be initialized.
  2. Configurability: The max_concurrent attribute is now hardcoded to 1000. The previous dataclass implementation allowed this to be configured at instantiation. It would be beneficial to restore this flexibility by making it an __init__ parameter.
    def __init__(
        self,
        synthesizer_llm_client: BaseLLMWrapper = None,
        max_concurrent: int = 1000,
    ):
        self.llm_client: BaseLLMWrapper = synthesizer_llm_client or init_llm(
            "synthesizer"
        )
        if not self.llm_client:
            raise ValueError(
                "LLM client for synthesizer could not be initialized. "
                "Check your environment variables for the SYNTHESIZER backend."
            )
        self.max_concurrent: int = max_concurrent


def generate(self, tasks: List[dict]) -> List[dict]:
loop = create_event_loop()
Expand Down Expand Up @@ -102,16 +104,7 @@ async def job(item):

load_dotenv()

tokenizer_instance: Tokenizer = Tokenizer(
model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base")
)
llm_client = OpenAIClient(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
tokenizer_instance=tokenizer_instance,
)
bds = BDS(llm_client=llm_client)
bds = BDS()

graph = NetworkXStorage.load_nx_graph(args.input_file)

Expand Down