Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dataflow/operators/conversations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys
from dataflow.utils.registry import LazyLoader
from .consistent_chat import ConsistentChatGenerator

cur_path = "dataflow/operators/conversations/"

_import_structure = {
"ConsistentChatGenerator": (cur_path + "consistent_chat.py", "ConsistentChatGenerator"),
}

sys.modules[__name__] = LazyLoader(__name__, "dataflow/operators/conversations/", _import_structure)
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import random
import json
import os
from dataflow.serving import APILLMServing_request
from dataflow.utils.registry import OPERATOR_REGISTRY
from dataflow import get_logger
from dataflow.core import OperatorABC
Expand All @@ -22,7 +20,7 @@ def __init__(self, llm_serving: LLMServingABC = None, num_dialogs_per_intent = 2
self.prompt = ConsistentChatPrompt()
self.logger.info(f'{self.__class__.__name__} initialized.')

def run(self):
def run(self, storage: DataFlowStorage):
all_query_prompts = []

# Step 1: Generate all queries using LLM
Expand All @@ -33,6 +31,7 @@ def run(self):
query_prompt = self.prompt.get_query_prompt(info_flow, topic)
all_query_prompts.append(query_prompt)
# Step 2: Generate queries by calling llm_serving once
self.logger.info("Generating queries...")
queries_list = self.llm_serving.generate_from_input(user_inputs=all_query_prompts)
valid_queries = []
cnt = 0
Expand All @@ -50,6 +49,7 @@ def run(self):
category = queries.get("category")
turns = queries.get("turns")
all_response_prompts.append(self.prompt.get_response_prompt(topic=category, queries=turns))
self.logger.info("Generating responses...")
responses_list = self.llm_serving.generate_from_input(user_inputs=all_response_prompts)

final_queries = []
Expand Down Expand Up @@ -87,9 +87,7 @@ def run(self):
continue
self.logger.info(f'Number of synthesized dialogues: {len(formatted_data)}')

output_filename = "generated_dialogs.json"
with open(output_filename, "w") as f:
json.dump(formatted_data, f, indent=4)

print(f"Data generated and saved to {output_filename}")
return formatted_data
df = pd.DataFrame(formatted_data)
storage.write(df)
self.logger.info(f'Number of synthesized dialogues: {len(df)} written to storage as DataFrame')
return df
Empty file.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading
Loading