diff --git a/.gitignore b/.gitignore index f7c1231..4261147 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,6 @@ archive* src/wandb/ src/configs/*/case* .cache/ +.cursor/ -dependencies/ \ No newline at end of file +dependencies \ No newline at end of file diff --git a/README.md b/README.md index dae601b..4c2ecff 100644 --- a/README.md +++ b/README.md @@ -152,18 +152,22 @@ python compare_env.py --config test2/case1/compare_deepseek-chat.yml ``` ### 🔊 Streaming TTS (Optional) -By default, the TTS pipeline generates audio serially and trims sentences that exceed the time budget. With `streaming_tts: true` in your config, a **streaming pipeline** is used instead: +By default, the TTS pipeline generates audio serially and trims sentences that exceed the time budget. With `streaming_tts: true` on a debater’s config entry, a **streaming pipeline** is used instead: - **Chunk-based processing**: the debate speech is split into paragraph-level chunks, each assigned a proportional share of the total time budget (opening: 240s, rebuttal: 240s, closing: 120s). - **Adaptive refinement**: FastSpeech2 estimates each chunk's duration; if off-target, an LLM rewrites the chunk to hit the target word count. Multiple TTS candidates are submitted in parallel and the closest-to-target is picked. - **Streaming overlap**: while chunk N plays, chunk N+1 is being refined and TTS-generated, minimizing gaps. - **No information loss**: instead of trimming sentences, text is rewritten to fit the budget. -To enable, set `streaming_tts: true` in the `env` section of your config: +To enable, set `streaming_tts: true` on each **debater** entry in your config (per-agent): ```yaml -env: - time_control: true - streaming_tts: true +debater: + - side: for + type: treedebater + streaming_tts: true + - side: against + type: treedebater + streaming_tts: true ``` After a streaming run, each speech produces a `*_chunks/` directory containing per-chunk audio, text, and a `chunk_profile.csv` with timing details. To visualize the overlap timeline: diff --git a/requirements.txt b/requirements.txt index 01f6999..566bc00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ openai +pydub backoff tqdm numpy @@ -7,6 +8,8 @@ google-generativeai matplotlib seaborn litellm +pydantic>=2 +instructor accelerate>=0.26.0 wandb tavily-python @@ -15,4 +18,5 @@ pulp g2p_en fastapi tavily-python +pydub git+https://github.com/UKPLab/sentence-transformers.git diff --git a/src/agents.py b/src/agents.py index 11c7e2e..c7747ef 100644 --- a/src/agents.py +++ b/src/agents.py @@ -20,7 +20,16 @@ from utils.constants import CLOSING_TIME, OPENING_TIME, REBUTTAL_TIME, WORDRATIO, deepseek_api_key from utils.model import HelperClient, safety_setting from utils.prompts import * -from utils.tool import log_file_path, logger +from utils.timing_log import ( + clear_speak_io_context, + log_io_block, + log_llm_io, + log_timing, + next_call_id, + one_line_preview, + set_speak_io_context, +) +from utils.tool import io_logger, io_logging_enabled, log_file_path, logger @dataclass @@ -42,6 +51,10 @@ class DebaterConfig(AgentConfig): use_rehearsal_tree: bool = True use_debate_flow_tree: bool = True url: str = "http://127.0.0.1:8081/" + streaming_tts: bool = False + streaming_listen: bool = False + single_pass_revision: bool = False + helper_model: str | None = None @dataclass @@ -114,20 +127,110 @@ def __init__(self, config) -> None: def speak(self, prompt, **kwargs): self._add_message("user", prompt) - logger.debug(f"[Conversation-History] {json.dumps(self.conversation)}") - logger.debug("[Prompt] " + prompt.strip().replace("\n", " ||| ")) - response = self._get_response(self.conversation, **kwargs) - if kwargs.get("n", 1) > 1: - logger.info(f"[Response] {response}") + call_id = next_call_id() + set_speak_io_context(call_id, "default_speak") + stage = getattr(self, "status", "") + side = getattr(self, "side", "") + try: + if io_logging_enabled(): + log_io_block( + io_logger, + call_id=call_id, + phase="default_speak", + title="Conversation-History", + body=json.dumps(self.conversation), + stage=stage, + side=side, + ) + log_io_block( + io_logger, + call_id=call_id, + phase="default_speak", + title="Prompt", + body=prompt, + stage=stage, + side=side, + ) + else: + log_llm_io( + logger, + phase="default_speak", + title="Conversation-History", + body=json.dumps(self.conversation), + stage=stage, + side=side, + ) + log_llm_io(logger, phase="default_speak", title="Prompt", body=prompt.strip(), stage=stage, side=side) + logger.debug( + f"[timing-meta] call_id={call_id} speak_session=default_speak n_messages={len(self.conversation)}" + ) + response = self._get_response(self.conversation, **kwargs) + if kwargs.get("n", 1) > 1: + if io_logging_enabled(): + log_llm_io( + logger, + phase="default_speak", + title="Response-multi", + body=str(response), + stage=stage, + side=side, + emit_main_ref=False, + ) + logger.info("[Response] " + one_line_preview(str(response)) + " [full text in *_io.log]") + else: + logger.info(f"[Response] {response}") + return response + if io_logging_enabled(): + log_io_block( + io_logger, + call_id=call_id, + phase="default_speak", + title="Response-Before-Post-Process", + body=str(response).strip(), + stage=stage, + side=side, + ) + else: + log_llm_io( + logger, + phase="default_speak", + title="Response-Before-Post-Process", + body=str(response).strip(), + stage=stage, + side=side, + ) + response = self.post_process(response, **kwargs) + if io_logging_enabled(): + log_io_block( + io_logger, + call_id=call_id, + phase="default_speak", + title="Response-After-Post-Process", + body=str(response).strip(), + stage=stage, + side=side, + ) + else: + log_llm_io( + logger, + phase="default_speak", + title="Response-After-Post-Process", + body=str(response).strip(), + stage=stage, + side=side, + ) return response - logger.debug("[Response-Before-Post-Process] " + response.strip().replace("\n", " ||| ")) - response = self.post_process(response, **kwargs) - logger.debug("[Response-After-Post-Process] " + response.strip().replace("\n", " ||| ")) - return response + finally: + clear_speak_io_context() def post_process(self, statement, **kwargs): self._add_message("assistant", f"{statement}") - logger.info("[Response] " + statement.strip().replace("\n", " ||| ")) + st = (statement or "").strip() + if io_logging_enabled(): + log_llm_io(logger, phase="post_process", title="Response", body=st, emit_main_ref=False) + logger.info("[Response] " + one_line_preview(st) + " [full text in *_io.log]") + else: + logger.info("[Response] " + st.replace("\n", " ||| ")) return statement def _get_response(self, messages, **kwargs): @@ -136,11 +239,23 @@ def _get_response(self, messages, **kwargs): kwargs.pop("max_words", None) kwargs.pop("time_control", None) kwargs.pop("streaming_tts", None) + kwargs.pop("streaming_listen", None) retry = 0 while retry < 3: try: + t0 = time.perf_counter() response = self.client(messages=messages, **kwargs) + elapsed = time.perf_counter() - t0 self.client_cost += response._hidden_params["response_cost"] + log_timing( + logger, + "debater_litellm_completion", + elapsed, + stage=getattr(self, "status", None), + side=getattr(self, "side", None), + model=self.config.model, + retry_attempt=retry, + ) response = [choice.message.content for choice in response.choices] if len(response) == 1: response = response[0] @@ -197,7 +312,7 @@ def rebuttal_generation(self, history, **kwargs): self.status = "rebuttal" self.listen(history) opponent = history[-1]["content"] - prompt = default_rebuttal_prompt.format(counter_act=self.counter_act, opponent=opponent, act=self.act) + prompt = default_rebuttal_prompt.format(counter_act=self.counter_act, opponent=opponent, act=self.act, motion=self.motion) prompt = prompt.replace("{n_words}", str(math.ceil(kwargs.get("max_time", REBUTTAL_TIME) / WORDRATIO["time"]))) response = self.speak(prompt, **kwargs) return response @@ -211,13 +326,20 @@ def closing_generation(self, history, **kwargs): response = self.speak(prompt, **kwargs) return response - def post_process(self, statement, max_time=-1, time_control=False, streaming_tts=False, **kwargs): + def post_process(self, statement, max_time=-1, time_control=False, streaming_tts=None, **kwargs): """ statement: AI生成的原始辩论陈述文本 max_time: 最大允许的发言时间(秒),-1表示无限制 time_control: 是否启用时间控制功能 - streaming_tts: 是否启用流式TTS(分chunk自适应refinement + 并行TTS) + streaming_tts: 是否启用流式TTS(分chunk自适应refinement + 并行TTS);默认取 debater 配置 """ + if streaming_tts is None: + streaming_tts = kwargs.pop("streaming_tts", None) + if streaming_tts is None: + streaming_tts = getattr(self.config, "streaming_tts", False) + else: + kwargs.pop("streaming_tts", None) + kwargs.pop("streaming_listen", None) statement = statement.strip() if statement is None: self._add_message("assistant", f"") @@ -245,7 +367,20 @@ def post_process(self, statement, max_time=-1, time_control=False, streaming_tts if max_time <= 0 or not time_control: self._add_message("assistant", f"{statement}") - logger.info("[Response] " + statement.strip().replace("\n", " ||| ")) + st = statement.strip() + if io_logging_enabled(): + log_llm_io( + logger, + phase="post_process", + title="Response-Final", + body=st, + stage=self.status, + side=self.side, + emit_main_ref=False, + ) + logger.info("[Response] " + one_line_preview(st) + " [full text in *_io.log]") + else: + logger.info("[Response] " + st.replace("\n", " ||| ")) return statement # NOTE the below part is time-consuming, can comment them and add "new_statement = statement" when developing @@ -262,22 +397,51 @@ def post_process(self, statement, max_time=-1, time_control=False, streaming_tts os.makedirs(audio_dir, exist_ok=True) audio_file = os.path.join(audio_dir, f"{self.config.type}_{self.status}_{self.side}.mp3") logger.info(f"[TTS-Start] Starting TTS for {self.config.type} {self.status} {self.side} (streaming={streaming_tts}, budget={max_time}s)") - logger.debug("[Time-Control] Statement: " + statement.replace("\n", " ||| ")) + log_llm_io( + logger, + phase="post_process", + title="Time-Control-Statement", + body=statement, + stage=self.status, + side=self.side, + ) if streaming_tts: # ---- Streaming TTS: chunk-based adaptive refinement + parallel TTS ---- logger.debug("[Time-Control] Using streaming TTS pipeline") + wall_t0 = time.perf_counter() content, reference, duration = convert_text_to_speech_streaming( statement, audio_file, total_budget_s=max_time, ) + wall_s = time.perf_counter() - wall_t0 logger.debug(f"[Time-Control] Save Audio: {audio_file}") logger.debug(f"[Time-Control] Streaming TTS Time: {duration:0.2f}") + log_timing( + logger, + "tts_wall_clock", + wall_s, + stage=self.status, + side=self.side, + kind="streaming_tts", + audio_duration_s=float(duration), + ) new_content = content else: # ---- Original serial TTS + trim ---- + wall_t0 = time.perf_counter() content, reference, duration = convert_text_to_speech(statement, audio_file) + wall_s = time.perf_counter() - wall_t0 logger.debug(f"[Time-Control] Save Audio: {audio_file}") logger.debug(f"[Time-Control] Original Time: {duration:0.2f}") + log_timing( + logger, + "tts_wall_clock", + wall_s, + stage=self.status, + side=self.side, + kind="batch_tts_encode", + audio_duration_s=float(duration), + ) if duration <= max_time: logger.debug(f"[Time-Control] Final Time: {duration:0.2f}") @@ -285,7 +449,16 @@ def post_process(self, statement, max_time=-1, time_control=False, streaming_tts else: save_file = audio_file.replace(".mp3", "_trimmed.mp3") logger.debug(f"[Time-Control] Save Trimmed Audio: {save_file}") + trim_t0 = time.perf_counter() duration, new_sentences = trim_audio_by_sentences(audio_file, save_file, max_duration=max_time * 1000) + log_timing( + logger, + "tts_trim_wall_clock", + time.perf_counter() - trim_t0, + stage=self.status, + side=self.side, + audio_duration_s=float(duration), + ) last_sentence = new_sentences[-1] idx = content.lower().find(last_sentence[:-1].lower()) # remove the punctuation if idx == -1: @@ -304,7 +477,20 @@ def post_process(self, statement, max_time=-1, time_control=False, streaming_tts # new_statement = statement self._add_message("assistant", f"{new_statement}") logger.info(f"[TTS-Done] Finished TTS for {self.config.type} {self.status} {self.side}") - logger.info("[Response] " + new_statement.strip().replace("\n", " ||| ")) + ns = new_statement.strip() + if io_logging_enabled(): + log_llm_io( + logger, + phase="post_process", + title="Response-After-TTS", + body=ns, + stage=self.status, + side=self.side, + emit_main_ref=False, + ) + logger.info("[Response] " + one_line_preview(ns) + " [full text in *_io.log]") + else: + logger.info("[Response] " + ns.replace("\n", " ||| ")) return new_statement @@ -390,7 +576,7 @@ def __init__(self, config, motion, port=8081) -> None: } self.BASE_URL = f"http://127.0.0.1:{port}/" logger.info(f"[BaselineDebater URL] {self.BASE_URL}") - logger.debug("[BaselineDebater init] " + str(self.input)) + log_llm_io(logger, phase="baseline", title="BaselineDebater-init", body=str(self.input)) def _make_request(self, url, data): max_retries = 3 @@ -415,10 +601,10 @@ def opening_generation(self, history, **kwargs): opening_response = self._make_request(self.BASE_URL + "v1/argument", self.input) opening = opening_response["Result"] self.input["Reference"] = opening_response["Reference"] - logger.debug("[Baseline-opening-input] " + str(self.input).replace("\n", " ||| ")) - logger.debug("[Baseline-opening-before] " + opening.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-opening-input", body=str(self.input).replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-opening-before", body=opening.strip()) opening = self.post_process(opening, **kwargs) - logger.debug("[Baseline-opening-after] " + opening.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-opening-after", body=opening.strip()) return opening def rebuttal_generation(self, history, **kwargs): @@ -435,10 +621,10 @@ def rebuttal_generation(self, history, **kwargs): rebuttal_response = self._make_request(self.BASE_URL + "v1/rebuttal", self.input) rebuttal = rebuttal_response["Result"] self.input["Reference"] = rebuttal_response["Reference"] - logger.debug("[Baseline-rebuttal-input] " + str(self.input).replace("\n", " ||| ")) - logger.debug("[Baseline-rebuttal-before] " + rebuttal.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-rebuttal-input", body=str(self.input).replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-rebuttal-before", body=rebuttal.strip()) rebuttal = self.post_process(rebuttal, **kwargs) - logger.debug("[Baseline-rebuttal-after] " + rebuttal.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-rebuttal-after", body=rebuttal.strip()) return rebuttal def closing_generation(self, history, **kwargs): @@ -456,10 +642,10 @@ def closing_generation(self, history, **kwargs): summary_response = self._make_request(self.BASE_URL + "v1/summary", self.input) summary = summary_response["Result"] self.input["Reference"] = summary_response["Reference"] - logger.debug("[Baseline-summary-input] " + str(self.input).replace("\n", " ||| ")) - logger.debug("[Baseline-summary-before] " + summary.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-summary-input", body=str(self.input).replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-summary-before", body=summary.strip()) summary = self.post_process(summary, **kwargs) - logger.debug("[Baseline-summary-after] " + summary.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="baseline", title="Baseline-summary-after", body=summary.strip()) return summary def reset_stage(self, stage, side, new_content): diff --git a/src/compare_env.py b/src/compare_env.py index 041e90b..ab6f366 100644 --- a/src/compare_env.py +++ b/src/compare_env.py @@ -2,6 +2,7 @@ import copy import json import os +import time import yaml @@ -9,6 +10,7 @@ from env import Env, EnvConfig, extract_overall_score from ouragents import TreeDebater from utils.constants import CLOSING_TIME, OPENING_TIME, REBUTTAL_TIME +from utils.timing_log import log_timing from utils.tool import logger @@ -28,7 +30,6 @@ def __init__(self, config, debug, baseline_type="baseline", test_type="treedebat self.claim_pool_size = config.claim_pool_size self.reverse = config.reverse # if to reverse the order of the debaters self.time_control = config.time_control - self.streaming_tts = config.streaming_tts self.debug = debug self.baseline_type = baseline_type self.test_type = test_type @@ -81,12 +82,18 @@ def step_play(self, side, stage, history, max_time): } base_response = baseline_call[stage]( - history=history, max_time=max_time, time_control=self.time_control, streaming_tts=self.streaming_tts, + history=history, + max_time=max_time, + time_control=self.time_control, + streaming_tts=self.baseline_debaters[side].config.streaming_tts, ) # Generate test response using reference history test_response = test_call[stage]( - history=history, max_time=max_time, time_control=self.time_control, streaming_tts=self.streaming_tts, + history=history, + max_time=max_time, + time_control=self.time_control, + streaming_tts=self.test_debaters[side].config.streaming_tts, ) return base_response, test_response @@ -105,6 +112,7 @@ def compare_play(self): # Run through each stage for stage in ["preparation", "opening", "rebuttal", "closing"]: logger.info(f"[{stage}] Start Comparison") + t_st = time.perf_counter() if stage == "preparation": # Generate claims for both reference and test debaters @@ -150,6 +158,7 @@ def compare_play(self): "keep_response": keep_response, } + log_timing(logger, "compare_env_stage_wall", time.perf_counter() - t_st, stage=stage, motion=self.motion[:80]) logger.info(f"[{stage}] Comparison Complete") if self.debug: diff --git a/src/configs/base.yml b/src/configs/base.yml index a6450ba..7f36244 100644 --- a/src/configs/base.yml +++ b/src/configs/base.yml @@ -1,28 +1,33 @@ env: - motion: AI will lead to the decline of human creative arts + motion: Learning to be a good writer still matters in the age of ai judge_num: 1 audience_num: 3 claim_pool_size: 50 reverse: False time_control: True - streaming_tts: False debater: - side: for - model: moonshot-v1-128k + model: deepseek-chat type: treedebater temperature: 1.0 use_retrieval: True add_retrieval_feedback: False - pool_file: ../results409/moonshot-v1-128k/ai_will_lead_to_the_decline_of_human_creative_arts_pool_for.json + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: False + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_for.json - side: against - model: moonshot-v1-128k + model: deepseek-chat type: treedebater temperature: 1.0 use_retrieval: True add_retrieval_feedback: False - pool_file: ../results409/moonshot-v1-128k/ai_will_lead_to_the_decline_of_human_creative_arts_pool_against.json + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: False + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_against.json judge: temperature: 0.0 diff --git a/src/configs/base_st.yml b/src/configs/base_st.yml new file mode 100644 index 0000000..0912b02 --- /dev/null +++ b/src/configs/base_st.yml @@ -0,0 +1,35 @@ +env: + motion: Learning to be a good writer still matters in the age of ai + judge_num: 1 + audience_num: 3 + claim_pool_size: 50 + reverse: False + time_control: True + +debater: + - side: for + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_for.json + + - side: against + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_against.json +judge: + temperature: 0.0 + +audience: + temperature: 1 diff --git a/src/configs/base_st_io.yml b/src/configs/base_st_io.yml new file mode 100644 index 0000000..905a07f --- /dev/null +++ b/src/configs/base_st_io.yml @@ -0,0 +1,30 @@ +env: + motion: Learning to be a good writer still matters in the age of ai + judge_num: 1 + audience_num: 3 + claim_pool_size: 50 + reverse: False + time_control: True + +debater: + - side: for + model: deepseek-chat + type: default + temperature: 1.0 + streaming_tts: False + + - side: against + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_against.json +judge: + temperature: 0.0 + +audience: + temperature: 1 diff --git a/src/configs/compare.yml b/src/configs/compare.yml index c7d6bba..c591d32 100644 --- a/src/configs/compare.yml +++ b/src/configs/compare.yml @@ -5,7 +5,6 @@ env: claim_pool_size: 50 reverse: False time_control: True - streaming_tts: False debater: - side: for @@ -14,6 +13,9 @@ debater: temperature: 1.0 use_retrieval: True add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: False pool_file: ../results10/deepseek-chat/it_is_time_to_welcome_an_a.i._tutor_in_the_classroom_pool_for.json - side: against @@ -22,6 +24,9 @@ debater: temperature: 1.0 use_retrieval: True add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: False pool_file: ../results10/deepseek-chat/it_is_time_to_welcome_an_a.i._tutor_in_the_classroom_pool_against.json judge: temperature: 0.0 diff --git a/src/configs/create_config.py b/src/configs/create_config.py index 5d1676a..56bb83a 100644 --- a/src/configs/create_config.py +++ b/src/configs/create_config.py @@ -80,6 +80,8 @@ def get_options(): "type": "treedebater", "temperature": 1.0, "use_retrieval": True, + "single_pass_revision": False, + "helper_model": model1, "pool_file": f"../results{args.pool_version}/{model_name1}/{motion_name}_pool_for.json", } configs["debater"][1] = { @@ -106,6 +108,8 @@ def get_options(): "temperature": 1.0, "use_retrieval": True, "add_retrieval_feedback": True, + "single_pass_revision": False, + "helper_model": model1, "pool_file": f"../results{args.pool_version}/{model_name2}/{motion_name}_pool_against.json", } diff --git a/src/configs/overlap_debate.yml b/src/configs/overlap_debate.yml new file mode 100644 index 0000000..720fcb0 --- /dev/null +++ b/src/configs/overlap_debate.yml @@ -0,0 +1,41 @@ +# Example: overlapping streaming listen + streaming TTS (see streaming/overlap.py; run: python -m streaming.overlap). +# Both sides must be treedebater for per-turn StreamingInputEnv; set streaming_listen on listeners. + +env: + motion: Learning to be a good writer still matters in the age of ai + judge_num: 1 + audience_num: 3 + claim_pool_size: 50 + reverse: False + time_control: True + +debater: + - side: for + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: True + streaming_listen: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_for.json + + - side: against + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: False + helper_model: deepseek-chat + streaming_tts: True + streaming_listen: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_against.json + +judge: + temperature: 0.0 + +audience: + temperature: 1 diff --git a/src/configs/overlap_debate_2.yml b/src/configs/overlap_debate_2.yml new file mode 100644 index 0000000..1cc242d --- /dev/null +++ b/src/configs/overlap_debate_2.yml @@ -0,0 +1,41 @@ +# Example: overlapping streaming listen + streaming TTS (see streaming/overlap.py; run: python -m streaming.overlap). +# Both sides must be treedebater for per-turn StreamingInputEnv; set streaming_listen on listeners. + +env: + motion: Learning to be a good writer still matters in the age of ai + judge_num: 1 + audience_num: 3 + claim_pool_size: 50 + reverse: False + time_control: True + +debater: + - side: for + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: True + helper_model: deepseek-chat + streaming_tts: True + streaming_listen: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_for.json + + - side: against + model: deepseek-chat + type: treedebater + temperature: 1.0 + use_retrieval: True + add_retrieval_feedback: False + single_pass_revision: True + helper_model: deepseek-chat + streaming_tts: True + streaming_listen: True + pool_file: ../results/deepseek-chat/learning_to_be_a_good_writer_still_matters_in_the_age_of_ai_pool_against.json + +judge: + temperature: 0.0 + +audience: + temperature: 1 diff --git a/src/debate_tree.py b/src/debate_tree.py index 0a7ed11..ab650bc 100644 --- a/src/debate_tree.py +++ b/src/debate_tree.py @@ -15,7 +15,9 @@ from evaluator import evaluate_defense_strength, evaluate_support_strength from utils.constants import get_embeddings +from utils.llm_schemas import StatementsResponse from utils.model import HelperClient, reward_model +from utils.timing_log import log_llm_io from utils.tool import get_response_with_retry, logger @@ -69,9 +71,15 @@ def propose_new_claims(proposer, motion, side, history, n): " ]\n" "}}\n" ) - logger.debug("[Proposer-Tree-Helper-Prompt] " + prompt.strip().replace("\n", " ||| ")) - content, response = get_response_with_retry(proposer, prompt, "statements", temperature=1) - logger.debug("[Proposer-Tree-Helper-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="debate_tree", title="Proposer-Tree-Helper-Prompt", body=prompt.strip(), side=side) + content, response = get_response_with_retry( + proposer, + prompt, + "statements", + response_model=StatementsResponse, + temperature=1, + ) + log_llm_io(logger, phase="debate_tree", title="Proposer-Tree-Helper-Response", body=response.strip(), side=side) return content @@ -385,32 +393,33 @@ def get_node_by_status_recursive(self, node, status, side=None): status_nodes.extend(self.search_status(child, status, side=side)) return status_nodes - def print_tree_recursive(self, node, level=0, prefix="", include_status=False, max_print_level=None): + def print_tree_recursive(self, node, level=0, lines=None, include_status=False, max_print_level=None): + if lines is None: + lines = [] if node is not None: if max_print_level is not None and level > max_print_level: - return "" - score_str = "" + return lines + score_parts = [] if node.scores is not None: for k, v in node.scores.items(): if k == "defense" and v != 0: - score_str += f"Attack Score: {v:.1f}, " + score_parts.append(f"Attack Score: {v:.1f}") elif k == "support" and v != 0: - score_str += f"Support Score: {v:.1f}, " - score_str = score_str.strip(", ") - prefix += ( - " " * level * 4 + score_parts.append(f"Support Score: {v:.1f}") + score_str = ", ".join(score_parts) + lines.append( + ' ' * level * 4 + f"Level-{level} Data (Visit: {node.visit_count}, Status: {node.status}): {node.data}, Scores: {score_str}\n" ) for child in node.children: - prefix += self.print_tree_recursive( - child, level + 1, include_status=False, max_print_level=max_print_level - ) - return prefix + self.print_tree_recursive(child, level + 1, lines=lines, include_status=False, max_print_level=max_print_level) + return lines def print_tree(self, prefix="", include_status=False, max_print_level=None): - return self.print_tree_recursive( - self.root, level=0, prefix=prefix, include_status=include_status, max_print_level=max_print_level - ) + lines = [prefix] if prefix else [] + self.print_tree_recursive(self.root, level=0, lines=lines, include_status=include_status, max_print_level=max_print_level) + return "".join(lines) + def get_tree_info(self): info = { @@ -535,7 +544,9 @@ def expand_tree(self, node, max_level=3, max_branch=3): node_list.extend(node.children) cur_level += 1 - def print_tree_recursive(self, node, level=0, prefix="", include_status=False, max_print_level=None): + def print_tree_recursive(self, node, level=0, lines=None, include_status=False, max_print_level=None): + if lines is None: + lines = [] if level == 0: position = "Root Claim" elif level % 2 == 1: @@ -545,28 +556,26 @@ def print_tree_recursive(self, node, level=0, prefix="", include_status=False, m if node is not None: if max_print_level is not None and level > max_print_level: - return "" - - score_str = "" + return lines + + score_parts = [] if node.scores is not None: for k, v in node.scores.items(): if k == "defense" and v != 0: - score_str += f"Attack Score: {v:0.1f}, " + score_parts.append(f"Attack Score: {v:0.1f}") elif k == "support" and v != 0: - score_str += f"Support Score: {v:0.1f}, " + score_parts.append(f"Support Score: {v:0.1f}") elif k == "minimax_strength" and v != 0: - score_str += f"Strength: {v:0.1f}, " - score_str = score_str.strip(", ") - + score_parts.append(f"Strength: {v:0.1f}") + score_str = ", ".join(score_parts) + if include_status: - prefix += " " * level * 4 + f"Level-{level} {position}: {node.data}, Scores: {score_str}\n" + lines.append(' ' * level * 4 + f"Level-{level} {position}: {node.data}, Scores: {score_str}\n") else: - prefix += " " * level * 4 + f"Level-{level} {position}: {node.data}\n" + lines.append(' ' * level * 4 + f"Level-{level} {position}: {node.data}\n") for child in node.children: - prefix += self.print_tree_recursive( - child, level + 1, include_status=include_status, max_print_level=max_print_level - ) - return prefix + self.print_tree_recursive(child, level + 1, lines=lines, include_status=include_status, max_print_level=max_print_level) + return lines def backward(self, level_decoy=0.8, support_weight=0.5): self.backward_recursive(self.root, level_decoy, support_weight) @@ -619,7 +628,9 @@ def __init__(self, motion, side): self.meta_attack_list = [] self.meta_rebuttal_list = [] - def print_tree_recursive(self, node, level=0, prefix="", include_status=False, max_print_level=None, reverse=False): + def print_tree_recursive(self, node, level=0, lines=None, include_status=False, max_print_level=None, reverse=False): + if lines is None: + lines = [] if level == 0: position = "Motion" elif level == 1: @@ -631,38 +642,34 @@ def print_tree_recursive(self, node, level=0, prefix="", include_status=False, m if node is not None: if max_print_level is not None and level > max_print_level: - return "" + return lines if level == 0: - prefix += " " * level * 4 + f"Level-{level} Motion: {self.motion}, Side: {self.side}\n" + lines.append(' ' * level * 4 + f"Level-{level} Motion: {self.motion}, Side: {self.side}\n") else: if include_status: - prefix += ( - " " * level * 4 - + f"Level-{level} {position} (Visit: {node.visit_count}, Status: {node.status}): {node.data}\n" - ) + lines.append(' ' * level * 4 + f"Level-{level} {position} (Visit: {node.visit_count}, Status: {node.status}): {node.data}\n") else: - prefix += " " * level * 4 + f"Level-{level} {position}: {node.data}\n" + lines.append(' ' * level * 4 + f"Level-{level} {position}: {node.data}\n") for child in node.children: - prefix += self.print_tree_recursive( - child, level + 1, include_status=include_status, max_print_level=max_print_level, reverse=reverse - ) - return prefix + self.print_tree_recursive(child, level + 1, lines=lines, include_status=include_status, max_print_level=max_print_level, reverse=reverse) + return lines def print_tree(self, prefix="", include_status=False, max_print_level=None, meta_info=True, reverse=False): - info_str = self.print_tree_recursive( + lines = [prefix] if prefix else [] + self.print_tree_recursive( self.root, level=0, - prefix=prefix, + lines=lines, include_status=include_status, max_print_level=max_print_level, reverse=reverse, ) if meta_info: if len(self.meta_attack_list) > 0: - info_str += f"Meta Attack to this debate tree: {self.meta_attack_list}\n" + lines.append(f"Meta Attack to this debate tree: {self.meta_attack_list}\n") if len(self.meta_rebuttal_list) > 0: - info_str += f"Meta Rebuttal to the attacks on this debate tree: {self.meta_rebuttal_list}" - return info_str + lines.append(f"Meta Rebuttal to the attacks on this debate tree: {self.meta_rebuttal_list}") + return "".join(lines) def update_node(self, action, new_claim=None, new_argument=None, target=None): if len(new_claim) == 0: @@ -709,6 +716,7 @@ def update_node(self, action, new_claim=None, new_argument=None, target=None): if action == "reinforce": match_node.argument.extend(new_argument) + match_node.argument = list(set(match_node.argument)) match_node.update_status(match_node.status) elif action == "rebut" or action == "attack": new_node = match_node.add_node(new_claim=new_claim, new_argument=new_argument) diff --git a/src/env.py b/src/env.py index 0673e7c..9df487c 100644 --- a/src/env.py +++ b/src/env.py @@ -1,9 +1,7 @@ import argparse import json -import os import time from dataclasses import dataclass -from functools import partial from typing import List import yaml @@ -11,7 +9,7 @@ from agents import Audience, AudienceConfig, BaselineDebater, Debater, DebaterConfig, HumanDebater, Judge, JudgeConfig from ouragents import TreeDebater from utils.constants import CLOSING_TIME, OPENING_TIME, REBUTTAL_TIME -from utils.model import HelperClient +from utils.timing_log import log_timing from utils.tool import logger @@ -26,7 +24,6 @@ class EnvConfig: claim_pool_size: int = 50 reverse: bool = False time_control: bool = True - streaming_tts: bool = False def extract_overall_score(obj_scores): # larger is better @@ -47,7 +44,6 @@ def __init__(self, config, debug) -> None: self.claim_pool_size = config.claim_pool_size self.reverse = config.reverse self.time_control = config.time_control - self.streaming_tts = config.streaming_tts self.debug = debug # init players @@ -87,6 +83,7 @@ def play(self, pre_only=False): order = ["for", "against"] if not self.reverse else ["against", "for"] for stage in ["preparation", "opening", "rebuttal", "closing"]: logger.info(f"[{stage}] Start") + t_stage = time.perf_counter() if stage == "preparation": for side in order: if self.debaters[side].type in ["treedebater"]: @@ -96,26 +93,33 @@ def play(self, pre_only=False): for side in order: player = self.debaters[side] response = player.opening_generation( - history=self.debate_process[1:], max_time=OPENING_TIME, - time_control=self.time_control, streaming_tts=self.streaming_tts, + history=self.debate_process[1:], + max_time=OPENING_TIME, + time_control=self.time_control, + streaming_tts=player.config.streaming_tts, ) self.debate_process.append({"stage": stage, "side": side, "content": response}) elif stage == "rebuttal": for side in order: player = self.debaters[side] response = player.rebuttal_generation( - history=self.debate_process[1:], max_time=REBUTTAL_TIME, - time_control=self.time_control, streaming_tts=self.streaming_tts, + history=self.debate_process[1:], + max_time=REBUTTAL_TIME, + time_control=self.time_control, + streaming_tts=player.config.streaming_tts, ) self.debate_process.append({"stage": stage, "side": side, "content": response}) elif stage == "closing": - for side in order: # reverse to make compatible with agent4debate + for side in order: player = self.debaters[side] response = player.closing_generation( - history=self.debate_process[1:], max_time=CLOSING_TIME, - time_control=self.time_control, streaming_tts=self.streaming_tts, + history=self.debate_process[1:], + max_time=CLOSING_TIME, + time_control=self.time_control, + streaming_tts=player.config.streaming_tts, ) self.debate_process.append({"stage": stage, "side": side, "content": response}) + log_timing(logger, "env_stage_wall", time.perf_counter() - t_stage, stage=stage, motion=self.motion[:80]) logger.info(f"[{stage}] Done") if self.debug: response = input("Press N to stop: ") @@ -124,6 +128,7 @@ def play(self, pre_only=False): def eval(self, process=None): logger.info("[Evaluation] Start") + t0 = time.perf_counter() output = {} process = self.debate_process if process is None else process process = [x for x in process if x["stage"] != "settings"] @@ -161,17 +166,20 @@ def eval(self, process=None): ) output[f"{side}_surprise_explanation"].append(surprise_explanation[0]) + log_timing(logger, "evaluation_wall", time.perf_counter() - t0, motion=self.motion[:80]) logger.info("[Evaluation] Done") return output, side_info def compare_debate(self, comparison_process, order_reverse=False): logger.info("[Comparison Evaluation] Start") + t_cmp = time.perf_counter() order = ["baseline_response", "test_response"] if not order_reverse else ["test_response", "baseline_response"] output = {} context = [] for i, phase in enumerate(comparison_process.keys()): logger.info(f"[{phase}] Start") + t_ph = time.perf_counter() output[phase] = {} stage, side = phase.split("_") @@ -206,7 +214,9 @@ def compare_debate(self, comparison_process, order_reverse=False): "content": comparison_process[phase]["keep_response"].split("**Reference**")[0], } ) + log_timing(logger, "comparison_phase_wall", time.perf_counter() - t_ph, phase=phase, motion=self.motion[:80]) + log_timing(logger, "comparison_evaluation_total_wall", time.perf_counter() - t_cmp, motion=self.motion[:80]) logger.info("[Comparison Evaluation] Done") return output diff --git a/src/evaluator.py b/src/evaluator.py index 694e2c5..1f786f5 100644 --- a/src/evaluator.py +++ b/src/evaluator.py @@ -10,7 +10,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from src.utils.tool import extract_numbers, find_json, logger +from src.utils.timing_log import log_llm_io +from src.utils.tool import extract_numbers, logger, parse_llm_json def extract_claims(llm, title, side, content, verbose=False): @@ -51,11 +52,9 @@ def extract_claims(llm, title, side, content, verbose=False): print(prompt) response = llm(prompt=prompt, temperature=0, max_tokens=1500, n=1, sys=system_prompt) - claims = find_json(response[0]) - try: - claims = eval(claims) - except: + claims = parse_llm_json(response[0]) + except Exception: claims = {} print("Error in extracting claims") @@ -116,11 +115,9 @@ def extract_obj_aspect(llm, title, side, content, claim_against=None): response = llm(prompt=prompt, temperature=0, max_tokens=800, n=1, sys=system_prompt) # print("extract_obj_aspect:", response[0]) - scores = find_json(response[0]) - try: - scores = eval(scores) - except: + scores = parse_llm_json(response[0]) + except Exception: scores = {} print("Error in extracting scores") @@ -184,11 +181,11 @@ def eval_surprise(llm, title, side, claims, n=3, reduction=False, verbose=False) surprises = [] score_list = [] for i in range(n): - surprise = find_json(response[i]) + surprise = {} try: - surprise = eval(surprise) + surprise = parse_llm_json(response[i]) surprise = surprise["result"] - except: + except Exception: print(surprise) print("Error in extracting surprise scores") surprise = {} @@ -231,9 +228,9 @@ def evaluate_support_strength(llm, motion, argument1, argument2, history=None): f"Argument 2: {argument2}\n" f"""The two arguments are from the same side in a debate, and the support strength refers to how well the first argument adds to the second argument. Each score ranges from 1 to 3, with 1 being the lowest and 3 being the highest. Provide your evaluation as a single number in the format "Score: [score]". You can additionally provide a brief explanation of your evaluation.""" ) - logger.debug("[Support-Strength-Prompt] {}".format(prompt.strip().replace("\n", " ||| "))) + log_llm_io(logger, phase="evaluator", title="Support-Strength-Prompt", body=prompt.strip()) response = llm(prompt=prompt, temperature=0)[0] - logger.debug("[Support-Strength-Response] {}".format(response.strip().replace("\n", " ||| "))) + log_llm_io(logger, phase="evaluator", title="Support-Strength-Response", body=response.strip()) response = response.replace("*", "") pos = response.find("Score: ") numbers = extract_numbers(response[pos : pos + 15]) @@ -254,9 +251,9 @@ def evaluate_defense_strength(llm, motion, argument1, argument2, history=None): f"Argument 2: {argument2}\n" """The two arguments are from the different sides in a debate, and the rebuttal strength refers to how well the first argument undermines the second argument. Each score ranges from 1 to 3, with 1 being the lowest and 3 being the highest. Provide your evaluation as a single number in the format "Score: [score]". You can additionally provide a brief explanation of your evaluation.""" ) - logger.debug("[Support-Defense-Prompt] {}".format(prompt.strip().replace("\n", " ||| "))) + log_llm_io(logger, phase="evaluator", title="Support-Defense-Prompt", body=prompt.strip()) response = llm(prompt=prompt, temperature=0)[0] - logger.debug("[Support-Defense-Response] {}".format(response.strip().replace("\n", " ||| "))) + log_llm_io(logger, phase="evaluator", title="Support-Defense-Response", body=response.strip()) pos = response.find("Score: ") numbers = extract_numbers(response[pos : pos + 15]) return numbers[0] diff --git a/src/ouragents.py b/src/ouragents.py index d85c0f5..bce4a41 100644 --- a/src/ouragents.py +++ b/src/ouragents.py @@ -4,10 +4,13 @@ import os import random import re +import threading import time import traceback from dataclasses import dataclass from functools import partial +from pathlib import Path +from typing import Optional import google.generativeai as genai import litellm @@ -30,10 +33,20 @@ get_retrieval_from_rehearsal_tree, rank_evidence, ) +from utils.llm_schemas import SelectedIdsResponse from utils.model import HelperClient from utils.prompts import * from utils.time_estimator import LengthEstimator -from utils.tool import get_response_with_retry, logger, sort_by_action, sort_by_importance +from utils.timing_log import ( + clear_speak_io_context, + log_io_block, + log_llm_io, + log_timing, + next_call_id, + set_speak_io_context, + timed_phase, +) +from utils.tool import get_response_with_retry, io_logger, io_logging_enabled, logger, sort_by_action, sort_by_importance class TreeDebater(Debater): @@ -53,9 +66,8 @@ def __init__(self, config, motion): + f"use_rehearsal_tree: {self.use_rehearsal_tree}, use_debate_flow_tree: {self.use_debate_flow_tree}" ) - self.helper_client = partial( - HelperClient, model=self.config.model, temperature=0, max_tokens=config.max_tokens, n=1 - ) + helper_model = getattr(config, "helper_model", self.config.model) + self.helper_client = partial(HelperClient, model=helper_model, temperature=0, max_tokens=config.max_tokens, n=1) self.simulated_audience = [Audience(AudienceConfig(model=self.config.model, temperature=1)) for _ in range(1)] # Initialize debate trees only if they are enabled @@ -91,6 +103,62 @@ def __init__(self, config, motion): self.used_evidence = set() + self._streaming_input_env = None + self._streaming_listen_thread = None + + def start_streaming_listen( + self, + watch_dir: Path, + stage: str, + *, + min_audio_seconds: float = 30.0, + min_text_words: int = 50, + poll_interval: float = 1.0, + audio_format: str = "mp3", + max_audio_wait_seconds: Optional[float] = None, + max_text_wait_seconds: Optional[float] = None, + max_total_audio_seconds: Optional[float] = None, + playback_cursor: Optional[list] = None, + ) -> None: + """Run :class:`StreamingInputEnv` on a thread (chunk audio → ASR → ``_analyze_statement``).""" + from streaming.env import StreamingInputConfig, StreamingInputEnv + + self.stop_streaming_listen(join_timeout=5.0) + + watch_dir = Path(watch_dir) + cfg = StreamingInputConfig( + watch_dir=watch_dir, + motion=self.motion, + stage=stage, + statement_side=self.oppo_side, + min_audio_seconds=min_audio_seconds, + min_text_words=min_text_words, + poll_interval=poll_interval, + audio_file_glob=f"*.{audio_format}", + max_audio_wait_seconds=max_audio_wait_seconds, + max_text_wait_seconds=max_text_wait_seconds, + max_total_audio_seconds=max_total_audio_seconds, + audio_format=audio_format, + playback_cursor=playback_cursor, + ) + self._streaming_input_env = StreamingInputEnv(self, cfg) + self._streaming_listen_thread = threading.Thread( + target=self._streaming_input_env.run, + name="StreamingInputListen", + daemon=True, + ) + self._streaming_listen_thread.start() + + def stop_streaming_listen(self, join_timeout: float = 300.0) -> None: + t = self._streaming_listen_thread + env = self._streaming_input_env + self._streaming_listen_thread = None + self._streaming_input_env = None + if env is not None: + env.stop() + if t is not None: + t.join(timeout=join_timeout) + def _get_evidence(self, claim): if self.use_retrieval: evidence = [x for x in claim["retrieved_evidence"] if "PDF" not in x["title"]] @@ -141,9 +209,23 @@ def claim_generation(self, pool_size, definition=None, **kwargs): self.oppo_claim_pool = claim_pool prompt = propose_definition_prompt.format(motion=self.motion, act=self.act) - logger.debug("[Definition-Helper-Prompt] " + prompt.strip().replace("\n", " ||| ")) + log_llm_io( + logger, + phase="ouragents_prepare", + title="Definition-Helper-Prompt", + body=prompt.strip(), + stage=getattr(self, "status", None), + side=getattr(self, "side", None), + ) response = self.helper_client(prompt=prompt)[0] - logger.debug("[Definition-Helper-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io( + logger, + phase="ouragents_prepare", + title="Definition-Helper-Response", + body=response.strip(), + stage=getattr(self, "status", None), + side=getattr(self, "side", None), + ) if "None" in response: self.definition = None else: @@ -351,37 +433,135 @@ def closing_generation(self, history, max_time, time_control=False, **kwargs): return response def speak(self, prompt, max_time, time_control=False, history=None, **kwargs): - self._add_message("user", prompt) - logger.debug(f"[Conversation-History] {json.dumps(self.conversation)}") - logger.debug("[Prompt] " + prompt.strip().replace("\n", " ||| ")) - - # add evidence based on audience feedback - response = self._get_response(self.conversation, **kwargs) - logger.debug("[Response-Before-Post-Process] " + response.strip().replace("\n", " ||| ")) - feedback_for_revision, new_evidence, allocation_plan, ori_statement = self._get_revision_suggestion( - statement=response, history=history, add_evidence=True, **kwargs - ) - response = self._length_adjust( - ori_statement, feedback_for_revision, new_evidence, allocation_plan, max_time, max_retry=1, **kwargs - ) + call_id = next_call_id() + ctx = dict(call_id=call_id, stage=self.status, side=self.side) + set_speak_io_context(call_id, "tree_debater_speak") + try: + with timed_phase(logger, "tree_debater_speak", **ctx): + self._add_message("user", prompt) + if io_logging_enabled(): + log_io_block( + io_logger, + call_id=call_id, + phase="tree_debater_speak", + title="Conversation-History", + body=json.dumps(self.conversation), + stage=self.status, + side=self.side, + ) + log_io_block( + io_logger, + call_id=call_id, + phase="tree_debater_speak", + title="Prompt", + body=prompt, + stage=self.status, + side=self.side, + ) + else: + log_llm_io( + logger, + phase="tree_debater_speak", + title="Conversation-History", + body=json.dumps(self.conversation), + stage=self.status, + side=self.side, + ) + log_llm_io( + logger, + phase="tree_debater_speak", + title="Prompt", + body=prompt.strip(), + stage=self.status, + side=self.side, + ) + logger.debug( + f"[timing-meta] call_id={call_id} speak_session=tree_debater_speak " + f"n_messages={len(self.conversation)}" + ) - # check audience feedback again - feedback_for_revision, new_evidence, _, _ = self._get_revision_suggestion( - statement=response, history=history, add_evidence=False, **kwargs - ) + with timed_phase(logger, "main_get_response", **ctx): + response = self._get_response(self.conversation, **kwargs) + if io_logging_enabled(): + log_io_block( + io_logger, + call_id=call_id, + phase="tree_debater_speak", + title="Response-Before-Post-Process", + body=str(response).strip(), + stage=self.status, + side=self.side, + ) + else: + log_llm_io( + logger, + phase="tree_debater_speak", + title="Response-Before-Post-Process", + body=str(response).strip(), + stage=self.status, + side=self.side, + ) - streaming_tts = kwargs.get("streaming_tts", False) - if not time_control or streaming_tts: - # streaming TTS has its own adaptive refinement, skip expensive retries here - response = self._length_adjust( - response, feedback_for_revision, new_evidence, allocation_plan, max_time, max_retry=1, **kwargs - ) - else: - response = self._length_adjust( - response, feedback_for_revision, new_evidence, allocation_plan, max_time, max_retry=10, **kwargs - ) + with timed_phase(logger, "revision_suggestion", pass_index=1, add_evidence=True, **ctx): + feedback_for_revision, new_evidence, allocation_plan, ori_statement = self._get_revision_suggestion( + statement=response, history=history, add_evidence=True, call_id=call_id, **kwargs + ) + with timed_phase(logger, "length_adjust", block=1, max_retry=1, **ctx): + response = self._length_adjust( + ori_statement, + feedback_for_revision, + new_evidence, + allocation_plan, + max_time, + max_retry=1, + call_id=call_id, + **kwargs, + ) - return super().post_process(response, max_time, time_control, **kwargs) + # Default to single-pass revision to reduce latency: + # pass-1 revision + one length-adjust. Set single_pass_revision=False + # (via config or kwargs) to restore the old two-pass behavior. + single_pass_revision = kwargs.get( + "single_pass_revision", + getattr(self.config, "single_pass_revision", False), + ) + if not single_pass_revision: + with timed_phase(logger, "revision_suggestion", pass_index=2, add_evidence=False, **ctx): + feedback_for_revision, new_evidence, _, _ = self._get_revision_suggestion( + statement=response, history=history, add_evidence=False, call_id=call_id, **kwargs + ) + + streaming_tts = kwargs.get("streaming_tts", getattr(self.config, "streaming_tts", False)) + if not time_control or streaming_tts: + with timed_phase(logger, "length_adjust", block=2, max_retry=1, **ctx): + response = self._length_adjust( + response, + feedback_for_revision, + new_evidence, + allocation_plan, + max_time, + max_retry=1, + call_id=call_id, + **kwargs, + ) + else: + with timed_phase(logger, "length_adjust", block=2, max_retry=10, **ctx): + response = self._length_adjust( + response, + feedback_for_revision, + new_evidence, + allocation_plan, + max_time, + max_retry=10, + call_id=call_id, + **kwargs, + ) + + with timed_phase(logger, "post_process", **ctx): + out = super().post_process(response, max_time, time_control, **kwargs) + return out + finally: + clear_speak_io_context() def listen(self, history): if len(history) == 0: @@ -393,7 +573,24 @@ def listen(self, history): # Only analyze statement if debate flow tree is enabled if self.use_debate_flow_tree: - self._analyze_statement(history[-1]["content"], self.oppo_side) + skip_full = getattr(self.config, "streaming_listen", False) and history[-1].get( + "tree_via_streaming" + ) is True + if not skip_full: + st = history[-1]["stage"] + with timed_phase( + logger, + "listen_analyze_statement", + stage=st, + side=self.side, + opponent_side=self.oppo_side, + ): + logger.debug( + f"[BatchListener] analyze_start stage={st} side={self.side} " + f"opponent_side={self.oppo_side} t={time.time():.3f}" + ) + self._analyze_statement(history[-1]["content"], self.oppo_side) + logger.debug(f"[BatchListener] analyze_end stage={st} side={self.side} t={time.time():.3f}") # Only prepare opponent tree list if both debate flow tree and rehearsal tree are enabled if self.use_rehearsal_tree and self.prepared_oppo_tree_list is None: @@ -404,7 +601,13 @@ def listen(self, history): def _get_feedback_from_audience(self, statement, history, **kwargs): extra_tree_info = "" if self.add_retrieval_feedback and self.use_debate_flow_tree: - retrieval, retrieval_feedback = self._get_retrieval_debate_tree(include_points=False) + with timed_phase( + logger, + "audience_exemplar_retrieval", + stage=self.status, + side=self.side, + ): + retrieval, retrieval_feedback = self._get_retrieval_debate_tree(include_points=False) if retrieval is not None: extra_tree_info += "\n\n" + retrieval_feedback @@ -420,18 +623,56 @@ def _get_feedback_from_audience(self, statement, history, **kwargs): retrieval=extra_tree_info, history=history_str, ) - logger.debug("[Audience-Feedback-Prompt] " + prompt.strip().replace("\n", " ||| ")) + call_id = kwargs.get("call_id") + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="audience_feedback", + title="Audience-Feedback-Prompt", + body=prompt.strip(), + stage=self.status, + side=self.side, + ) + else: + log_llm_io( + logger, + phase="audience_feedback", + title="Audience-Feedback-Prompt", + body=prompt.strip(), + stage=self.status, + side=self.side, + ) audience_feedback = [] flat_audience_feedback = "" - for i, au in enumerate(self.simulated_audience): - feedback = au.feedback(prompt) - audience_feedback.append(feedback) - key_feedback = ( - "Critical Issues and Minimal Revision Suggestions" - + feedback.split("Critical Issues and Minimal Revision Suggestions")[-1] + with timed_phase(logger, "audience_simulated_feedback_llm", stage=self.status, side=self.side, n_audience=len(self.simulated_audience)): + for i, au in enumerate(self.simulated_audience): + feedback = au.feedback(prompt) + audience_feedback.append(feedback) + key_feedback = ( + "Critical Issues and Minimal Revision Suggestions" + + feedback.split("Critical Issues and Minimal Revision Suggestions")[-1] + ) + flat_audience_feedback += f"\n\n\nAudience {i+1} Feedback:\n" + key_feedback + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="audience_feedback", + title="Audience-Feedback-Response", + body=flat_audience_feedback.strip(), + stage=self.status, + side=self.side, + ) + else: + log_llm_io( + logger, + phase="audience_feedback", + title="Audience-Feedback-Response", + body=flat_audience_feedback.strip(), + stage=self.status, + side=self.side, ) - flat_audience_feedback += f"\n\n\nAudience {i+1} Feedback:\n" + key_feedback - logger.debug("[Audience-Feedback-Response] " + flat_audience_feedback.strip().replace("\n", " ||| ")) return flat_audience_feedback, audience_feedback def _get_retrieval_debate_tree(self, **kwargs): @@ -442,7 +683,15 @@ def _get_retrieval_debate_tree(self, **kwargs): logger.debug( f"[Retrieval-Debate-Tree] Search for {self.side} side: " + current_tree_info.strip().replace("\\n", " ||| ") ) + t_embed = time.perf_counter() current_tree_embedding = self._get_embedding_from_cache(current_tree_info) + log_timing( + logger, + "exemplar_retrieval_query_embedding", + time.perf_counter() - t_embed, + stage=self.status, + side=self.side, + ) memory_tree_embedding = self.pro_embeddings if self.side == "for" else self.con_embeddings if self.status == "opening": memory_tree_embedding = memory_tree_embedding[0] @@ -451,12 +700,21 @@ def _get_retrieval_debate_tree(self, **kwargs): elif self.status == "closing": memory_tree_embedding = memory_tree_embedding[2] + t_search = time.perf_counter() hits = semantic_search( torch.tensor([current_tree_embedding]), torch.tensor(memory_tree_embedding), score_function=dot_score, top_k=1, )[0] + log_timing( + logger, + "exemplar_retrieval_semantic_search", + time.perf_counter() - t_search, + stage=self.status, + side=self.side, + top_k=1, + ) retrieval_idx = [x["corpus_id"] for x in hits] retrieval_data = [self.data_list[idx] for idx in retrieval_idx] retrieval_motion = [data["motion"] for data in retrieval_data] @@ -550,8 +808,12 @@ def _get_prepared_tree(self, side): prepared_tree.append(sorted_match_trees[i][0]) similarity = sorted_match_trees[i][1] query_claim = sorted_match_trees[i][2] - logger.debug( - f"[Get-Prepared-Tree] Opponent's Tree (similarity: {similarity:0.2f}) for claim: {query_claim}\n{tree.print_tree(include_status=True)}" + + log_llm_io( + logger, + phase="get_prepared_tree", + title=f"Opponent's Tree (similarity: {similarity:0.2f}) for claim: {query_claim}", + body=tree.print_tree(include_status=True), ) thoughts = { @@ -575,16 +837,23 @@ def _retrieve_on_prepared_tree(self, action): look_ahead_num = REMAINING_ROUND_NUM[f"{self.status}_{self.side}"] query_embedding = self._get_embedding_from_cache(target_claim) - additional_info, retrieval_nodes = get_retrieval_from_rehearsal_tree( - action_type, - target_claim, - self.side, - self.oppo_side, - self.prepared_tree_list, - self.prepared_oppo_tree_list, - look_ahead_num, - query_embedding, - ) + with timed_phase( + logger, + "rehearsal_retrieve_on_prepared_tree", + stage=self.status, + side=self.side, + action_type=action_type, + ): + additional_info, retrieval_nodes = get_retrieval_from_rehearsal_tree( + action_type, + target_claim, + self.side, + self.oppo_side, + self.prepared_tree_list, + self.prepared_oppo_tree_list, + look_ahead_num, + query_embedding, + ) thoughts = { "stage": self.status, @@ -599,7 +868,7 @@ def _retrieve_on_prepared_tree(self, action): return "\n".join(additional_info) - def _get_revision_suggestion(self, statement, history, add_evidence=True, **kwargs): + def _get_revision_suggestion(self, statement, history, add_evidence=True, call_id=None, **kwargs): statement = statement.replace("**Statement:**", "**Statement**").replace("**Statement**:", "**Statement**") parts = statement.split("**Statement**") if len(parts) > 1: @@ -612,7 +881,9 @@ def _get_revision_suggestion(self, statement, history, add_evidence=True, **kwar if self.status == "closing": return "", "", allocation_plan, statement - feedback_from_audience, audience_feedback = self._get_feedback_from_audience(statement, history, **kwargs) + feedback_from_audience, audience_feedback = self._get_feedback_from_audience( + statement, history, call_id=call_id, **kwargs + ) feedback_for_revision = f"Revision Guidance:\n{feedback_from_audience}" new_evidence = [] @@ -632,9 +903,51 @@ def _get_revision_suggestion(self, statement, history, add_evidence=True, **kwar statement=statement, feedback=feedback_for_revision, ) - logger.debug("[Evidence-Selection-Prompt] " + prompt.strip().replace("\n", " ||| ")) - selected_ids, response = get_response_with_retry(self.helper_client, prompt, "selected_ids") - logger.debug("[Evidence-Selection-Response] " + response.strip().replace("\n", " ||| ")) + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="evidence_selection", + title="Evidence-Selection-Prompt", + body=prompt.strip(), + stage=self.status, + side=self.side, + ) + else: + log_llm_io( + logger, + phase="evidence_selection", + title="Evidence-Selection-Prompt", + body=prompt.strip(), + stage=self.status, + side=self.side, + ) + with timed_phase(logger, "evidence_selection_llm", stage=self.status, side=self.side): + selected_ids, response = get_response_with_retry( + self.helper_client, + prompt, + "selected_ids", + response_model=SelectedIdsResponse, + ) + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="evidence_selection", + title="Evidence-Selection-Response", + body=response.strip(), + stage=self.status, + side=self.side, + ) + else: + log_llm_io( + logger, + phase="evidence_selection", + title="Evidence-Selection-Response", + body=response.strip(), + stage=self.status, + side=self.side, + ) new_evidence = [ e for e in new_evidence if e["id"] in selected_ids and e["id"] not in self.used_evidence ] @@ -667,6 +980,7 @@ def _get_revision_suggestion(self, statement, history, add_evidence=True, **kwar def _length_adjust( self, statement, feedback_for_revision, new_evidence, allocation_plan, max_time, max_retry=10, **kwargs ): + call_id = kwargs.pop("call_id", None) budget, threshold = max_time, TIME_TOLERANCE time_adjuster = TimeAdjuster() estimator = LengthEstimator(mode=TIME_MODE_FOR_STATEMENT) @@ -677,6 +991,7 @@ def _length_adjust( retry = 0 response_list = [] while not flag and retry < max_retry: + iter_t0 = time.perf_counter() evidence_str = json.dumps([{k: v for k, v in x.items() if k != "raw_content"} for x in new_evidence]) prompt = post_process_prompt.format( motion=self.motion, @@ -689,18 +1004,90 @@ def _length_adjust( allocation_plan=allocation_plan, ) - logger.debug("[Get-Expert-Audience-Revision-Prompt] " + prompt.strip().replace("\n", " ||| ")) + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="length_adjust", + title=f"Get-Expert-Audience-Revision-Prompt_iter{retry + 1}", + body=prompt.strip(), + stage=self.status, + side=self.side, + iteration=retry + 1, + ) + else: + log_llm_io( + logger, + phase="length_adjust", + title=f"Get-Expert-Audience-Revision-Prompt_iter{retry + 1}", + body=prompt.strip(), + stage=self.status, + side=self.side, + iteration=retry + 1, + ) revision = self.helper_client(prompt=prompt)[0] - logger.debug("[Get-Expert-Audience-Revision-Response] " + revision.strip().replace("\n", " ||| ")) + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="length_adjust", + title=f"Get-Expert-Audience-Revision-Response_iter{retry + 1}", + body=revision.strip(), + stage=self.status, + side=self.side, + iteration=retry + 1, + ) + else: + log_llm_io( + logger, + phase="length_adjust", + title=f"Get-Expert-Audience-Revision-Response_iter{retry + 1}", + body=revision.strip(), + stage=self.status, + side=self.side, + iteration=retry + 1, + ) new_statement = revision.replace("Revised Statement:\n", "") new_statement = new_statement.replace("et al.,", "") new_statement = new_statement.replace("[X]", "") response = re.sub(r" [X-Z][ \%]", "", new_statement) - logger.debug("[Response-After-Post-Process] " + response.strip().replace("\n", " ||| ")) + if io_logging_enabled() and call_id is not None: + log_io_block( + io_logger, + call_id=call_id, + phase="length_adjust", + title=f"Response-After-Post-Process_iter{retry + 1}", + body=response.strip(), + stage=self.status, + side=self.side, + iteration=retry + 1, + ) + else: + log_llm_io( + logger, + phase="length_adjust", + title=f"Response-After-Post-Process_iter{retry + 1}", + body=response.strip(), + stage=self.status, + side=self.side, + iteration=retry + 1, + ) current_cost, n_words, flag = time_adjuster.revise_helper( response, n_words, budget, threshold=threshold, ratio=ratio, estimator=estimator ) + log_timing( + logger, + "length_adjust_iteration", + time.perf_counter() - iter_t0, + stage=self.status, + side=self.side, + iteration=retry + 1, + max_retry=max_retry, + call_id=call_id, + fit_ok=flag, + current_cost=current_cost, + ) response_list.append([response, current_cost]) retry += 1 if not flag and max_retry > 1: @@ -743,7 +1130,17 @@ def _get_embedding_from_cache(self, content: str): retry = 0 while retry < max_retry: try: + t0 = time.perf_counter() embedding = get_embeddings([content])[0] + log_timing( + logger, + "embedding_api_fetch", + time.perf_counter() - t0, + stage=self.status, + side=self.side, + cache_hit=False, + attempt=retry + 1, + ) break except Exception as e: logger.error(f"[Get-Embedding-From-Cache] Error: {e}. Sleep 30 seconds and retry.") @@ -768,43 +1165,50 @@ def _analyze_statement(self, statements, statement_side): tree, oppo_tree = self.debate_tree, self.oppo_debate_tree else: tree, oppo_tree = self.oppo_debate_tree, self.debate_tree - claims = extract_statement( - self.helper_client, - self.motion, - statements, - tree=[tree.print_tree(include_status=True), oppo_tree.print_tree(include_status=True, reverse=True)], - side=statement_side, + with timed_phase( + logger, + "analyze_statement", stage=self.status, - ) - - for x in claims: - for p in x["purpose"]: - target_tree = tree if p["targeted_debate_tree"] == "you" else oppo_tree - if p["target"] == "N/A" and target_tree.max_level == 0: - if p["action"] == "propose" or p["action"] == "rebut" or p["action"] == "reinforce": - p["target"] = x["claim"] - - for x in claims: - claim = x["claim"] - arguments = x["arguments"] - if isinstance(x["purpose"], dict): - purpose = [x["purpose"]] - else: - purpose = x["purpose"] - for p in purpose: - target_tree = tree if p["targeted_debate_tree"] == "you" else oppo_tree - action = p["action"] - target = p["target"] - target_tree.update_node(action, new_claim=claim, new_argument=arguments, target=target) + side=self.side, + statement_side=statement_side, + ): + claims = extract_statement( + self.helper_client, + self.motion, + statements, + tree=[tree.print_tree(include_status=True), oppo_tree.print_tree(include_status=True, reverse=True)], + side=statement_side, + stage=self.status, + ) - thoughts = { - "stage": self.status, - "side": statement_side, - "mode": "analyze_statement", - "statement": statements, - "claims": claims, - } - self.debate_thoughts.append(thoughts) + for x in claims: + for p in x["purpose"]: + target_tree = tree if p["targeted_debate_tree"] == "you" else oppo_tree + if p["target"] == "N/A" and target_tree.max_level == 0: + if p["action"] == "propose" or p["action"] == "rebut" or p["action"] == "reinforce": + p["target"] = x["claim"] + + for x in claims: + claim = x["claim"] + arguments = x["arguments"] + if isinstance(x["purpose"], dict): + purpose = [x["purpose"]] + else: + purpose = x["purpose"] + for p in purpose: + target_tree = tree if p["targeted_debate_tree"] == "you" else oppo_tree + action = p["action"] + target = p["target"] + target_tree.update_node(action, new_claim=claim, new_argument=arguments, target=target) + + thoughts = { + "stage": self.status, + "side": statement_side, + "mode": "analyze_statement", + "statement": statements, + "claims": claims, + } + self.debate_thoughts.append(thoughts) return claims diff --git a/src/prepare.py b/src/prepare.py index 3eefdf0..53edda9 100644 --- a/src/prepare.py +++ b/src/prepare.py @@ -2,6 +2,7 @@ import json import os import re +import time from functools import partial import google.generativeai as genai @@ -13,8 +14,10 @@ from debate_tree import PrepareTree from searcher import MAX_QUERY, get_search_query, get_search_result, get_source_info, update_search_query from utils.constants import EMBEDDING_MODEL, google_api_key +from utils.llm_schemas import ResultsResponse from utils.model import HelperClient, reward_model from utils.prompts import claim_propose_prompt, propose_definition_prompt +from utils.timing_log import log_llm_io, log_timing from utils.tool import get_response_with_retry, logger genai.configure(api_key=google_api_key) @@ -62,9 +65,9 @@ def __init__( def create_claim(self, need_score=True, need_evidence=True, max_search_depth=2, max_search_branch=3): prompt = propose_definition_prompt.format(motion=self.motion, act=self.act) - logger.debug("[Definition-Helper-Prompt] " + prompt.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="prepare", title="Definition-Helper-Prompt", body=prompt.strip(), side=self.side) response = self.client(prompt=prompt)[0] - logger.debug("[Definition-Helper-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="prepare", title="Definition-Helper-Response", body=response.strip(), side=self.side) if "None" in response: self.definition = "" else: @@ -72,9 +75,15 @@ def create_claim(self, need_score=True, need_evidence=True, max_search_depth=2, prompt = claim_propose_prompt.format(motion=self.motion, act=self.act, size=self.pool_size) - logger.debug("[Claim-Propose-Prompt] " + prompt.replace("\n", " ||| ")) - results, response = get_response_with_retry(self.client, prompt, "results", temperature=1.0) - logger.debug("[Claim-Propose-Response] " + json.dumps(response, indent=2).replace("\n", " ||| ")) + log_llm_io(logger, phase="prepare", title="Claim-Propose-Prompt", body=prompt, side=self.side) + results, response = get_response_with_retry( + self.client, + prompt, + "results", + response_model=ResultsResponse, + temperature=1.0, + ) + log_llm_io(logger, phase="prepare", title="Claim-Propose-Response", body=json.dumps(response, indent=2), side=self.side) for item in results: strength = item.get("strength", 5) @@ -293,6 +302,7 @@ def get_evidence_pool(self): logger.info(f"Skip motion: {motion}") else: logger.info(f"Create motion for {save_file_name}...") + t_prep = time.perf_counter() claim_workspace = ClaimPool( motion=motion, side=side, model=model, pool_size=pool_size, use_rm_model=not args.ban_rm_model ) @@ -302,6 +312,14 @@ def get_evidence_pool(self): max_search_depth=args.max_search_depth, max_search_branch=args.max_search_branch, ) + log_timing( + logger, + "prepare_claim_pool_wall", + time.perf_counter() - t_prep, + motion=motion_name, + side=side, + n_claims=len(claim_pool), + ) logger.info(f"Claim Pool Size: {len(claim_pool)}") if len(claim_pool) > 0: diff --git a/src/scripts/README_LOGGING.md b/src/scripts/README_LOGGING.md new file mode 100644 index 0000000..c29c555 --- /dev/null +++ b/src/scripts/README_LOGGING.md @@ -0,0 +1,268 @@ +# Streaming Debate Logging Guide + +This document explains the DEBUG-level logging added to analyze streaming debate performance across all operating modes. + +## Four Operating Modes + +The system supports 4 combinations of `streaming_tts` and `streaming_listen`: + +| Mode | streaming_tts | streaming_listen | Description | +|------|---------------|------------------|-------------| +| **Full Overlap** | True | True | Speaker generates chunks in real-time, listener processes while speaking | +| **Speaker Stream** | True | False | Speaker generates chunks in real-time, listener processes after completion | +| **Listener Stream** | False | True | Speaker generates full audio then chunks, listener processes while "playing" | +| **Sequential** | False | False | Speaker generates full audio, listener processes after completion | + +## Log Components and Events + +### 1. Turn-Level Events + +``` +[Turn] turn_start stage=opening side=for t=1234567890.123 +[Turn] mode_config stage=opening side=for streaming_tts=True streaming_listen=True mode=tts=stream_listen=stream t=1234567890.125 +[Turn] turn_end stage=opening side=for t=1234567920.000 +``` + +**Captures**: Turn boundaries and mode configuration + +### 2. Speaker Thread Events + +``` +[SpeakerWorker] thread_start stage=opening side=for t=1234567890.130 +[SpeakerWorker] generation_start stage=opening side=for t=1234567890.131 +[SpeakerWorker] generation_end stage=opening side=for response_len=1250 t=1234567905.500 +[SpeakerWorker] thread_end stage=opening side=for t=1234567925.200 +``` + +**Captures**: LLM generation timing + +#### Batch TTS Mode (streaming_tts=False) + +``` +[SpeakerWorker] posthoc_chunk_start mode=batch_tts mp3_path=treedebater_opening_for.mp3 t=1234567905.600 +[SpeakerWorker] posthoc_chunk_split audio_duration=120.50s num_chunks=12 t=1234567905.650 +[SpeakerWorker] posthoc_chunk_end num_chunks=12 stream_time=0.450s t=1234567906.050 +``` + +**Captures**: Post-hoc chunking overhead + +### 3. TTS Chunk Bridge Events (streaming_tts=True only) + +``` +[TtsChunkBridge] chunk_detected chunk_idx=1 size=45120 t=1234567890.145 +[TtsChunkBridge] chunk_copied chunk_idx=1 detection_latency=0.805s copy_time=0.002s t=1234567890.950 +``` + +**Captures**: Real-time chunk availability + +### 4. Playback Main Thread Events + +``` +[PlaybackMain] playback_start side=for t=1234567890.150 +[PlaybackMain] wait_chunk_start chunk_idx=1 cursor=0.00s t=1234567890.150 +[PlaybackMain] wait_chunk_end chunk_idx=1 wait_time=0.850s t=1234567891.000 +[PlaybackMain] chunk_assembled chunk_idx=1 duration=10.00s total_audio=10.00s t=1234567891.015 +[PlaybackMain] file_write chunk_idx=1 size_sec=10.00s write_time=0.120s t=1234567891.135 +[PlaybackMain] chunk_playback_start chunk_idx=1 duration=10.00s t=1234567891.136 +[PlaybackMain] chunk_playback_end chunk_idx=1 cursor=10.00s t=1234567901.136 +[PlaybackMain] playback_end side=for cursor=120.50s t=1234567920.000 +``` + +**Captures**: +- Speaker bubbles (wait_chunk gaps) +- File I/O overhead +- Chunk assembly timing + +### 5. Streaming Listener Events (streaming_listen=True only) + +``` +[StreamingInputEnv] thread_start stage=opening statement_side=for cursor_mode=True t=1234567890.127 +[StreamingInputEnv] file_read cursor=3.00s available=10.00s read_time=0.015s t=1234567894.200 +[StreamingInputEnv] asr_start audio_range=0.00-3.00s t=1234567894.215 +[StreamingInputEnv] asr_end audio_range=0.00-3.00s text_len=450 asr_time=0.850s t=1234567895.065 +[StreamingInputEnv] tree_update_start words=75 text_preview=In this opening statement... t=1234567895.066 +[StreamingInputEnv] tree_update_end words=75 update_time=0.245s t=1234567895.311 +[StreamingInputEnv] wait_audio_accumulation available=2.50s need=3.00s t=1234567896.000 +[StreamingInputEnv] thread_end stage=opening statement_side=for t=1234567925.123 +``` + +**Captures**: +- ASR timing and RTF +- Tree update costs +- Audio starvation events +- Listener bubble (thread_end - playback_end) + +### 6. Batch Listener Events (streaming_listen=False only) + +``` +[NonStreamingListener] batch_listen_start stage=opening side=against t=1234567920.100 +[BatchListener] analyze_start stage=opening side=against opponent_side=for t=1234567920.200 +[BatchListener] analyze_end stage=opening side=against analyze_time=2.450s t=1234567922.650 +``` + +**Captures**: Sequential processing overhead + +## Key Metrics Calculation + +### Speaker Bubble (waiting for TTS chunks) +``` +speaker_bubble = sum(wait_chunk_end - wait_chunk_start) +``` + +### Listener Bubble (post-playback processing) +``` +listener_bubble = listener_thread_end - playback_end +``` + +### True Overlap +``` +true_overlap = playback_duration - speaker_bubble +``` + +### Overlap Efficiency +``` +overlap_efficiency = true_overlap / (playback_duration + listener_bubble) +``` + +### ASR Real-Time Factor +``` +rtf = asr_time / audio_duration +# rtf < 1.0 means real-time capable +``` + +### Chunk End-to-End Latency +``` +e2e_latency = chunk_playback_end - chunk_detected +``` + +### Bottleneck Detection +``` +bottleneck = max(speaker_duration, playback_duration, listener_duration) +``` + +## Mode-Specific Metrics + +### Full Overlap (tts=stream, listen=stream) +- Speaker bubbles +- Listener bubbles +- True overlap +- ASR RTF +- Chunk E2E latency + +### Speaker Stream (tts=stream, listen=batch) +- Speaker bubbles (should be minimal) +- Batch analyze time +- No listener bubble (sequential) + +### Listener Stream (tts=batch, listen=stream) +- Post-hoc chunk time +- Listener bubbles +- ASR RTF +- No speaker bubbles (batch TTS) + +### Sequential (tts=batch, listen=batch) +- Post-hoc chunk time +- Batch analyze time +- No bubbles (no overlap) + +## Using the Analysis Script + +```bash +# Analyze any mode +python src/scripts/analyze_streaming_performance.py log_files/debate.log + +# Compare modes +python src/scripts/analyze_streaming_performance.py log_files/full_overlap.log -o full.json +python src/scripts/analyze_streaming_performance.py log_files/sequential.log -o sequential.json + +# Then compare JSON outputs to understand performance differences +``` + +## Example: Comparing Modes + +**Full Overlap Mode:** +``` +Mode: tts=stream_listen=stream +Overlap efficiency: 84.5% +Speaker bubble: 8.2s (6.8% of playback) +Listener bubble: 12.3s (9.3% of listener time) +Bottleneck: LISTENER +``` + +**Sequential Mode:** +``` +Mode: tts=batch_listen=batch +Post-hoc chunk time: 0.450s +Batch analyze time: 2.450s +No overlap (sequential execution) +``` + +The difference shows the benefit of streaming: ~96s saved through parallelization! + +## Agent / LLM timing and I/O logs (TreeDebater) + +These lines use a **separate format** from streaming events above. They are intended for grep and ad-hoc profiling, not for `analyze_streaming_performance.py`. + +### Main debate log (`N.log`) + +Single-line timing records share the prefix **`[timing]`**, then **`phase=...`**, **`duration_s=...`**, and optional context keys (`stage`, `side`, `call_id`, `pass_index`, `iteration`, …). + +Examples: + +``` +[timing] phase=tree_debater_speak duration_s=45.2301 call_id=3 stage=rebuttal side=for +[timing] phase=main_get_response duration_s=12.1000 call_id=3 stage=rebuttal side=for +[timing] phase=revision_suggestion duration_s=8.0200 pass_index=1 add_evidence=True call_id=3 stage=rebuttal side=for +[timing] phase=length_adjust_iteration duration_s=3.1000 iteration=1 max_retry=10 call_id=3 fit_ok=False current_cost=4.2 stage=rebuttal side=for +[timing] phase=helper_client_litellm duration_s=2.5000 model=gpt-4o n_index=1 max_tokens=4096 +[timing] phase=debater_litellm_completion duration_s=11.8000 stage=opening side=for model=gpt-4o retry_attempt=0 +[timing] phase=tts_wall_clock duration_s=8.4000 stage=opening side=for kind=streaming_tts audio_duration_s=7.5 +[timing] phase=env_stage_wall duration_s=120.0000 stage=opening motion=... +[timing] phase=evaluation_wall duration_s=90.0000 motion=... +``` + +**Macro phases** (env / parallel / compare / prepare scripts): `env_stage_wall`, `evaluation_wall`, `comparison_phase_wall`, `comparison_evaluation_total_wall`, `compare_env_stage_wall`, `prepare_claim_pool_wall`. + +**Meso phases** (TreeDebater `speak`): `tree_debater_speak`, `main_get_response`, `revision_suggestion` (with `pass_index` / `add_evidence`), `length_adjust` blocks (outer `timed_phase` with `block=`), `post_process`, `listen_analyze_statement`, `analyze_statement`, `audience_exemplar_retrieval`, `audience_simulated_feedback_llm`, `evidence_selection_llm`, `rehearsal_retrieve_on_prepared_tree`. + +**Atoms**: `helper_client_litellm`, `debater_litellm_completion`, `get_response_with_retry_llm`, `embedding_api_fetch`, `exemplar_retrieval_query_embedding`, `exemplar_retrieval_semantic_search`, `tts_wall_clock`, `tts_trim_wall_clock`. + +Short metadata without full prompts: **`[timing-meta] call_id=... speak_session=...`** where **`speak_session`** is either **`default_speak`** (base :class:`~agents.Agent`) or **`tree_debater_speak`** (:class:`~ouragents.TreeDebater`). That value matches the default **`[io] phase=`** for blocks in that turn (Prompt, Response, TTS-related bodies, etc.). Submodule blocks may override **`phase=`** (e.g. **`audience_feedback`**, **`length_adjust`**) while keeping the same **`call_id`**. + +### I/O log (`N_io.log`) + +When prompt/response logging is enabled (default), large bodies are written to a **sibling file** next to the main log: if the main file is `log_files/14.log`, the I/O file is `log_files/14_io.log`. + +- Disable I/O file (and fall back to legacy DEBUG prompts on the main log only): set environment variable **`DEBATE_LOG_PROMPTS=0`** (also accepts `false`, `no`, `off`). +- Each I/O block starts with **`[io]`**, then key/value pairs: + - **`call_id`**: same id as **`[timing-meta]`** for that speak turn (sub-blocks under one turn share this id). + - **`phase`**: usually the **`speak_session`** (`default_speak` / `tree_debater_speak`) or a **submodule** label (`audience_feedback`, `length_adjust`, …) when the block is not the top-level speak transcript. + - **`title`**: what the block is (e.g. **`Prompt`**, **`Conversation-History`**, **`Audience-Feedback-Prompt`**, **`Response-After-TTS`**). This is **not** the same field as **`phase`** — use **`title`** to tell Prompt vs Response apart. + - Optional **`stage`** / **`side`** for attribution. +- Then a separator line and the body. + +When I/O logging is on, optional **`[io-ref]`** lines on the main log use **`speak_session`** / **`title`** / **`call_id`** for quick correlation without duplicating bodies. + +On startup, when the I/O log is attached, the main log records: **`[timing] phase=io_log_ready io_log=...`**. + +### Analyzing agent timing logs + +Use **`src/scripts/analyze_agent_timing.py`** (separate from `analyze_streaming_performance.py`): + +```bash +python src/scripts/analyze_agent_timing.py log_files/14.log +python src/scripts/analyze_agent_timing.py log_files/14.log --io-log log_files/14_io.log -v +python src/scripts/analyze_agent_timing.py log_files/14.log --json-out agent_timing_report.json +``` + +It summarizes macro vs speak-pipeline vs atom phases, groups nested phases by **`call_id`**, lists **`length_adjust_iteration`** rows, and optionally counts **`[io]`** blocks in the I/O file. + +## Log File Location + +Logs go to the same file as debate system logs: +- File handler: DEBUG level (captures all events) +- Console handler: INFO level (shows high-level progress only) + +Large LLM prompts/responses default to **`N_io.log`** (see Agent section above) so the main file stays readable. + +This keeps detailed timing data available for analysis while maintaining clean console output. diff --git a/src/scripts/analyze_agent_timing.py b/src/scripts/analyze_agent_timing.py new file mode 100644 index 0000000..ae13231 --- /dev/null +++ b/src/scripts/analyze_agent_timing.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +""" +Analyze agent / LLM timing lines from TreeDebater main logs (``[timing]``) and optional I/O logs (``[io]``). + +Parses the format emitted by ``utils/timing_log.py`` (see ``src/scripts/README_LOGGING.md``, Agent section). + +Usage: + python src/scripts/analyze_agent_timing.py log_files/14.log + python src/scripts/analyze_agent_timing.py log_files/14.log --io-log log_files/14_io.log + python src/scripts/analyze_agent_timing.py log_files/14.log --json-out report.json +""" + +from __future__ import annotations + +import argparse +import json +import re +from collections import defaultdict +from dataclasses import dataclass, field, asdict +from pathlib import Path +from typing import Any, DefaultDict, Dict, List, Optional, Tuple + + +# Strip standard debate file formatter prefix: "... DEBUG module - funcName: message" +_PREFIX_RE = re.compile( + r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}\s+\w+\s+[\w.]+\s+-\s+\w+:\s*" +) + +_TIMING_HEAD = re.compile( + r"^\[timing\]\s+phase=(\S+)\s+duration_s=([\d.]+)\s*(.*)$" +) +_TIMING_META = re.compile( + r"^\[timing-meta\]\s+(.*)$" +) +_KV = re.compile(r"(\w+)=([^\s]+)") + +# Phases to exclude from analysis output/report. +# Edit this list to hide noisy phases without changing runtime logging. +EXCLUDED_PHASES = { + "tts_wall_clock", + "length_adjust", +} + +# Phases grouped for the human-readable report (edit as you add new phases) +MACRO_PHASES = frozenset( + { + "env_stage_wall", + "evaluation_wall", + "comparison_phase_wall", + "comparison_evaluation_total_wall", + "compare_env_stage_wall", + "prepare_claim_pool_wall", + "io_log_ready", + } +) +SPEAK_PIPELINE = frozenset( + { + "tree_debater_speak", + "main_get_response", + "revision_suggestion", + "length_adjust", + "length_adjust_iteration", + "post_process", + } +) +LISTENER_TREE = frozenset( + { + "listen_analyze_statement", + "analyze_statement", + } +) +AUDIENCE_REVISION = frozenset( + { + "audience_exemplar_retrieval", + "audience_simulated_feedback_llm", + "evidence_selection_llm", + } +) +RETRIEVAL = frozenset( + { + "rehearsal_retrieve_on_prepared_tree", + "exemplar_retrieval_query_embedding", + "exemplar_retrieval_semantic_search", + } +) +ATOM_LLM = frozenset( + { + "helper_client_litellm", + "debater_litellm_completion", + "get_response_with_retry_llm", + } +) +ATOM_OTHER = frozenset( + { + "embedding_api_fetch", + "tts_wall_clock", + "tts_trim_wall_clock", + } +) + + +@dataclass +class TimingRecord: + phase: str + duration_s: float + fields: Dict[str, str] = field(default_factory=dict) + raw: str = "" + + +def _parse_kv_tail(tail: str) -> Dict[str, str]: + out: Dict[str, str] = {} + for m in _KV.finditer(tail.strip()): + out[m.group(1)] = m.group(2) + return out + + +def _strip_log_prefix(line: str) -> str: + line = line.rstrip("\n") + m = _PREFIX_RE.match(line) + if m: + return line[m.end() :] + return line + + +def parse_timing_line(line: str) -> Optional[TimingRecord]: + s = _strip_log_prefix(line) + if not s.startswith("[timing]"): + return None + m = _TIMING_HEAD.match(s) + if not m: + return None + phase, dur_s, tail = m.group(1), float(m.group(2)), m.group(3) or "" + return TimingRecord(phase=phase, duration_s=dur_s, fields=_parse_kv_tail(tail), raw=s) + + +def parse_timing_meta(line: str) -> Optional[Dict[str, str]]: + s = _strip_log_prefix(line) + if not s.startswith("[timing-meta]"): + return None + m = _TIMING_META.match(s) + if not m: + return None + return _parse_kv_tail(m.group(1)) + + +def parse_io_header_line(line: str) -> Optional[Dict[str, str]]: + s = _strip_log_prefix(line) + if "[io]" not in s: + return None + # First line of a block: "[io] call_id=1 phase=... title=..." + idx = s.find("[io]") + if idx < 0: + return None + rest = s[idx + len("[io]") :].strip() + return _parse_kv_tail(rest) + + +def load_timing_records(path: Path) -> List[TimingRecord]: + records: List[TimingRecord] = [] + with path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + rec = parse_timing_line(line) + if rec: + records.append(rec) + return records + + +def filter_excluded_phases( + records: List[TimingRecord], excluded_phases: set[str] +) -> Tuple[List[TimingRecord], int]: + if not excluded_phases: + return records, 0 + filtered = [r for r in records if r.phase not in excluded_phases] + return filtered, len(records) - len(filtered) + + +def load_meta_records(path: Path) -> List[Dict[str, str]]: + rows: List[Dict[str, str]] = [] + with path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + m = parse_timing_meta(line) + if m: + rows.append(m) + return rows + + +def count_io_blocks(path: Path) -> Tuple[int, Dict[Tuple[str, str], int]]: + """Count I/O blocks (header lines with ``[io]``) and histogram (phase, title).""" + total = 0 + hist: Dict[Tuple[str, str], int] = defaultdict(int) + with path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + h = parse_io_header_line(line) + if h and "call_id" in h: + total += 1 + key = (h.get("phase", "?"), h.get("title", "?")) + hist[key] += 1 + return total, dict(hist) + + +def aggregate_by_phase(records: List[TimingRecord]) -> Dict[str, Dict[str, Any]]: + by_phase: DefaultDict[str, List[float]] = defaultdict(list) + for r in records: + by_phase[r.phase].append(r.duration_s) + out: Dict[str, Dict[str, Any]] = {} + for phase, xs in sorted(by_phase.items()): + out[phase] = { + "count": len(xs), + "total_s": round(sum(xs), 4), + "mean_s": round(sum(xs) / len(xs), 4), + "min_s": round(min(xs), 4), + "max_s": round(max(xs), 4), + } + return out + + +def group_by_call_id(records: List[TimingRecord]) -> Dict[str, List[TimingRecord]]: + groups: DefaultDict[str, List[TimingRecord]] = defaultdict(list) + for r in records: + cid = r.fields.get("call_id") + if cid is not None: + groups[cid].append(r) + return dict(groups) + + +def print_report( + records: List[TimingRecord], + meta: List[Dict[str, str]], + io_path: Optional[Path], + verbose: bool, + excluded_phases: Optional[set[str]] = None, + filtered_count: int = 0, +) -> Dict[str, Any]: + agg = aggregate_by_phase(records) + by_cid = group_by_call_id(records) + + lines: List[str] = [] + W = lines.append + + W("=" * 80) + W("AGENT / LLM TIMING ANALYSIS ([timing] lines)") + W("=" * 80) + W(f"Total timing records: {len(records)}") + if excluded_phases: + W(f"Excluded phases: {sorted(excluded_phases)}") + W(f"Excluded records: {filtered_count}") + W("") + + # --- Bucket summary --- + def bucket_sum(phases: frozenset) -> Tuple[int, float]: + n, t = 0, 0.0 + for p in phases: + if p in agg: + n += agg[p]["count"] + t += agg[p]["total_s"] + return n, t + + W("--- Interest summary (by category) ---") + for name, pset in [ + ("Macro (env / eval / compare / prep)", MACRO_PHASES), + ("Speak pipeline (TreeDebater turn)", SPEAK_PIPELINE), + ("Listen + debate-flow tree", LISTENER_TREE), + ("Audience + revision LLM blocks", AUDIENCE_REVISION), + ("Retrieval (exemplar + rehearsal)", RETRIEVAL), + ("Atom: LLM completions", ATOM_LLM), + ("Atom: embed / TTS wall", ATOM_OTHER), + ]: + n, t = bucket_sum(pset) + W(f" {name}: events={n} total_time_s={t:.2f}") + W("") + + # --- Per-phase table --- + W("--- Per-phase statistics ---") + W(f"{'phase':<42} {'n':>5} {'total_s':>10} {'mean_s':>10} {'max_s':>10}") + for phase in sorted(agg.keys()): + a = agg[phase] + W(f"{phase:<42} {a['count']:>5} {a['total_s']:>10.2f} {a['mean_s']:>10.2f} {a['max_s']:>10.2f}") + W("") + + # --- Speak sessions by call_id --- + if by_cid: + W("--- TreeDebater speak sessions (by call_id) ---") + for cid in sorted(by_cid.keys(), key=lambda x: int(x) if x.isdigit() else 0): + sess = by_cid[cid] + total = sum(r.duration_s for r in sess if r.phase == "tree_debater_speak") + W(f" call_id={cid} (tree_debater_speak wall={total:.2f}s if present)") + for r in sess: + if r.phase == "tree_debater_speak": + continue + extra = " ".join(f"{k}={v}" for k, v in sorted(r.fields.items()) if k != "call_id") + W(f" {r.duration_s:8.2f}s {r.phase}" + (f" | {extra}" if extra else "")) + W("") + else: + W("--- No records with call_id= (speak pipeline grouping skipped) ---") + W("") + + # --- Length-adjust iterations --- + iters = [r for r in records if r.phase == "length_adjust_iteration"] + if iters: + W("--- Length adjust iterations ---") + for r in iters: + W( + f" iter={r.fields.get('iteration', '?')} max_retry={r.fields.get('max_retry', '?')} " + f"fit_ok={r.fields.get('fit_ok', '?')} cost={r.fields.get('current_cost', '?')} " + f"duration_s={r.duration_s:.3f} stage={r.fields.get('stage')} side={r.fields.get('side')}" + ) + W("") + + # --- Slowest single events --- + W("--- Slowest 25 timing events ---") + slow = sorted(records, key=lambda r: r.duration_s, reverse=True)[:25] + for r in slow: + loc = f"{r.fields.get('stage', '')}/{r.fields.get('side', '')}".strip("/") + W(f" {r.duration_s:10.2f}s {r.phase}" + (f" ({loc})" if loc else "")) + W("") + + # --- Meta lines --- + if meta: + W(f"--- timing-meta lines: {len(meta)} ---") + if verbose: + for m in meta[:50]: + W(f" {m}") + if len(meta) > 50: + W(f" ... ({len(meta) - 50} more)") + W("") + + # --- I/O log --- + io_report: Dict[str, Any] = {} + if io_path and io_path.is_file(): + n_io, hist = count_io_blocks(io_path) + W("--- I/O log (prompt/response blocks) ---") + W(f" file: {io_path}") + W(f" total [io] blocks: {n_io}") + if hist and verbose: + W(" histogram (phase, title) -> count:") + for (ph, title), c in sorted(hist.items(), key=lambda x: -x[1])[:40]: + W(f" ({ph}, {title}): {c}") + W("") + io_report = {"io_file": str(io_path), "io_blocks": n_io, "histogram": {f"{a}|{b}": c for (a, b), c in hist.items()}} + elif io_path: + W(f"--- I/O log not found: {io_path} ---") + W("") + + text = "\n".join(lines) + print(text) + + return { + "total_records": len(records), + "excluded_phases": sorted(excluded_phases) if excluded_phases else [], + "excluded_records": filtered_count, + "by_phase": agg, + "call_id_sessions": {k: [asdict(x) for x in v] for k, v in by_cid.items()}, + "meta_count": len(meta), + "io": io_report, + } + + +def main() -> None: + ap = argparse.ArgumentParser(description="Analyze [timing] / [io] agent logs.") + ap.add_argument("log_file", type=Path, help="Main debate log (e.g. log_files/14.log)") + ap.add_argument("--io-log", type=Path, default=None, help="I/O log path (default: _io.log next to main log)") + ap.add_argument("--json-out", type=Path, default=None, help="Write structured JSON summary") + ap.add_argument("--verbose", "-v", action="store_true", help="Extra detail (meta + I/O histogram)") + args = ap.parse_args() + + main_log = args.log_file + if not main_log.is_file(): + raise SystemExit(f"Log not found: {main_log}") + + io_log = args.io_log + if io_log is None: + p = str(main_log) + if p.endswith(".log"): + candidate = Path(p.replace(".log", "_io.log")) + if candidate.is_file(): + io_log = candidate + + records_raw = load_timing_records(main_log) + records, filtered_count = filter_excluded_phases(records_raw, EXCLUDED_PHASES) + meta = load_meta_records(main_log) + + report = print_report( + records, + meta, + io_log, + args.verbose, + excluded_phases=EXCLUDED_PHASES, + filtered_count=filtered_count, + ) + + if args.json_out: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + with args.json_out.open("w", encoding="utf-8") as f: + json.dump(report, f, indent=2) + print(f"Wrote JSON summary to {args.json_out}") + + +if __name__ == "__main__": + main() diff --git a/src/scripts/analyze_streaming_performance.py b/src/scripts/analyze_streaming_performance.py new file mode 100644 index 0000000..b19314b --- /dev/null +++ b/src/scripts/analyze_streaming_performance.py @@ -0,0 +1,871 @@ +#!/usr/bin/env python3 +""" +Analyze streaming debate performance from log files. + +Extracts timing metrics from DEBUG logs to calculate: +- Speaker bubbles (waiting for TTS chunks) +- Listener bubbles (post-playback processing) +- ASR real-time factors +- End-to-end chunk latency +- File I/O overhead +- Tree update costs +- Pipeline efficiency + +Usage: + python analyze_streaming_performance.py [--output output.json] [--verbose] +""" + +import argparse +import csv +import json +import re +from collections import defaultdict +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +@dataclass +class Event: + """Single log event with timestamp.""" + timestamp: float + component: str + event_type: str + data: Dict[str, str] + raw_line: str + + +@dataclass +class ChunkMetrics: + """Metrics for a single TTS chunk.""" + chunk_idx: int + detected_time: Optional[float] = None + copied_time: Optional[float] = None + assembled_time: Optional[float] = None + playback_start_time: Optional[float] = None + playback_end_time: Optional[float] = None + duration: Optional[float] = None + detection_latency: Optional[float] = None + + @property + def e2e_latency(self) -> Optional[float]: + """End-to-end latency: detection to playback complete.""" + if self.detected_time and self.playback_end_time: + return self.playback_end_time - self.detected_time + return None + + +@dataclass +class ASRMetrics: + """Metrics for an ASR operation.""" + audio_start: float + audio_end: float + asr_start_time: float + asr_end_time: float + text_len: int + + @property + def audio_duration(self) -> float: + return self.audio_end - self.audio_start + + @property + def asr_time(self) -> float: + return self.asr_end_time - self.asr_start_time + + @property + def rtf(self) -> float: + """Real-time factor: ASR time / audio duration.""" + if self.audio_duration > 0: + return self.asr_time / self.audio_duration + return 0.0 + + +@dataclass +class TreeUpdateMetrics: + """Metrics for a tree update.""" + start_time: float + end_time: float + word_count: int + + @property + def update_time(self) -> float: + return self.end_time - self.start_time + + +@dataclass +class BubbleMetrics: + """Wait/bubble timing.""" + wait_start: float + wait_end: float + context: str + + @property + def duration(self) -> float: + return self.wait_end - self.wait_start + + +@dataclass +class TurnMetrics: + """Complete metrics for one debate turn.""" + stage: str + side: str + turn_start: Optional[float] = None + turn_end: Optional[float] = None + + # Mode detection + streaming_tts: Optional[bool] = None + streaming_listen: Optional[bool] = None + mode: Optional[str] = None + + # Thread lifecycle + speaker_thread_start: Optional[float] = None + speaker_thread_end: Optional[float] = None + listener_thread_start: Optional[float] = None + listener_thread_end: Optional[float] = None + playback_start: Optional[float] = None + playback_end: Optional[float] = None + + # Generation timing + generation_start: Optional[float] = None + generation_end: Optional[float] = None + + # Batch processing timing + posthoc_chunk_start: Optional[float] = None + posthoc_chunk_end: Optional[float] = None + batch_analyze_start: Optional[float] = None + batch_analyze_end: Optional[float] = None + + # Chunk metrics + chunks: Dict[int, ChunkMetrics] = None + + # ASR metrics + asr_operations: List[ASRMetrics] = None + + # Tree updates + tree_updates: List[TreeUpdateMetrics] = None + + # Bubbles + speaker_bubbles: List[BubbleMetrics] = None + + # File I/O + file_writes: List[Tuple[float, float]] = None # (start, end) + file_reads: List[Tuple[float, float]] = None # (start, end) + + def __post_init__(self): + if self.chunks is None: + self.chunks = {} + if self.asr_operations is None: + self.asr_operations = [] + if self.tree_updates is None: + self.tree_updates = [] + if self.speaker_bubbles is None: + self.speaker_bubbles = [] + if self.file_writes is None: + self.file_writes = [] + if self.file_reads is None: + self.file_reads = [] + + @property + def total_duration(self) -> Optional[float]: + if self.turn_start and self.turn_end: + return self.turn_end - self.turn_start + return None + + @property + def playback_duration(self) -> Optional[float]: + if self.playback_start and self.playback_end: + return self.playback_end - self.playback_start + return None + + @property + def speaker_duration(self) -> Optional[float]: + if self.speaker_thread_start and self.speaker_thread_end: + return self.speaker_thread_end - self.speaker_thread_start + return None + + @property + def listener_duration(self) -> Optional[float]: + if self.listener_thread_start and self.listener_thread_end: + return self.listener_thread_end - self.listener_thread_start + return None + + @property + def generation_time(self) -> Optional[float]: + if self.generation_start and self.generation_end: + return self.generation_end - self.generation_start + return None + + @property + def total_speaker_bubble(self) -> float: + return sum(b.duration for b in self.speaker_bubbles) + + @property + def time_to_first_chunk(self) -> Optional[float]: + """Wait time for the first playable chunk (chunk_1).""" + for b in self.speaker_bubbles: + if b.context == "chunk_1": + return b.duration + return None + + @property + def time_between_chunks(self) -> float: + """Speaker bubble excluding the first chunk wait.""" + return sum(b.duration for b in self.speaker_bubbles if b.context != "chunk_1") + + @property + def listener_bubble(self) -> Optional[float]: + """Time from playback end to listener thread end.""" + if self.playback_end and self.listener_thread_end: + return self.listener_thread_end - self.playback_end + return None + + @property + def true_overlap(self) -> Optional[float]: + """Playback time minus speaker bubbles.""" + if self.playback_duration is not None: + return self.playback_duration - self.total_speaker_bubble + return None + + @property + def avg_asr_rtf(self) -> Optional[float]: + if self.asr_operations: + return sum(op.rtf for op in self.asr_operations) / len(self.asr_operations) + return None + + @property + def audio_duration(self) -> Optional[float]: + """Best-effort total audio duration for the turn.""" + chunk_durations = [ + chunk.duration for chunk in self.chunks.values() if chunk.duration is not None + ] + if chunk_durations: + return sum(chunk_durations) + + if self.asr_operations: + return max(op.audio_end for op in self.asr_operations) + + return None + + @property + def total_tree_update_time(self) -> float: + return sum(u.update_time for u in self.tree_updates) + + @property + def avg_tree_update_time(self) -> Optional[float]: + if self.tree_updates: + return self.total_tree_update_time / len(self.tree_updates) + return None + + @property + def total_file_write_time(self) -> float: + return sum(end - start for start, end in self.file_writes) + + @property + def total_file_read_time(self) -> float: + return sum(end - start for start, end in self.file_reads) + + @property + def bottleneck(self) -> Optional[str]: + """Identify bottleneck using bubble/wait time.""" + times = [] + if self.total_speaker_bubble > 0: + times.append(('SPEAKER', self.total_speaker_bubble)) + if self.listener_bubble is not None and self.listener_bubble > 0: + times.append(('LISTENER', self.listener_bubble)) + + if times: + return max(times, key=lambda x: x[1])[0] + return None + + @property + def overlap_efficiency(self) -> Optional[float]: + """True overlap / (playback + listener bubble).""" + if self.true_overlap and self.playback_duration and self.listener_bubble is not None: + total = self.playback_duration + self.listener_bubble + if total > 0: + return self.true_overlap / total + return None + + @property + def posthoc_chunk_time(self) -> Optional[float]: + """Time to split and stream chunks post-hoc (batch TTS mode).""" + if self.posthoc_chunk_start and self.posthoc_chunk_end: + return self.posthoc_chunk_end - self.posthoc_chunk_start + return None + + @property + def batch_analyze_time(self) -> Optional[float]: + """Time to analyze statement in batch mode (non-streaming listener).""" + if self.batch_analyze_start and self.batch_analyze_end: + return self.batch_analyze_end - self.batch_analyze_start + return None + + +def parse_log_line(line: str) -> Optional[Event]: + """Parse a single log line into an Event.""" + # Match: [Component] event_type key1=value1 key2=value2 t=timestamp + match = re.search(r'\[([^\]]+)\]\s+(\w+)\s+(.*?)\s+t=([\d.]+)', line) + if not match: + return None + + component, event_type, data_str, timestamp = match.groups() + + # Parse key=value pairs + data = {} + for kv_match in re.finditer(r'(\w+)=([^\s]+)', data_str): + key, value = kv_match.groups() + data[key] = value + + return Event( + timestamp=float(timestamp), + component=component, + event_type=event_type, + data=data, + raw_line=line.strip() + ) + + +def extract_turn_key(event: Event) -> Optional[Tuple[str, str]]: + """Extract (stage, side) from event data.""" + stage = event.data.get('stage') + # Some components (e.g., StreamingInputEnv) emit `statement_side` + # instead of `side`, so support both field names. + side = event.data.get('side') or event.data.get('statement_side') + if stage and side: + return (stage, side) + return None + + +def parse_log_file(log_path: Path) -> Dict[Tuple[str, str], TurnMetrics]: + """Parse log file and extract all metrics per turn.""" + turns: Dict[Tuple[str, str], TurnMetrics] = {} + + # Temporary state for tracking multi-event operations + wait_chunk_starts: Dict[Tuple[str, str, int], float] = {} # (stage, side, chunk_idx) -> time + asr_starts: Dict[Tuple[str, str, float, float], float] = {} # (stage, side, start, end) -> time + tree_starts: Dict[Tuple[str, str, int], float] = {} # (stage, side, word_count) -> time + file_write_starts: Dict[Tuple[str, str], float] = {} + active_streaming_turn: Optional[Tuple[str, str]] = None + + with open(log_path, 'r') as f: + for line in f: + event = parse_log_line(line) + if not event: + continue + + turn_key = extract_turn_key(event) + if event.component == 'StreamingInputEnv' and event.event_type == 'thread_start' and turn_key is not None: + active_streaming_turn = turn_key + # StreamingInputEnv lines often omit stage/side after thread_start. + # Reuse the active listener context so ASR/tree events are still attributed. + if turn_key is None and event.component == 'StreamingInputEnv': + if event.event_type == 'thread_start': + st = event.data.get('stage') + sd = event.data.get('statement_side') or event.data.get('side') + if st and sd: + active_streaming_turn = (st, sd) + turn_key = active_streaming_turn + elif active_streaming_turn is not None: + turn_key = active_streaming_turn + if event.event_type == 'thread_end': + active_streaming_turn = None + if not turn_key: + continue + + stage, side = turn_key + if turn_key not in turns: + turns[turn_key] = TurnMetrics(stage=stage, side=side) + + turn = turns[turn_key] + + # Process event + if event.component == 'Turn': + if event.event_type == 'turn_start': + turn.turn_start = event.timestamp + elif event.event_type == 'turn_end': + turn.turn_end = event.timestamp + elif event.event_type == 'mode_config': + turn.streaming_tts = event.data.get('streaming_tts') == 'True' + turn.streaming_listen = event.data.get('streaming_listen') == 'True' + turn.mode = event.data.get('mode') + + elif event.component == 'SpeakerWorker': + if event.event_type == 'thread_start': + turn.speaker_thread_start = event.timestamp + elif event.event_type == 'thread_end': + turn.speaker_thread_end = event.timestamp + elif event.event_type == 'generation_start': + turn.generation_start = event.timestamp + elif event.event_type == 'generation_end': + turn.generation_end = event.timestamp + elif event.event_type == 'posthoc_chunk_start': + turn.posthoc_chunk_start = event.timestamp + elif event.event_type == 'posthoc_chunk_end': + turn.posthoc_chunk_end = event.timestamp + + elif event.component == 'BatchListener': + if event.event_type == 'analyze_start': + turn.batch_analyze_start = event.timestamp + elif event.event_type == 'analyze_end': + turn.batch_analyze_end = event.timestamp + + elif event.component == 'StreamingInputEnv': + if event.event_type == 'thread_start': + turn.listener_thread_start = event.timestamp + elif event.event_type == 'thread_end': + turn.listener_thread_end = event.timestamp + elif event.event_type == 'asr_start': + audio_range = _parse_audio_range(event.data.get('audio_range', '')) + if audio_range is None: + continue + audio_start, audio_end = audio_range + asr_starts[(stage, side, audio_start, audio_end)] = event.timestamp + elif event.event_type == 'asr_end': + audio_range = _parse_audio_range(event.data.get('audio_range', '')) + if audio_range is None: + continue + audio_start, audio_end = audio_range + asr_key = (stage, side, audio_start, audio_end) + if asr_key in asr_starts: + turn.asr_operations.append(ASRMetrics( + audio_start=audio_start, + audio_end=audio_end, + asr_start_time=asr_starts[asr_key], + asr_end_time=event.timestamp, + text_len=int(event.data.get('text_len', 0)) + )) + elif event.event_type == 'tree_update_start': + words = int(event.data.get('words', 0)) + tree_starts[(stage, side, words)] = event.timestamp + elif event.event_type == 'tree_update_end': + words = int(event.data.get('words', 0)) + tree_key = (stage, side, words) + if tree_key in tree_starts: + turn.tree_updates.append(TreeUpdateMetrics( + start_time=tree_starts[tree_key], + end_time=event.timestamp, + word_count=words + )) + + elif event.component == 'PlaybackMain': + if event.event_type == 'playback_start': + turn.playback_start = event.timestamp + elif event.event_type == 'playback_end': + turn.playback_end = event.timestamp + elif event.event_type == 'wait_chunk_start': + chunk_idx = int(event.data.get('chunk_idx', 0)) + wait_chunk_starts[(stage, side, chunk_idx)] = event.timestamp + elif event.event_type == 'wait_chunk_end': + chunk_idx = int(event.data.get('chunk_idx', 0)) + wait_key = (stage, side, chunk_idx) + if wait_key in wait_chunk_starts: + turn.speaker_bubbles.append(BubbleMetrics( + wait_start=wait_chunk_starts[wait_key], + wait_end=event.timestamp, + context=f"chunk_{chunk_idx}" + )) + elif event.event_type == 'chunk_assembled': + chunk_idx = int(event.data.get('chunk_idx', 0)) + if chunk_idx not in turn.chunks: + turn.chunks[chunk_idx] = ChunkMetrics(chunk_idx=chunk_idx) + turn.chunks[chunk_idx].assembled_time = event.timestamp + turn.chunks[chunk_idx].duration = float(event.data.get('duration', '0').rstrip('s')) + elif event.event_type == 'chunk_playback_start': + chunk_idx = int(event.data.get('chunk_idx', 0)) + if chunk_idx not in turn.chunks: + turn.chunks[chunk_idx] = ChunkMetrics(chunk_idx=chunk_idx) + turn.chunks[chunk_idx].playback_start_time = event.timestamp + elif event.event_type == 'chunk_playback_end': + chunk_idx = int(event.data.get('chunk_idx', 0)) + if chunk_idx not in turn.chunks: + turn.chunks[chunk_idx] = ChunkMetrics(chunk_idx=chunk_idx) + turn.chunks[chunk_idx].playback_end_time = event.timestamp + elif event.event_type == 'file_write': + # File write is atomic in our case (start/end in same log line) + write_time = float(event.data.get('write_time', '0').rstrip('s')) + turn.file_writes.append((event.timestamp - write_time, event.timestamp)) + + elif event.component == 'TtsChunkBridge': + if event.event_type == 'chunk_detected': + chunk_idx = int(event.data.get('chunk_idx', 0)) + if chunk_idx not in turn.chunks: + turn.chunks[chunk_idx] = ChunkMetrics(chunk_idx=chunk_idx) + turn.chunks[chunk_idx].detected_time = event.timestamp + elif event.event_type == 'chunk_copied': + chunk_idx = int(event.data.get('chunk_idx', 0)) + if chunk_idx not in turn.chunks: + turn.chunks[chunk_idx] = ChunkMetrics(chunk_idx=chunk_idx) + turn.chunks[chunk_idx].copied_time = event.timestamp + detection_latency = float(event.data.get('detection_latency', '0').rstrip('s')) + turn.chunks[chunk_idx].detection_latency = detection_latency + + return turns + + +def _to_float(value: str, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _to_bool(value: str) -> bool: + return str(value).strip().lower() in {"1", "true", "yes", "y"} + + +def _parse_audio_range(value: str) -> Optional[Tuple[float, float]]: + if not value or "-" not in value: + return None + left, right = value.split("-", 1) + try: + return float(left), float(right.rstrip("s")) + except ValueError: + return None + + +def load_tts_chunk_profiles(outputs_dir: Path) -> Dict[Tuple[str, str], List[Dict[str, str]]]: + """Load per-turn streaming TTS chunk profiles from *_chunks/chunk_profile.csv.""" + profiles: Dict[Tuple[str, str], List[Dict[str, str]]] = {} + if not outputs_dir.exists(): + return profiles + + for chunk_csv in outputs_dir.glob("*_chunks/chunk_profile.csv"): + parent_name = chunk_csv.parent.name # e.g., treedebater_opening_for_chunks + m = re.match(r"^[^_]+_([^_]+)_(for|against)_chunks$", parent_name) + if not m: + continue + stage, side = m.group(1), m.group(2) + rows: List[Dict[str, str]] = [] + try: + with chunk_csv.open("r", encoding="utf-8", errors="replace") as f: + reader = csv.DictReader(f) + for row in reader: + if row: + rows.append(row) + except OSError: + continue + profiles[(stage, side)] = rows + + return profiles + + +def generate_summary( + turns: Dict[Tuple[str, str], TurnMetrics], + tts_profiles: Optional[Dict[Tuple[str, str], List[Dict[str, str]]]] = None, +) -> Dict: + """Generate summary statistics across all turns.""" + summary = { + 'total_turns': len(turns), + 'turns': {} + } + tts_profiles = tts_profiles or {} + + for turn_key, turn in turns.items(): + stage, side = turn_key + turn_id = f"{stage}_{side}" + + turn_summary = { + 'stage': stage, + 'side': side, + 'mode': turn.mode, + 'streaming_tts': turn.streaming_tts, + 'streaming_listen': turn.streaming_listen, + 'total_duration': turn.total_duration, + 'playback_duration': turn.playback_duration, + 'speaker_duration': turn.speaker_duration, + 'listener_duration': turn.listener_duration, + 'generation_time': turn.generation_time, + 'speaker_bubble_total': turn.total_speaker_bubble, + 'time_to_first_chunk': turn.time_to_first_chunk, + 'time_between_chunks': turn.time_between_chunks, + 'speaker_bubble_pct': (turn.total_speaker_bubble / turn.playback_duration * 100) if turn.playback_duration else None, + 'listener_bubble': turn.listener_bubble, + 'listener_bubble_pct': (turn.listener_bubble / turn.listener_duration * 100) if turn.listener_duration and turn.listener_bubble else None, + 'true_overlap': turn.true_overlap, + 'overlap_efficiency': turn.overlap_efficiency, + 'bottleneck': turn.bottleneck, + 'chunk_count': len(turn.chunks), + 'audio_duration': turn.audio_duration, + 'asr_operations': len(turn.asr_operations), + 'avg_asr_rtf': turn.avg_asr_rtf, + 'tree_updates': len(turn.tree_updates), + 'total_tree_update_time': turn.total_tree_update_time, + 'avg_tree_update_time': turn.avg_tree_update_time, + 'total_file_write_time': turn.total_file_write_time, + 'total_file_read_time': turn.total_file_read_time, + 'posthoc_chunk_time': turn.posthoc_chunk_time, + 'batch_analyze_time': turn.batch_analyze_time, + } + + # Chunk details + if turn.chunks: + chunk_latencies = [c.e2e_latency for c in turn.chunks.values() if c.e2e_latency] + if chunk_latencies: + turn_summary['chunk_latency_avg'] = sum(chunk_latencies) / len(chunk_latencies) + turn_summary['chunk_latency_min'] = min(chunk_latencies) + turn_summary['chunk_latency_max'] = max(chunk_latencies) + + # ASR details + if turn.asr_operations: + rtfs = [op.rtf for op in turn.asr_operations] + turn_summary['asr_rtf_min'] = min(rtfs) + turn_summary['asr_rtf_max'] = max(rtfs) + turn_summary['asr_real_time'] = all(rtf < 1.0 for rtf in rtfs) + + # Bubble breakdown + turn_summary['speaker_bubbles'] = [ + {'duration': b.duration, 'context': b.context} + for b in turn.speaker_bubbles + ] + + # Streaming TTS chunk-profile stats (if available) + profile_rows = tts_profiles.get(turn_key, []) + if profile_rows: + ref_counts = [_to_float(r.get("n_ref_used", "0"), 0.0) for r in profile_rows] + refined_counts = [x for x in ref_counts if x > 0] + timed_out = sum(1 for r in profile_rows if _to_bool(r.get("timed_out", "false"))) + chunk_total_times = [_to_float(r.get("chunk_total_s", "0"), 0.0) for r in profile_rows] + tts_api_times = [_to_float(r.get("tts_api_s", "0"), 0.0) for r in profile_rows] + refine_times = [_to_float(r.get("refine_total_s", "0"), 0.0) for r in profile_rows] + first_profile_row = min( + profile_rows, + key=lambda r: int(_to_float(r.get("chunk_idx", "0"), 0.0)), + ) + first_chunk_gen_time_s = _to_float(first_profile_row.get("chunk_total_s", "0"), 0.0) + turn_summary["tts_profile_chunks"] = len(profile_rows) + turn_summary["tts_chunks_refined"] = len(refined_counts) + turn_summary["tts_total_refinements"] = int(sum(ref_counts)) + turn_summary["tts_avg_refinements_per_chunk"] = sum(ref_counts) / len(ref_counts) + turn_summary["tts_avg_refinements_refined_chunks"] = ( + sum(refined_counts) / len(refined_counts) if refined_counts else 0.0 + ) + turn_summary["tts_timed_out_chunks"] = timed_out + turn_summary["first_chunk_gen_time_s"] = first_chunk_gen_time_s + turn_summary["chunk_gen_total_s_avg"] = sum(chunk_total_times) / len(chunk_total_times) + turn_summary["chunk_gen_total_s_min"] = min(chunk_total_times) + turn_summary["chunk_gen_total_s_max"] = max(chunk_total_times) + turn_summary["chunk_tts_api_s_avg"] = sum(tts_api_times) / len(tts_api_times) + turn_summary["chunk_refine_s_avg"] = sum(refine_times) / len(refine_times) + + # Speaker bubble definition for reporting: + # first chunk generation time + time between chunks. + if turn_summary.get("first_chunk_gen_time_s") is not None: + derived_speaker_bubble = turn_summary["first_chunk_gen_time_s"] + turn.time_between_chunks + turn_summary["speaker_bubble_total"] = derived_speaker_bubble + turn_summary["speaker_bubble_pct"] = ( + derived_speaker_bubble / turn.playback_duration * 100 + ) if turn.playback_duration else None + turn_summary["true_overlap"] = ( + turn.playback_duration - derived_speaker_bubble + ) if turn.playback_duration is not None else None + if turn.playback_duration is not None and turn.listener_bubble is not None: + denom = turn.playback_duration + turn.listener_bubble + turn_summary["overlap_efficiency"] = ( + turn_summary["true_overlap"] / denom if denom > 0 else None + ) + else: + turn_summary["overlap_efficiency"] = None + if turn.listener_bubble is not None: + turn_summary["bottleneck"] = ( + "SPEAKER" if derived_speaker_bubble >= turn.listener_bubble else "LISTENER" + ) + + summary['turns'][turn_id] = turn_summary + + return summary + + +def print_summary(summary: Dict, verbose: bool = False): + """Print human-readable summary.""" + print("\n" + "="*80) + print("STREAMING DEBATE PERFORMANCE ANALYSIS") + print("="*80 + "\n") + + print("--- Metric Definitions ---") + print(" --- Timing Overview ---") + print(" Total duration: turn_end - turn_start") + print(" Playback duration: playback_end - playback_start") + print(" Audio duration: sum(chunk durations), fallback=max(ASR audio_end)") + print(" Speaker duration: speaker_thread_end - speaker_thread_start") + print(" Listener duration: listener_thread_end - listener_thread_start") + print(" Generation time: generation_end - generation_start") + print(" --- Bubble Analysis ---") + print(" Speaker bubble: first chunk gen time + Time Between Chunks") + print(" Listener bubble: listener_thread_end - playback_end") + print(" --- Efficiency Metrics ---") + print(" Time to First Chunk: wait time for chunk_1") + print(" Time Between Chunks: sum(waits for chunk_2+)") + print(" True overlap: playback_duration - speaker_bubble") + print(" Overlap efficiency: true_overlap / (playback_duration + listener_bubble)") + print(" Bottleneck: max(speaker bubble, listener bubble)") + print(" --- Pipeline Stats ---") + print(" Speaking side:") + print(" Avg chunk latency: mean(playback_end - detected)") + print(" Chunk gen time: from chunk_profile.csv (chunk_total_s/tts_api_s/refine_total_s)") + print(" TTS refinements: from chunk_profile.csv n_ref_used stats") + print(" TTS timeouts: count(chunks where timed_out=True in chunk_profile.csv)") + print(" Listening side:") + print(" Avg ASR RTF: mean((asr_end - asr_start) / (audio_end - audio_start))") + print(" Avg update time: mean(tree_update_end - tree_update_start)") + print(" --- I/O Overhead ---") + print(" File write time: sum(file_write.write_time)") + print(" File read time: sum(file_read_end - file_read_start)") + print(" --- Batch Mode Metrics ---") + print(" Post-hoc chunk time: posthoc_chunk_end - posthoc_chunk_start") + print(" Batch analyze time: batch_analyze_end - batch_analyze_start") + print() + + print(f"Total turns analyzed: {summary['total_turns']}\n") + + for turn_id, turn in summary['turns'].items(): + print(f"\n{'='*80}") + print(f"Turn: {turn['stage']} ({turn['side']})") + print(f"{'='*80}") + + ideal_audio_duration = 120 if turn['stage'] == 'closing' else 240 + + print(f"\n--- Mode Configuration ---") + mode_desc = turn.get('mode', 'unknown') + print(f" Mode: {mode_desc}") + print(f" Streaming TTS: {turn.get('streaming_tts', 'N/A')}") + print(f" Streaming Listen: {turn.get('streaming_listen', 'N/A')}") + + print(f"\n--- Timing Overview ---") + print(f" Total duration: {turn['total_duration']:.2f}s" if turn['total_duration'] else " Total duration: N/A") + print(f" Audio duration: {turn['audio_duration']:.2f}s (ideal={ideal_audio_duration}s, gap={turn['audio_duration'] - ideal_audio_duration:.2f}s)" if turn.get('audio_duration') is not None else " Audio duration: N/A") + print(f" Playback duration: {turn['playback_duration']:.2f}s" if turn['playback_duration'] else " Playback duration: N/A") + print(f" Speaker duration: {turn['speaker_duration']:.2f}s" if turn['speaker_duration'] else " Speaker duration: N/A") + print(f" Listener duration: {turn['listener_duration']:.2f}s" if turn['listener_duration'] else " Listener duration: N/A") + print(f" Generation time: {turn['generation_time']:.2f}s" if turn['generation_time'] else " Generation time: N/A") + + print(f"\n--- Bubble Analysis ---") + sb_total = turn.get('speaker_bubble_total') + first_wait = turn.get('time_to_first_chunk') + inter_wait = turn.get('time_between_chunks') + sb_pct = turn.get('speaker_bubble_pct') + if sb_total is not None: + if sb_pct is not None: + print(f" Speaker bubble: {sb_total:.2f}s ({sb_pct:.1f}% of playback)") + else: + print(f" Speaker bubble: {sb_total:.2f}s (playback duration N/A, no %)") + if turn['listener_bubble'] is not None: + lb_pct = turn.get('listener_bubble_pct') + if lb_pct is not None: + print(f" Listener bubble: {turn['listener_bubble']:.2f}s ({lb_pct:.1f}% of listener time)") + else: + print(f" Listener bubble: {turn['listener_bubble']:.2f}s") + + print(f"\n--- Efficiency Metrics ---") + print(f" Time to First Chunk: {first_wait:.2f}s" if first_wait is not None else " Time to First Chunk: N/A") + print(f" Time Between Chunks: {inter_wait:.2f}s") + if turn['true_overlap'] is not None: + print(f" True overlap: {turn['true_overlap']:.2f}s") + if turn['overlap_efficiency'] is not None: + print(f" Overlap efficiency: {turn['overlap_efficiency']*100:.1f}%") + if turn['bottleneck']: + print(f" Bottleneck: {turn['bottleneck']}") + + print(f"\n--- Pipeline Stats (Speaking Side) ---") + tts_profile_chunks = turn.get('tts_profile_chunks') + if tts_profile_chunks is not None: + print( + f" Chunks processed: tts_profile_chunks={tts_profile_chunks} " + f"(playback_chunks={turn['chunk_count']})" + ) + else: + print(f" Chunks processed: tts_profile_chunks=N/A (playback_chunks={turn['chunk_count']})") + if turn.get('chunk_latency_avg'): + print(f" Avg chunk latency: {turn['chunk_latency_avg']:.2f}s (min={turn['chunk_latency_min']:.2f}s, max={turn['chunk_latency_max']:.2f}s)") + if turn.get('chunk_gen_total_s_avg') is not None: + print( + f" Chunk gen time: avg={turn['chunk_gen_total_s_avg']:.2f}s " + f"(min={turn['chunk_gen_total_s_min']:.2f}s, max={turn['chunk_gen_total_s_max']:.2f}s)" + ) + print( + f" Chunk gen breakdown: avg_tts_api={turn['chunk_tts_api_s_avg']:.2f}s " + f"avg_refine={turn['chunk_refine_s_avg']:.2f}s" + ) + if turn.get('tts_profile_chunks') is not None: + print( + f" TTS refinements: total={turn['tts_total_refinements']} " + f"refined_chunks={turn['tts_chunks_refined']}/{turn['tts_profile_chunks']} " + f"avg/chunk={turn['tts_avg_refinements_per_chunk']:.2f}" + ) + print( + f" TTS timeouts: {turn['tts_timed_out_chunks']} chunk(s)" + ) + + print(f"\n--- Pipeline Stats (Listening Side) ---") + print(f" ASR operations: {turn['asr_operations']}") + if turn['avg_asr_rtf'] is not None: + status = "✓ REAL-TIME" if turn.get('asr_real_time') else "✗ LAGGING" + print(f" Avg ASR RTF: {turn['avg_asr_rtf']:.3f} {status}") + print(f" Tree updates: {turn['tree_updates']}") + print( + f" Avg update time: {turn['avg_tree_update_time']:.2f}s" + if turn.get('avg_tree_update_time') is not None + else " Avg update time: N/A" + ) + + print(f"\n--- I/O Overhead ---") + print(f" File write time: {turn['total_file_write_time']:.3f}s") + print(f" File read time: {turn['total_file_read_time']:.3f}s") + + # Mode-specific metrics + if turn.get('posthoc_chunk_time') is not None: + print(f"\n--- Batch TTS Processing ---") + print(f" Post-hoc chunk time: {turn['posthoc_chunk_time']:.3f}s (split + stream)") + + if turn.get('batch_analyze_time') is not None: + print(f"\n--- Batch Listener Processing ---") + print(f" Batch analyze time: {turn['batch_analyze_time']:.3f}s (statement analysis)") + + if verbose and turn['speaker_bubbles']: + print(f"\n--- Speaker Bubble Breakdown ---") + for i, bubble in enumerate(turn['speaker_bubbles'], 1): + print(f" Bubble {i}: {bubble['duration']:.3f}s ({bubble['context']})") + + +def main(): + parser = argparse.ArgumentParser(description='Analyze streaming debate performance from log files') + parser.add_argument('log_file', type=str, help='Path to log file') + parser.add_argument('--output', '-o', type=str, help='Output JSON file (optional)') + parser.add_argument( + '--outputs-dir', + type=str, + default=None, + help='Directory containing per-turn *_chunks/chunk_profile.csv (default: _outputs)', + ) + parser.add_argument('--verbose', '-v', action='store_true', help='Verbose output') + + args = parser.parse_args() + + log_path = Path(args.log_file) + if not log_path.exists(): + print(f"Error: Log file not found: {log_path}") + return 1 + + print(f"Parsing log file: {log_path}") + turns = parse_log_file(log_path) + outputs_dir = Path(args.outputs_dir) if args.outputs_dir else Path(str(log_path).replace('.log', '_outputs')) + tts_profiles = load_tts_chunk_profiles(outputs_dir) + + if not turns: + print("No streaming turns found in log file.") + return 1 + + summary = generate_summary(turns, tts_profiles=tts_profiles) + + print_summary(summary, verbose=args.verbose) + + if args.output: + output_path = Path(args.output) + with open(output_path, 'w') as f: + json.dump(summary, f, indent=2) + print(f"\nDetailed metrics saved to: {output_path}") + + return 0 + + +if __name__ == '__main__': + exit(main()) diff --git a/src/searcher.py b/src/searcher.py index ae5f0e3..c0add3a 100644 --- a/src/searcher.py +++ b/src/searcher.py @@ -9,6 +9,7 @@ from utils.db import get_cached_answer, save_query from utils.model import HelperClient from utils.prompts import iterative_search_prompt, search_prompt, summarize_result_prompt +from utils.timing_log import log_llm_io from utils.tool import logger MAX_QUERY = 10 @@ -106,13 +107,13 @@ def get_search_query(llm_client, motion, stance, claim=None, extra_prompt=None): prompt += "\n\n**Claim**: {claim}\n\n".format(claim=claim) if extra_prompt is not None: prompt += "\n\n" + extra_prompt - logger.debug("[Search-Helper-Prompt] " + prompt.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="searcher", title="Search-Helper-Prompt", body=prompt.strip()) response = llm_client(prompt=prompt)[0] - logger.debug("[Search-Helper-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="searcher", title="Search-Helper-Response", body=response.strip()) queries = find_tavily(response) queries = [q.replace('"', "") for q in queries] - logger.debug("[Search-Helper-Queries] " + " ||| ".join(queries)) + log_llm_io(logger, phase="searcher", title="Search-Helper-Queries", body=" ||| ".join(queries)) return queries @@ -123,11 +124,11 @@ def update_search_query(llm_client, motion, stance, claim, results): prompt = iterative_search_prompt.format( motion=motion, stance=stance, claim=claim, results=json.dumps(simple_results, indent=2) ) - logger.debug("[Search-Helper-Update-Prompt] " + prompt.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="searcher", title="Search-Helper-Update-Prompt", body=prompt.strip()) response = llm_client(prompt=prompt)[0] - logger.debug("[Search-Helper-Update-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="searcher", title="Search-Helper-Update-Response", body=response.strip()) queries = find_tavily(response) - logger.debug("[Search-Helper-Queries] " + " ||| ".join(queries)) + log_llm_io(logger, phase="searcher", title="Search-Helper-Queries", body=" ||| ".join(queries)) return queries @@ -136,9 +137,9 @@ def summarize_search_result(llm_client, claim, search_results): query = r["query"] content = {"title": r["title"], "url": r["url"], "content": r["content"]} prompt = summarize_result_prompt.format(claim=claim, query=query, results=json.dumps(content, indent=2)) - logger.debug("[Search-Summarize-Prompt] " + prompt.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="searcher", title="Search-Summarize-Prompt", body=prompt.strip()) response = llm_client(prompt=prompt)[0] - logger.debug("[Search-Summarize-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="searcher", title="Search-Summarize-Response", body=response.strip()) r["argument"] = response return search_results diff --git a/src/streaming/README.md b/src/streaming/README.md new file mode 100644 index 0000000..804ca7d --- /dev/null +++ b/src/streaming/README.md @@ -0,0 +1,77 @@ +# `streaming` package + +This package implements **chunked audio → ASR → debate-tree updates** for TreeDebater, plus helpers that **feed chunk files into a watch directory** so the listener can process speech incrementally (similar in spirit to debate-anonymous chunked pipelines). + +Run everything from **`TreeDebater/src`** so the `streaming` package and sibling modules (`env`, `agents`, `utils`, …) resolve correctly. Alternatively, put `src` on `PYTHONPATH` and run from the repo root. + +--- + +## Modules + +| Module | Role | +|--------|------| +| **`env.py`** | `StreamingInputEnv` (watch directory → Whisper → `TreeDebater._analyze_statement`), `StreamingDebateEnv` (full debate with a streaming listener each speech turn, post-hoc MP3 split into chunks), path helpers (`tts_outputs_dir_from_log`, …), and the main CLI (`--debate` or watch-only). | +| **`overlap.py`** | `OverlappingStreamingDebateEnv`: playback-driven main thread, optional `streaming_listen`, optional live **streaming TTS** chunk bridge, post-hoc fallback when no live chunks were copied. | +| **`bridges.py`** | **`run_live_chunk_bridge`**: poll `log_files/_outputs` for stable full-speaker MP3s, split with pydub, write `{side}_chunkNNN.*` into a watch dir (for demos next to `env.py`). **`run_streaming_tts_chunk_copy_bridge`**: copy stable `chunk_NNN.*` from streaming TTS output into `{speaker_side}_chunkNNN.*` for overlap runs. | +| **`chunk_audio.py`** | Split audio with fixed duration or silence detection, **`stream_chunks_to_directory`** (real-time or burst pace), **`clear_watch_chunk_files`**, and a small CLI for one-off file simulation. | +| **`run_listen_demo.py`** | Standalone **listener + bridge** demo: start `StreamingInputEnv` on a TreeDebater side, optionally run the live MP3 bridge against `N_outputs`, or one-shot chunk a file. | + +The package **`__init__.py`** does **not** eagerly import heavy modules. Use submodule imports (see below), or lazy access such as `import streaming` then `streaming.env` (same effect as `import streaming.env`). + +--- + +## Command-line entrypoints + +From `TreeDebater/src`: + +```bash +# Full debate with streaming listener (non-overlap): same YAML idea as env.py +python -m streaming.env --debate --config configs/base_st_io.yml + +# Watch-only: one TreeDebater listens on a directory of chunk MP3s +python -m streaming.env --config configs/base_st.yml --watch-dir /tmp/watch --debater-side for + +# Overlap + playback-driven timing (see configs/overlap_debate.yml) +python -m streaming.overlap --config configs/overlap_debate.yml + +# Split one file into chunks and write into a watch dir (pydub-only logic) +python -m streaming.chunk_audio --audio-file path/to/speech.mp3 --watch-dir /tmp/watch + +# Live bridge from log_files/_outputs + listener (run while env.py produces TTS) +python -m streaming.run_listen_demo --config configs/base_st.yml --watch-dir /tmp/watch --debater-side for +``` + +Use `--help` on any of the above for full flags. + +--- + +## Python imports + +Prefer explicit submodule imports: + +```python +from streaming.env import StreamingDebateEnv, StreamingInputEnv, StreamingInputConfig +from streaming.env import opponent_side, tts_outputs_dir_from_log +from streaming.overlap import OverlappingStreamingDebateEnv +from streaming.bridges import run_live_chunk_bridge, run_streaming_tts_chunk_copy_bridge +from streaming.chunk_audio import split_audio, stream_chunks_to_directory, clear_watch_chunk_files +``` + +TreeDebater’s `ouragents.py` uses `StreamingInputEnv` / `StreamingInputConfig` from **`streaming.env`** for `start_streaming_listen`. + +--- + +## YAML knobs (debate configs) + +These are documented more fully in the main TreeDebater README; at a glance: + +- **`streaming_tts`** (speaker): use the streaming TTS pipeline; overlap mode can bridge `chunk_NNN` files when combined with `time_control` and the overlap env. +- **`streaming_listen`** (listener): in overlap runs, start `StreamingInputEnv` on a background thread and set `tree_via_streaming` on the turn record so `listen()` can avoid duplicating full `_analyze_statement` when the tree was already updated from audio. + +--- + +## Dependencies + +- **`env.py` / `overlap.py`**: `pydub`, `openai` (Whisper), `yaml`, TreeDebater `env` / agents / utils (API keys as elsewhere). +- **`chunk_audio.py`**: `pydub` only. +- **`bridges.py`**: `pydub`, `utils.tool` logger for the streaming-TTS copy bridge; live MP3 bridge prints to stdout for simple demos. diff --git a/src/streaming/__init__.py b/src/streaming/__init__.py new file mode 100644 index 0000000..fe8f8d5 --- /dev/null +++ b/src/streaming/__init__.py @@ -0,0 +1,38 @@ +""" +Streaming ASR pipeline, debate wrappers, chunk bridges, and audio simulation for TreeDebater. + +Use explicit submodule imports (recommended):: + + from streaming.env import StreamingDebateEnv, StreamingInputEnv + from streaming.overlap import OverlappingStreamingDebateEnv + +Or lazy submodules:: + + import streaming + env = streaming.env # same as ``import streaming.env as env`` + +Runnable modules:: + + python -m streaming.env --help + python -m streaming.overlap --help + python -m streaming.chunk_audio --help + python -m streaming.run_listen_demo --help +""" + +from __future__ import annotations + +import importlib +from typing import Any + +_SUBMODULES = frozenset({"bridges", "chunk_audio", "env", "overlap", "run_listen_demo"}) +__all__ = sorted(_SUBMODULES) + + +def __getattr__(name: str) -> Any: + if name in _SUBMODULES: + return importlib.import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return sorted(set(globals()) | set(_SUBMODULES)) diff --git a/src/streaming/bridges.py b/src/streaming/bridges.py new file mode 100644 index 0000000..211f29b --- /dev/null +++ b/src/streaming/bridges.py @@ -0,0 +1,266 @@ +""" +Bridges speaker audio into a **watch directory** for :class:`streaming.env.StreamingInputEnv`. + +1. **Full TTS MP3** (``run_live_chunk_bridge``): poll ``log_files/N_outputs`` for stable + ``{type}_{stage}_{speaker}.mp3`` from ``agents.post_process``, split with pydub, write + ``{chunk_log_id}_chunkNNN.*`` — for running **alongside** ``env.py`` / demos. + +2. **Streaming TTS partial files** (``run_streaming_tts_chunk_copy_bridge``): poll a + per-turn ``*_chunks`` directory for stable ``chunk_NNN.*`` from streaming TTS, **copy** + (no re-encode) to ``{speaker_side}_chunkNNN.*`` for overlap debate + playback-driven env. +""" + +from __future__ import annotations + +import shutil +import threading +import time +from pathlib import Path +from typing import Dict, List, Optional, Set + +from pydub import AudioSegment + +from .chunk_audio import split_audio, stream_chunks_to_directory +from utils.tool import logger + +_REPO_ROOT = Path(__file__).resolve().parents[2] + +_STAGE_ORDER = {"opening": 0, "rebuttal": 1, "closing": 2} + +__all__ = [ + "default_outputs_dir", + "infer_session_log_id", + "list_speaker_mp3s", + "run_live_chunk_bridge", + "run_streaming_tts_chunk_copy_bridge", +] + + +def default_outputs_dir(log_id: str) -> Path: + return (_REPO_ROOT / "log_files" / f"{log_id}_outputs").resolve() + + +def infer_session_log_id(repo_root: Path | None = None) -> str | None: + """Largest numeric ``N`` with ``log_files/N_outputs`` present.""" + root = repo_root if repo_root is not None else _REPO_ROOT + log_dir = root / "log_files" + if not log_dir.is_dir(): + return None + best: int | None = None + for p in log_dir.iterdir(): + if not p.is_dir() or not p.name.endswith("_outputs"): + continue + stem = p.name[: -len("_outputs")] + if stem.isdigit(): + n = int(stem) + if best is None or n > best: + best = n + return str(best) if best is not None else None + + +def list_speaker_mp3s(outputs_dir: Path, speaker_side: str, ext: str) -> list[Path]: + """``agents.py`` TTS names: ``{type}_{stage}_{side}.mp3`` — keep files for ``speaker_side``.""" + ext = ext.lower().lstrip(".") + out: list[Path] = [] + for p in outputs_dir.glob(f"*.{ext}"): + if not p.is_file(): + continue + parts = p.stem.split("_") + if len(parts) >= 2 and parts[-1] == speaker_side: + out.append(p) + + def sort_key(p: Path) -> tuple: + parts = p.stem.split("_") + if len(parts) >= 2: + rank = _STAGE_ORDER.get(parts[-2], 99) + else: + rank = 99 + return (rank, p.name.lower()) + + return sorted(out, key=sort_key) + + +def run_live_chunk_bridge( + outputs_dir: Path, + watch_dir: Path, + speaker_side: str, + chunk_log_id: str, + stop_event: threading.Event, + *, + audio_format: str = "mp3", + split_mode: str = "fixed", + chunk_seconds: float = 10.0, + silence_window_seconds: float = 0.7, + poll_interval: float = 1.0, + stable_polls: int = 2, + skip_initial_files: bool = True, +) -> None: + """ + Until ``stop_event`` is set, poll ``outputs_dir`` for new ``*_*_{speaker_side}.mp3``, + split each stable file into chunks, and append to ``watch_dir`` using monotonic + ``{chunk_log_id}_chunkNNN`` indices. + """ + outputs_dir = Path(outputs_dir).resolve() + watch_dir = Path(watch_dir).resolve() + ext = audio_format.lower().lstrip(".") + + streamed: Set[str] = set() + if skip_initial_files: + for p in list_speaker_mp3s(outputs_dir, speaker_side, ext): + streamed.add(str(p.resolve())) + + size_stable: Dict[str, tuple[int, int]] = {} + next_chunk_idx = 1 + + print( + f"[live_chunk_bridge] Watching {outputs_dir} for new {speaker_side!r} speech " + f"(skip_initial={skip_initial_files}), writing chunks to {watch_dir} as log_id={chunk_log_id!r}." + ) + + while not stop_event.is_set(): + for path in list_speaker_mp3s(outputs_dir, speaker_side, ext): + key = str(path.resolve()) + if key in streamed: + continue + try: + sz = path.stat().st_size + except OSError: + continue + if sz < 2048: + continue + prev = size_stable.get(key) + if prev is None or prev[0] != sz: + size_stable[key] = (sz, 1) + continue + last_sz, cnt = prev + cnt += 1 + size_stable[key] = (last_sz, cnt) + if cnt < stable_polls: + continue + + try: + audio = AudioSegment.from_file(str(path)) + except Exception as e: + print(f"[live_chunk_bridge] Could not load {path.name} yet ({e}); retrying later.") + size_stable.pop(key, None) + continue + + if split_mode == "fixed": + chunks = split_audio(audio, mode="fixed", time_seconds=chunk_seconds) + else: + chunks = split_audio(audio, mode="silence", time_seconds=silence_window_seconds) + + print( + f"[live_chunk_bridge] {path.name} ({len(audio) / 1000.0:.1f}s) → {len(chunks)} chunk(s), " + f"starting at chunk index {next_chunk_idx}." + ) + next_chunk_idx = stream_chunks_to_directory( + chunks, + watch_dir, + chunk_log_id, + audio_format=audio_format, + dry_run=False, + max_total_seconds=None, + chunk_index_start=next_chunk_idx, + ) + streamed.add(key) + size_stable.pop(key, None) + + time.sleep(poll_interval) + + print("[live_chunk_bridge] Stopped.") + + +def run_streaming_tts_chunk_copy_bridge( + chunks_dir: Path, + watch_dir: Path, + speaker_side: str, + stop_event: threading.Event, + live_chunk_counter: List[int], + *, + audio_format: str = "mp3", + poll_interval: float = 0.4, + stable_rounds: int = 2, + min_bytes: int = 512, +) -> None: + """ + Poll ``chunks_dir`` for ``chunk_NNN.`` from streaming TTS; when byte-size is + stable, copy to ``watch_dir`` as ``{speaker_side}_chunkNNN.`` (same contract as + :func:`streaming.chunk_audio.stream_chunks_to_directory` filenames for that log id). + + Increments ``live_chunk_counter[0]`` for each successful copy (used to skip post-hoc + split when overlap debate already fed chunks). + """ + chunks_dir = Path(chunks_dir).resolve() + watch_dir = Path(watch_dir).resolve() + watch_dir.mkdir(parents=True, exist_ok=True) + ext = audio_format.lower().lstrip(".") + size_stable: dict[str, tuple[int, int]] = {} + copied_indices: set[int] = set() + chunk_detected_time: dict[int, float] = {} + + def _chunk_index(path: Path) -> Optional[int]: + if path.suffix.lower() != f".{ext}": + return None + parts = path.stem.split("_") + if len(parts) != 2 or parts[0] != "chunk": + return None + try: + return int(parts[1]) + except ValueError: + return None + + while not stop_event.is_set(): + if not chunks_dir.is_dir(): + time.sleep(poll_interval) + continue + + for path in sorted(chunks_dir.glob(f"chunk_*.{ext}"), key=lambda p: p.stat().st_mtime): + idx = _chunk_index(path) + if idx is None or idx in copied_indices: + continue + key = str(path.resolve()) + try: + sz = path.stat().st_size + except OSError: + continue + if sz < min_bytes: + continue + + if idx not in chunk_detected_time: + chunk_detected_time[idx] = time.time() + logger.debug( + f"[TtsChunkBridge] chunk_detected chunk_idx={idx} size={sz} t={time.time():.3f}" + ) + + prev = size_stable.get(key) + if prev is None or prev[0] != sz: + size_stable[key] = (sz, 1) + continue + _, cnt = size_stable[key] + cnt += 1 + size_stable[key] = (sz, cnt) + if cnt < stable_rounds: + continue + + # PlaybackMain consumes chunk001+, while streaming TTS emits chunk_000+. + # Shift to 1-based indices when copying into watch_dir. + playback_idx = idx + 1 + dest = watch_dir / f"{speaker_side}_chunk{playback_idx:03d}.{ext}" + try: + copy_start = time.time() + shutil.copy2(path, dest) + copy_end = time.time() + copied_indices.add(idx) + live_chunk_counter[0] += 1 + detection_latency = copy_end - chunk_detected_time[idx] + logger.debug( + f"[TtsChunkBridge] chunk_copied chunk_idx={idx} playback_chunk_idx={playback_idx} " + f"detection_latency={detection_latency:.3f}s " + f"copy_time={copy_end - copy_start:.3f}s t={time.time():.3f}" + ) + logger.info(f"[TtsChunkBridge] {path.name} → {dest.name} (n={live_chunk_counter[0]})") + except OSError as e: + logger.warning(f"[TtsChunkBridge] copy failed {path} → {dest}: {e}") + + time.sleep(poll_interval) diff --git a/src/streaming/chunk_audio.py b/src/streaming/chunk_audio.py new file mode 100644 index 0000000..414ca5d --- /dev/null +++ b/src/streaming/chunk_audio.py @@ -0,0 +1,196 @@ +""" +Write timed audio chunks into a watch directory (same contract as +debate-anonymous ``simulate_online_audio_stream.py``): filenames +``_chunkNNN.`` so :mod:`streaming.env` groups by ``log_id``. + +Use an **opponent** recording with :mod:`streaming.env` / :mod:`streaming.run_listen_demo` +(``--debater-side`` is your agent; audio is attributed to the other side). + +Depends on ``pydub`` only for chunk I/O. :func:`clear_watch_chunk_files` is a small helper +for demos (no extra deps). Run from ``TreeDebater/src``:: + + python -m streaming.chunk_audio --audio-file path/to/opponent_speech.mp3 --watch-dir /tmp/watch_demo +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path +from typing import List + +from pydub import AudioSegment +from pydub.silence import split_on_silence + + +def clear_watch_chunk_files(watch_dir: Path, audio_format: str = "mp3") -> None: + """Remove ``*.{audio_format}`` files under ``watch_dir`` (ignore subdirs / other suffixes).""" + watch_dir = Path(watch_dir) + if not watch_dir.exists(): + return + suf = f".{audio_format.lower().lstrip('.')}" + for entry in watch_dir.iterdir(): + try: + if entry.is_file() and entry.suffix.lower() == suf: + entry.unlink() + except OSError as e: + print(f"Warning: could not remove {entry}: {e}", file=sys.stderr) + + +def split_audio(audio: AudioSegment, mode: str, time_seconds: float) -> List[AudioSegment]: + """Split ``audio`` into chunks (``time_seconds`` is chunk length for ``fixed``, min silence for ``silence``).""" + if mode == "fixed": + chunk_ms = int(time_seconds * 1000) + chunks: List[AudioSegment] = [] + for i in range(0, len(audio), chunk_ms): + chunk = audio[i : i + chunk_ms] + if len(chunk) > 500: + chunks.append(chunk) + return chunks + if mode == "silence": + chunks = split_on_silence( + audio, + min_silence_len=int(time_seconds * 1000), + silence_thresh=-60, + keep_silence=200, + seek_step=10, + ) + return [c for c in chunks if len(c) > 500] + raise ValueError(f"Invalid mode: {mode}") + + +def stream_chunks_to_directory( + chunks: List[AudioSegment], + watch_dir: Path, + log_id: str, + audio_format: str = "mp3", + dry_run: bool = False, + max_total_seconds: float | None = None, + chunk_index_start: int = 1, + realtime_pace: bool = True, +) -> int: + """ + Write each chunk to ``watch_dir``, sleeping ~chunk duration between writes (approx. real-time). + + ``chunk_index_start`` continues ``{log_id}_chunkNNN`` across multiple source files. + Returns the next chunk index after this batch. + + If ``realtime_pace`` is False, writes all chunks back-to-back (no sleep); use when a + separate thread simulates playback timing (e.g. playback-driven debate env). + """ + watch_dir = Path(watch_dir) + watch_dir.mkdir(parents=True, exist_ok=True) + + total = len(chunks) + accumulated_seconds = 0.0 + last_written_idx = chunk_index_start - 1 + for i, chunk in enumerate(chunks): + idx = chunk_index_start + i + filename = f"{log_id}_chunk{idx:03d}.{audio_format}" + out_path = watch_dir / filename + + duration_seconds = len(chunk) / 1000.0 + accumulated_seconds += duration_seconds + if dry_run: + print( + f"[dry-run] Would write {out_path} " + f"(chunk_duration_s={duration_seconds:.3f}, accumulated_s={accumulated_seconds:.3f})" + ) + else: + chunk.export(str(out_path), format=audio_format) + ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + print( + f"[{ts}] Wrote {out_path} " + f"(chunk_duration_s={duration_seconds:.3f}, accumulated_s={accumulated_seconds:.3f})" + ) + last_written_idx = idx + + if max_total_seconds is not None and accumulated_seconds >= max_total_seconds: + print( + f"Reached max_total_seconds={max_total_seconds:.3f}s for log_id='{log_id}'. " + "Stopping streaming; remaining chunks will be dropped." + ) + return last_written_idx + 1 + + if realtime_pace and i < total - 1: + sleep_seconds = len(chunk) / 1000.0 + if sleep_seconds > 0: + time.sleep(sleep_seconds) + + return last_written_idx + 1 + + +__all__ = [ + "clear_watch_chunk_files", + "split_audio", + "stream_chunks_to_directory", + "parse_args", + "main", +] + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Split an audio file and stream chunk files into a directory in (approximate) real time." + ) + p.add_argument("--audio-file", type=str, required=True) + p.add_argument("--watch-dir", type=str, required=True) + p.add_argument( + "--log-id", + type=str, + default=None, + help="Filename prefix before _chunkNNN. Default: stem of audio file before first underscore.", + ) + p.add_argument("--split-mode", type=str, choices=["fixed", "silence"], default="fixed") + p.add_argument("--chunk-seconds", type=float, default=10.0) + p.add_argument("--silence-window-seconds", type=float, default=0.7) + p.add_argument("--audio-format", type=str, default="mp3") + p.add_argument("--dry-run", action="store_true") + p.add_argument("--max-total-seconds", type=float, default=None) + return p.parse_args() + + +def main() -> None: + args = parse_args() + audio_path = Path(args.audio_file) + if not audio_path.exists(): + raise FileNotFoundError(audio_path) + + if args.log_id is not None: + log_id: str = args.log_id + else: + stem = audio_path.stem + log_id = stem.split("_")[0] if "_" in stem else stem + + print(f"Using log_id='{log_id}'") + audio = AudioSegment.from_file(str(audio_path)) + if args.split_mode == "fixed": + print(f"Splitting into fixed {args.chunk_seconds:.2f}s chunks...") + chunks = split_audio(audio, mode="fixed", time_seconds=args.chunk_seconds) + else: + print(f"Splitting on silence (min {args.silence_window_seconds:.2f}s)...") + chunks = split_audio(audio, mode="silence", time_seconds=args.silence_window_seconds) + + print(f"Created {len(chunks)} chunks.") + watch_dir = Path(args.watch_dir) + print(f"Streaming to '{watch_dir}'...") + stream_chunks_to_directory( + chunks=chunks, + watch_dir=watch_dir, + log_id=log_id, + audio_format=args.audio_format, + dry_run=bool(args.dry_run), + max_total_seconds=args.max_total_seconds, + ) + print("Done streaming chunks.") + + +if __name__ == "__main__": + _src = Path(__file__).resolve().parent.parent + _root = _src.parent + for _p in (_src, _root): + _ps = str(_p) + if _ps not in sys.path: + sys.path.insert(0, _ps) + main() diff --git a/src/streaming/env.py b/src/streaming/env.py new file mode 100644 index 0000000..f4c863a --- /dev/null +++ b/src/streaming/env.py @@ -0,0 +1,885 @@ +""" +Streaming input environment: watch a directory for chunked audio, transcribe, +buffer text, and incrementally update a TreeDebater's debate trees. + +This mirrors debate-anonymous ``chunked_pipeline.py`` (watch → min-audio buffer → +ASR → min-text buffer → analyze) but routes tree updates through +``TreeDebater._analyze_statement`` instead of a standalone ``analyze_statement``. + +**Full debate + streaming listen** + +Use ``StreamingDebateEnv`` (same YAML as ``env.py``): runs ``Env.play()`` stages, and on each +speech turn starts ``StreamingInputEnv`` on the **listener** (opponent TreeDebater), runs the +speaker’s generation + TTS into ``log_files/N_outputs``, then streams that MP3 into per-turn +watch subdirs so the listener does streaming ASR + ``_analyze_statement`` while chunks arrive. + +CLI (from ``TreeDebater/src``):: + + python -m streaming.env --debate --config configs/base_st_io.yml + +Watch-only (listener alone, no ``env.py``):: + + python -m streaming.env --config ... --watch-dir /tmp/w --debater-side for + +Requires: ``pydub``, ``openai`` (Whisper), and API keys as in ``utils.constants``. + +Run scripts from ``TreeDebater/src``, or set ``PYTHONPATH`` to include ``src``. + +When run as ``python -m streaming.env``, debate ``env`` is imported before ``utils.tool`` +where applicable so logging matches ``env.py``. +""" + +from __future__ import annotations + +import sys + +import argparse +import json +import threading +import time +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import yaml +from openai import OpenAI +from pydub import AudioSegment + +from utils.constants import CLOSING_TIME, OPENING_TIME, REBUTTAL_TIME +from utils.tool import logger + + +def transcribe_audio_segment(segment: AudioSegment, audio_format: str = "mp3") -> str: + """Transcribe a pydub ``AudioSegment`` using OpenAI Whisper (``whisper-1``).""" + buf = BytesIO() + segment.export(buf, format=audio_format) + buf.seek(0) + buf.name = f"audio.{audio_format}" + client = OpenAI() + transcript = client.audio.transcriptions.create(model="whisper-1", file=buf, language="en") + return (transcript.text or "").strip() + + +def _log_id_from_filename(filename: str) -> str: + """Match ``chunked_pipeline`` grouping: stem before first underscore.""" + return Path(filename).stem.split("_")[0] + + +def opponent_side(side: str) -> str: + """Return the other debate side.""" + if side == "for": + return "against" + if side == "against": + return "for" + raise ValueError(f"side must be 'for' or 'against', got {side!r}") + + +@dataclass +class StreamingInputConfig: + watch_dir: Path + motion: str + stage: str # opening | rebuttal | closing — sets TreeDebater.status for analysis + statement_side: str # side whose speech is in the audio (here: opponent of debater.side) + min_audio_seconds: float = 30.0 + min_text_words: int = 50 + poll_interval: float = 1.0 + audio_file_glob: str = "*.mp3" + max_audio_wait_seconds: Optional[float] = None + max_text_wait_seconds: Optional[float] = None + max_total_audio_seconds: Optional[float] = None + audio_format: str = "mp3" + playback_cursor: Optional[List[float]] = None # shared cursor for real-time synchronization + + +class StreamingInputEnv: + """ + Watch ``watch_dir`` for new audio files, buffer by duration and text by word + count, then call ``TreeDebater._analyze_statement`` on each flushed text batch. + + Configure ``statement_side`` as the **speaker** in the audio (typically the + opponent of ``debater.side``). Read updated trees from ``self.debater`` anytime. + + Intended for one ``log_id`` stream (e.g. ``speaker_chunk001.mp3``); multiple + log_ids are supported with independent buffers like ``chunked_pipeline``. + """ + + def __init__(self, debater: "TreeDebater", config: StreamingInputConfig) -> None: + self.debater = debater + self.config = config + self._tree_lock = threading.Lock() + self._stop = threading.Event() + + if not debater.use_debate_flow_tree: + logger.warning( + "[StreamingInputEnv] use_debate_flow_tree is False; " + "_analyze_statement will no-op. Enable debate flow tree for updates." + ) + + def stop(self) -> None: + self._stop.set() + + def _flush_text(self, chunks: List[str]) -> None: + merged = " ".join(chunks).strip() + if not merged: + return + self.debater.status = self.config.stage + wc = len(merged.split()) + logger.debug(f"[StreamingInputEnv] tree_update_start words={wc} text_preview={merged[:50]}... t={time.time():.3f}") + tree_start = time.time() + with self._tree_lock: + self.debater._analyze_statement(merged, self.config.statement_side) + tree_end = time.time() + logger.debug(f"[StreamingInputEnv] tree_update_end words={wc} update_time={tree_end - tree_start:.3f}s t={tree_end:.3f}") + logger.info( + f"[StreamingInputEnv] analyze_statement flush: side={self.config.statement_side} " + f"stage={self.config.stage} words={wc} chunks={len(chunks)}" + ) + + def run(self) -> None: + """Block until :meth:`stop` or ``KeyboardInterrupt``.""" + cfg = self.config + watch_dir = Path(cfg.watch_dir) + watch_dir.mkdir(parents=True, exist_ok=True) + + min_audio_ms = int(cfg.min_audio_seconds * 1000) + cap_audio_ms = ( + int(cfg.max_total_audio_seconds * 1000) + if cfg.max_total_audio_seconds and cfg.max_total_audio_seconds > 0 + else None + ) + + seen: set = set() + audio_buf: Dict[str, List[Tuple[str, Any]]] = {} + audio_first_seen: Dict[str, float] = {} + emitted_ms: Dict[str, int] = {} + completed: set = set() + + text_buf: Dict[str, List[str]] = {} + text_first_seen: Dict[str, float] = {} + + use_cursor = cfg.playback_cursor is not None + continuous_file = watch_dir / "continuous_audio.mp3" if use_cursor else None + + logger.debug(f"[StreamingInputEnv] thread_start stage={cfg.stage} statement_side={cfg.statement_side} cursor_mode={use_cursor} t={time.time():.3f}") + logger.info( + f"[StreamingInputEnv] watching {watch_dir} ({cfg.audio_file_glob}), " + f"min_audio_s={cfg.min_audio_seconds}, min_text_words={cfg.min_text_words}, " + f"statement_side={cfg.statement_side}, stage={cfg.stage}, " + f"cursor_mode={use_cursor}" + ) + + try: + while not self._stop.is_set(): + if use_cursor and continuous_file and continuous_file.is_file(): + # Cursor mode: read from continuous audio up to cursor position + try: + available_audio = self._read_audio_up_to_cursor(continuous_file, cfg.playback_cursor[0]) + if available_audio is not None and len(available_audio) > 0: + self._process_cursor_audio( + available_audio, + emitted_ms, + min_audio_ms, + cap_audio_ms, + text_buf, + text_first_seen, + ) + except Exception as e: + logger.warning(f"[StreamingInputEnv] Cursor audio processing failed: {e}") + else: + # Chunk mode: original behavior + try: + paths = sorted(watch_dir.glob(cfg.audio_file_glob), key=lambda p: p.stat().st_mtime) + except Exception: + time.sleep(cfg.poll_interval) + continue + + for path in paths: + path_str = str(path.resolve()) + if path_str in seen: + continue + try: + seg = AudioSegment.from_file(path_str) + except Exception as e: + logger.warning(f"[StreamingInputEnv] Skip unreadable audio {path_str}: {e}") + continue + seen.add(path_str) + log_id = _log_id_from_filename(path.name) + if log_id in completed: + continue + audio_buf.setdefault(log_id, []).append((path_str, seg)) + if log_id not in audio_first_seen: + audio_first_seen[log_id] = time.perf_counter() + + for log_id in list(audio_buf.keys()): + self._process_one_log_audio( + log_id, + audio_buf, + audio_first_seen, + emitted_ms, + completed, + min_audio_ms, + cap_audio_ms, + text_buf, + text_first_seen, + ) + + self._idle_text_flush(text_buf, text_first_seen) + time.sleep(cfg.poll_interval) + except KeyboardInterrupt: + logger.info("[StreamingInputEnv] KeyboardInterrupt; stopping.") + finally: + for log_id in list(text_buf.keys()): + if text_buf.get(log_id): + self._flush_text(text_buf[log_id]) + text_buf[log_id] = [] + logger.debug(f"[StreamingInputEnv] thread_end stage={cfg.stage} statement_side={cfg.statement_side} t={time.time():.3f}") + + def _idle_text_flush(self, text_buf: Dict[str, List[str]], text_first_seen: Dict[str, float]) -> None: + cfg = self.config + if cfg.max_text_wait_seconds is None or cfg.max_text_wait_seconds <= 0: + return + now = time.perf_counter() + for log_id in list(text_buf.keys()): + if log_id not in text_first_seen or not text_buf.get(log_id): + continue + if now - text_first_seen[log_id] >= cfg.max_text_wait_seconds: + total_words = sum(len(s.split()) for s in text_buf[log_id]) + logger.info( + f"[StreamingInputEnv] Max text wait exceeded for log_id={log_id!r} " + f"(words={total_words}); flushing." + ) + self._flush_text(text_buf[log_id]) + text_buf[log_id] = [] + text_first_seen.pop(log_id, None) + + def _process_one_log_audio( + self, + log_id: str, + audio_buf: Dict[str, List[Tuple[str, Any]]], + audio_first_seen: Dict[str, float], + emitted_ms: Dict[str, int], + completed: set, + min_audio_ms: int, + cap_audio_ms: Optional[int], + text_buf: Dict[str, List[str]], + text_first_seen: Dict[str, float], + ) -> None: + cfg = self.config + acc_ms = 0 + to_emit: List[Tuple[str, Any]] = [] + while audio_buf.get(log_id) and acc_ms < min_audio_ms: + path_str, seg = audio_buf[log_id].pop(0) + to_emit.append((path_str, seg)) + acc_ms += len(seg) + + elapsed = None + if log_id in audio_first_seen: + elapsed = time.perf_counter() - audio_first_seen[log_id] + + should_emit = False + wait_time_exceeded = False + total_so_far_ms = emitted_ms.get(log_id, 0) + + if to_emit and acc_ms >= min_audio_ms: + should_emit = True + elif to_emit and cfg.max_audio_wait_seconds and cfg.max_audio_wait_seconds > 0: + if elapsed is not None and elapsed >= cfg.max_audio_wait_seconds: + should_emit = True + wait_time_exceeded = True + logger.info( + f"[StreamingInputEnv] Max audio wait exceeded for log_id={log_id!r}; " + f"emitting partial audio {acc_ms / 1000.0:.2f}s." + ) + + if should_emit and cap_audio_ms is not None: + if total_so_far_ms >= cap_audio_ms: + should_emit = False + audio_buf[log_id] = [] + completed.add(log_id) + audio_first_seen.pop(log_id, None) + logger.info(f"[StreamingInputEnv] Total audio cap already reached for log_id={log_id!r}; dropping rest.") + elif total_so_far_ms + acc_ms > cap_audio_ms: + logger.info(f"[StreamingInputEnv] Total audio cap crossed for log_id={log_id!r} on this batch.") + + if not should_emit: + for item in reversed(to_emit): + audio_buf.setdefault(log_id, []).insert(0, item) + if not audio_buf.get(log_id): + del audio_buf[log_id] + audio_first_seen.pop(log_id, None) + return + + combined = to_emit[0][1] + for _, s in to_emit[1:]: + combined += s + + force_flush = wait_time_exceeded or ( + cap_audio_ms is not None and (total_so_far_ms + acc_ms) >= cap_audio_ms + ) + + try: + text = transcribe_audio_segment(combined, audio_format=cfg.audio_format) + except Exception as e: + logger.warning(f"[StreamingInputEnv] Transcription failed for log_id={log_id!r}: {e}") + for item in reversed(to_emit): + audio_buf.setdefault(log_id, []).insert(0, item) + return + + emitted_ms[log_id] = total_so_far_ms + acc_ms + logger.info( + f"[StreamingInputEnv] Transcribed log_id={log_id!r} batch_s={len(combined) / 1000.0:.2f} " + f"text_words={len(text.split())} force_flush={force_flush}" + ) + + if cap_audio_ms is not None and emitted_ms[log_id] >= cap_audio_ms: + completed.add(log_id) + logger.info(f"[StreamingInputEnv] Total audio cap reached for log_id={log_id!r}; future audio ignored.") + + self._append_transcript_text(log_id, text, force_flush, text_buf, text_first_seen) + + if not audio_buf.get(log_id): + del audio_buf[log_id] + audio_first_seen.pop(log_id, None) + + def _append_transcript_text( + self, + log_id: str, + text: str, + force_flush: bool, + text_buf: Dict[str, List[str]], + text_first_seen: Dict[str, float], + ) -> None: + cfg = self.config + if not text.strip(): + return + text_buf.setdefault(log_id, []).append(text) + if log_id not in text_first_seen: + text_first_seen[log_id] = time.perf_counter() + + total_words = sum(len(s.split()) for s in text_buf[log_id]) + elapsed = time.perf_counter() - text_first_seen[log_id] + should_flush = False + if total_words >= cfg.min_text_words: + should_flush = True + elif cfg.max_text_wait_seconds and cfg.max_text_wait_seconds > 0 and elapsed >= cfg.max_text_wait_seconds: + should_flush = True + logger.info(f"[StreamingInputEnv] Max text wait exceeded for log_id={log_id!r}; flushing ({total_words} words).") + if force_flush: + should_flush = True + logger.info(f"[StreamingInputEnv] Force flush for log_id={log_id!r} ({total_words} words).") + + if should_flush: + self._flush_text(text_buf[log_id]) + text_buf[log_id] = [] + text_first_seen.pop(log_id, None) + + def _read_audio_up_to_cursor(self, continuous_file: Path, cursor_seconds: float) -> Optional[AudioSegment]: + """Read continuous audio file up to the cursor position.""" + try: + read_start = time.time() + full_audio = AudioSegment.from_file(str(continuous_file)) + cursor_ms = int(cursor_seconds * 1000) + if cursor_ms <= 0: + return None + result = full_audio[:cursor_ms] + read_end = time.time() + logger.debug(f"[StreamingInputEnv] file_read cursor={cursor_seconds:.2f}s " + f"available={len(full_audio)/1000.0:.2f}s read_time={read_end - read_start:.3f}s t={read_end:.3f}") + return result + except Exception as e: + logger.warning(f"[StreamingInputEnv] Failed to read continuous audio: {e}") + return None + + def _process_cursor_audio( + self, + available_audio: AudioSegment, + emitted_ms: Dict[str, int], + min_audio_ms: int, + cap_audio_ms: Optional[int], + text_buf: Dict[str, List[str]], + text_first_seen: Dict[str, float], + ) -> None: + """Process audio from cursor-based continuous file.""" + cfg = self.config + log_id = "continuous" + + available_ms = len(available_audio) + already_processed_ms = emitted_ms.get(log_id, 0) + + # Check if there's new audio to process + if available_ms <= already_processed_ms: + return + + # Check cap + if cap_audio_ms is not None and already_processed_ms >= cap_audio_ms: + return + + # Extract new audio segment + new_audio = available_audio[already_processed_ms:] + new_audio_ms = len(new_audio) + + # Check if we have enough audio to process + if new_audio_ms < min_audio_ms: + logger.debug(f"[StreamingInputEnv] wait_audio_accumulation available={new_audio_ms/1000.0:.2f}s " + f"need={min_audio_ms/1000.0:.2f}s t={time.time():.3f}") + return + + # Decide how much to process (in chunks of min_audio_ms) + to_process_ms = (new_audio_ms // min_audio_ms) * min_audio_ms + if to_process_ms == 0: + return + + # Apply cap if needed + if cap_audio_ms is not None: + remaining_cap = cap_audio_ms - already_processed_ms + to_process_ms = min(to_process_ms, remaining_cap) + + segment_to_transcribe = new_audio[:to_process_ms] + audio_start_sec = already_processed_ms / 1000.0 + audio_end_sec = (already_processed_ms + to_process_ms) / 1000.0 + + logger.debug(f"[StreamingInputEnv] asr_start audio_range={audio_start_sec:.2f}-{audio_end_sec:.2f}s t={time.time():.3f}") + asr_start = time.time() + try: + text = transcribe_audio_segment(segment_to_transcribe, audio_format=cfg.audio_format) + asr_end = time.time() + asr_duration = asr_end - asr_start + audio_duration = to_process_ms / 1000.0 + logger.debug(f"[StreamingInputEnv] asr_end audio_range={audio_start_sec:.2f}-{audio_end_sec:.2f}s " + f"text_len={len(text)} asr_time={asr_duration:.3f}s t={asr_end:.3f}") + + emitted_ms[log_id] = already_processed_ms + to_process_ms + logger.info( + f"[StreamingInputEnv] Cursor transcribed s={len(segment_to_transcribe) / 1000.0:.2f} " + f"words={len(text.split())} total_processed={emitted_ms[log_id] / 1000.0:.2f}s" + ) + + force_flush = cap_audio_ms is not None and emitted_ms[log_id] >= cap_audio_ms + self._append_transcript_text(log_id, text, force_flush, text_buf, text_first_seen) + except Exception as e: + logger.warning(f"[StreamingInputEnv] Cursor transcription failed: {e}") + + +def load_treedebater_for_side(config_path: Path, side: str) -> Tuple[dict, "TreeDebater"]: + from agents import DebaterConfig + from ouragents import TreeDebater + + with config_path.open("r", encoding="utf-8") as f: + full = yaml.load(f, Loader=yaml.FullLoader) + debater_cfgs = full["debater"] + chosen = None + for d in debater_cfgs: + if d.get("side") == side and d.get("type") == "treedebater": + chosen = d + break + if chosen is None: + raise ValueError( + f"No treedebater entry with side={side!r} in {config_path}. " + "Streaming tree updates require a TreeDebater config for that side." + ) + motion = full["env"]["motion"] + debater = TreeDebater(DebaterConfig(**chosen), motion=motion) + return full, debater + + +def tts_outputs_dir_from_log() -> Path: + """Resolved ``log_files/_outputs`` for the current debate log session.""" + from utils.tool import log_file_path + + if not log_file_path: + return Path("log_files") / "1_outputs" + lp = Path(log_file_path) + return (lp.parent / f"{lp.stem}_outputs").resolve() + + +def default_watch_root_from_log() -> Path: + """Resolved ``log_files/_watch`` for per-turn streaming watch directories.""" + from utils.tool import log_file_path + + if not log_file_path: + return Path("log_files") / "1_watch" + lp = Path(log_file_path) + return (lp.parent / f"{lp.stem}_watch").resolve() + + +# Backward-compatible names +_tts_outputs_dir_from_log = tts_outputs_dir_from_log +_default_watch_root_from_log = default_watch_root_from_log + + +class StreamingDebateEnv: + """ + Full debate (``Env``) plus streaming ASR for the **listener** on each speech turn. + + When the opponent is a ``TreeDebater``, starts ``StreamingInputEnv`` before the speaker + generates; after TTS writes ``{type}_{stage}_{side}.mp3`` under ``log_files/N_outputs``, + splits that file into timed chunks and streams them into a per-turn watch dir so the + listener updates trees incrementally while the debate proceeds. + """ + + def __init__( + self, + env_config, + debug: bool, + watch_root: Optional[Path] = None, + *, + min_audio_seconds: float = 30.0, + min_text_words: int = 50, + poll_interval: float = 1.0, + audio_format: str = "mp3", + split_mode: str = "fixed", + chunk_seconds: float = 10.0, + silence_window_seconds: float = 0.7, + max_audio_wait_seconds: float = 0.0, + max_text_wait_seconds: float = 0.0, + max_total_audio_seconds: float = 0.0, + listener_join_timeout: float = 300.0, + min_playback_increment: float = 3.0, + ) -> None: + from env import Env + + self._env = Env(env_config, debug) + self._listener_join_timeout = listener_join_timeout + self._watch_root = Path(watch_root).resolve() if watch_root else default_watch_root_from_log() + self._watch_root.mkdir(parents=True, exist_ok=True) + self._min_audio_seconds = min_audio_seconds + self._min_text_words = min_text_words + self._poll_interval = poll_interval + self._audio_format = audio_format + self._split_mode = split_mode + self._chunk_seconds = chunk_seconds + self._silence_window_seconds = silence_window_seconds + self._max_audio_wait = max_audio_wait_seconds if max_audio_wait_seconds > 0 else None + self._max_text_wait = max_text_wait_seconds if max_text_wait_seconds > 0 else None + self._max_total_audio = max_total_audio_seconds if max_total_audio_seconds > 0 else None + self._min_playback_increment = min_playback_increment + + def __getattr__(self, name: str): + return getattr(self._env, name) + + def _clear_dir(self, d: Path) -> None: + if not d.exists(): + return + for p in d.iterdir(): + try: + if p.is_file(): + p.unlink() + except OSError: + pass + + def _play_speech_turn(self, stage_key: str, side: str, max_time: float, generate_fn: Callable[[], str]) -> None: + from .chunk_audio import split_audio, stream_chunks_to_directory + + listener = opponent_side(side) + player = self._env.debaters[side] + listener_deb = self._env.debaters[listener] + + if listener_deb.type != "treedebater": + response = generate_fn() + self._env.debate_process.append({"stage": stage_key, "side": side, "content": response}) + return + + turn_watch = self._watch_root / f"{stage_key}_{side}" + turn_watch.mkdir(parents=True, exist_ok=True) + self._clear_dir(turn_watch) + + sic = StreamingInputConfig( + watch_dir=turn_watch, + motion=self._env.motion, + stage=stage_key, + statement_side=side, + min_audio_seconds=self._min_audio_seconds, + min_text_words=self._min_text_words, + poll_interval=self._poll_interval, + audio_file_glob=f"*.{self._audio_format}", + max_audio_wait_seconds=self._max_audio_wait, + max_text_wait_seconds=self._max_text_wait, + max_total_audio_seconds=self._max_total_audio, + audio_format=self._audio_format, + ) + sin_env = StreamingInputEnv(listener_deb, sic) + listener_thread = threading.Thread(target=sin_env.run, name="StreamingInputEnv", daemon=True) + listener_thread.start() + time.sleep(max(0.5, self._poll_interval)) + + response: Optional[str] = None + try: + response = generate_fn() + finally: + try: + mp3_path = tts_outputs_dir_from_log() / f"{player.config.type}_{player.status}_{player.side}.mp3" + if mp3_path.is_file() and mp3_path.stat().st_size > 2048: + audio = AudioSegment.from_file(str(mp3_path)) + if self._split_mode == "fixed": + chunks = split_audio(audio, mode="fixed", time_seconds=self._chunk_seconds) + else: + chunks = split_audio( + audio, mode="silence", time_seconds=self._silence_window_seconds + ) + logger.info( + f"[StreamingDebateEnv] Streaming {len(chunks)} chunk(s) from {mp3_path.name} → {turn_watch} " + f"(listener={listener!r})" + ) + stream_chunks_to_directory( + chunks, + turn_watch, + side, + audio_format=self._audio_format, + dry_run=False, + max_total_seconds=None, + chunk_index_start=1, + ) + else: + logger.warning(f"[StreamingDebateEnv] No TTS MP3 at {mp3_path} (skip chunk stream for listener).") + except Exception as e: + logger.error(f"[StreamingDebateEnv] Chunk stream failed: {e}") + sin_env.stop() + listener_thread.join(timeout=self._listener_join_timeout) + + if response is not None: + self._env.debate_process.append({"stage": stage_key, "side": side, "content": response}) + + def play(self, pre_only: bool = False) -> None: + order = ["for", "against"] if not self._env.reverse else ["against", "for"] + for stage in ["preparation", "opening", "rebuttal", "closing"]: + logger.info(f"[{stage}] Start") + if stage == "preparation": + for side in order: + if self._env.debaters[side].type in ["treedebater"]: + self._env.debaters[side].claim_generation(self._env.claim_pool_size, temperature=1) + elif stage == "opening": + for side in order: + + def _gen_opening(side=side): + return self._env.debaters[side].opening_generation( + history=self._env.debate_process[1:], + max_time=OPENING_TIME, + time_control=self._env.time_control, + streaming_tts=getattr(self._env.debaters[side].config, "streaming_tts", False), + ) + + self._play_speech_turn("opening", side, OPENING_TIME, _gen_opening) + elif stage == "rebuttal": + for side in order: + + def _gen_rebuttal(side=side): + return self._env.debaters[side].rebuttal_generation( + history=self._env.debate_process[1:], + max_time=REBUTTAL_TIME, + time_control=self._env.time_control, + streaming_tts=getattr(self._env.debaters[side].config, "streaming_tts", False), + ) + + self._play_speech_turn("rebuttal", side, REBUTTAL_TIME, _gen_rebuttal) + elif stage == "closing": + for side in order: + + def _gen_closing(side=side): + return self._env.debaters[side].closing_generation( + history=self._env.debate_process[1:], + max_time=CLOSING_TIME, + time_control=self._env.time_control, + streaming_tts=getattr(self._env.debaters[side].config, "streaming_tts", False), + ) + + self._play_speech_turn("closing", side, CLOSING_TIME, _gen_closing) + logger.info(f"[{stage}] Done") + if self._env.debug: + if input("Press N to stop: ").lower() == "n": + break + + +__all__ = [ + "transcribe_audio_segment", + "opponent_side", + "StreamingInputConfig", + "StreamingInputEnv", + "load_treedebater_for_side", + "tts_outputs_dir_from_log", + "default_watch_root_from_log", + "_tts_outputs_dir_from_log", + "_default_watch_root_from_log", + "StreamingDebateEnv", + "parse_args", + "main", +] + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Watch-only: stream ASR on a folder. --debate: full Env debate + streaming listener per turn." + ) + p.add_argument("--config", type=str, required=True, help="YAML config (same layout as env.py).") + p.add_argument( + "--debate", + action="store_true", + help="Run full debate (like env.py) with StreamingInputEnv on opponent TreeDebater each speech turn.", + ) + p.add_argument( + "--watch-dir", + type=str, + default=None, + help="Watch-only: required chunk directory. Debate mode: optional root for per-turn subdirs (default: log_files/_watch).", + ) + p.add_argument( + "--debater-side", + type=str, + choices=["for", "against"], + default=None, + help="Watch-only: load treedebater for this listener side. Not used with --debate.", + ) + p.add_argument("--debug", action="store_true", default=False, help="Debate mode: pause prompts like env.py.") + p.add_argument( + "--stage", + type=str, + choices=["opening", "rebuttal", "closing"], + default="opening", + help="Watch-only: stage label for extract_statement.", + ) + p.add_argument("--min-audio-seconds", type=float, default=30.0) + p.add_argument("--min-text-words", type=int, default=50) + p.add_argument("--poll-interval", type=float, default=1.0) + p.add_argument("--audio-glob", type=str, default="*.mp3") + p.add_argument("--audio-format", type=str, default="mp3") + p.add_argument("--split-mode", type=str, choices=["fixed", "silence"], default="fixed") + p.add_argument("--chunk-seconds", type=float, default=10.0) + p.add_argument("--silence-window-seconds", type=float, default=0.7) + p.add_argument("--max-audio-wait-seconds", type=float, default=0.0) + p.add_argument("--max-text-wait-seconds", type=float, default=0.0) + p.add_argument("--max-total-audio-seconds", type=float, default=0.0) + p.add_argument("--listener-join-timeout", type=float, default=300.0, help="Debate mode: max seconds to join listener after each turn.") + p.add_argument("--min-playback-increment", type=float, default=3.0, help="Minimum playback increment in seconds (env-level cursor update frequency).") + return p.parse_args() + + +def main() -> None: + args = parse_args() + config_path = Path(args.config).resolve() + if not config_path.is_file(): + raise FileNotFoundError(config_path) + + if args.debate: + from agents import AudienceConfig, DebaterConfig, JudgeConfig + from env import EnvConfig + + with open(config_path, "r", encoding="utf-8") as f: + full_config = yaml.load(f, Loader=yaml.FullLoader) + logger.info(f"Config: {full_config}") + env_config = EnvConfig( + debater_config=[DebaterConfig(**c) for c in full_config["debater"]], + judge_config=JudgeConfig(**full_config["judge"]), + audience_config=AudienceConfig(**full_config["audience"]), + **full_config["env"], + ) + watch_root = Path(args.watch_dir).resolve() if args.watch_dir else None + sde = StreamingDebateEnv( + env_config, + args.debug, + watch_root, + min_audio_seconds=args.min_audio_seconds, + min_text_words=args.min_text_words, + poll_interval=args.poll_interval, + audio_format=args.audio_format, + split_mode=args.split_mode, + chunk_seconds=args.chunk_seconds, + silence_window_seconds=args.silence_window_seconds, + max_audio_wait_seconds=args.max_audio_wait_seconds, + max_text_wait_seconds=args.max_text_wait_seconds, + max_total_audio_seconds=args.max_total_audio_seconds, + listener_join_timeout=args.listener_join_timeout, + min_playback_increment=args.min_playback_increment, + ) + sde.play() + log_file = logger.handlers[0].baseFilename + save_file = log_file.replace(".log", ".json") + logger.info(f"Saving to {save_file}") + record = { + "motion": sde.motion, + "config": full_config, + "debate_process": sde.debate_process[1:], + "debate_thoughts": { + "for": sde.debaters["for"].debate_thoughts, + "against": sde.debaters["against"].debate_thoughts, + }, + "debate_tree": { + "for": [ + ( + sde.debaters["for"].debate_tree.get_tree_info() + if sde.debaters["for"].type in ["treedebater"] + else {} + ), + ( + sde.debaters["for"].oppo_debate_tree.get_tree_info() + if sde.debaters["for"].type in ["treedebater"] + else {} + ), + ], + "against": [ + ( + sde.debaters["against"].debate_tree.get_tree_info() + if sde.debaters["against"].type in ["treedebater"] + else {} + ), + ( + sde.debaters["against"].oppo_debate_tree.get_tree_info() + if sde.debaters["against"].type in ["treedebater"] + else {} + ), + ], + }, + "conversation": { + "for": sde.debaters["for"].conversation, + "against": sde.debaters["against"].conversation, + }, + } + json.dump(record, open(save_file, "w"), indent=2) + if not args.debug: + evaluation, side_into = sde.eval() + logger.info(f"Result: {evaluation}") + record.update({"evaluation": evaluation, "eval_side_info": side_into}) + json.dump(record, open(save_file, "w"), indent=2) + return + + if not args.watch_dir or not args.debater_side: + raise SystemExit("Watch-only mode requires --watch-dir and --debater-side (or use --debate).") + + debater_side = args.debater_side + _, debater = load_treedebater_for_side(config_path, debater_side) + if debater.side != debater_side: + raise RuntimeError("Debater side mismatch after load.") + + statement_side = opponent_side(debater_side) + logger.info( + f"[StreamingInputEnv] TreeDebater (listener) side={debater.side!r}; " + f"chunk audio is speaker side={statement_side!r}." + ) + + sic = StreamingInputConfig( + watch_dir=Path(args.watch_dir).resolve(), + motion=debater.motion, + stage=args.stage, + statement_side=statement_side, + min_audio_seconds=args.min_audio_seconds, + min_text_words=args.min_text_words, + poll_interval=args.poll_interval, + audio_file_glob=args.audio_glob, + max_audio_wait_seconds=args.max_audio_wait_seconds if args.max_audio_wait_seconds > 0 else None, + max_text_wait_seconds=args.max_text_wait_seconds if args.max_text_wait_seconds > 0 else None, + max_total_audio_seconds=args.max_total_audio_seconds if args.max_total_audio_seconds > 0 else None, + audio_format=args.audio_format, + ) + + env = StreamingInputEnv(debater, sic) + env.run() + + +if __name__ == "__main__": + # Run as: ``python -m streaming.env`` from ``TreeDebater/src`` (or with ``src`` on ``PYTHONPATH``). + src_dir = Path(__file__).resolve().parent.parent + root = src_dir.parent + for p in (src_dir, root): + ps = str(p) + if ps not in sys.path: + sys.path.insert(0, ps) + main() diff --git a/src/streaming/overlap.py b/src/streaming/overlap.py new file mode 100644 index 0000000..2e483ea --- /dev/null +++ b/src/streaming/overlap.py @@ -0,0 +1,484 @@ +""" +Full debate with **true overlap** and a **playback-driven main thread**. + +The **main thread** only waits for stable ``{side}_chunkNNN.mp3`` files in the turn watch +directory and simulates speaking (sleeps for each chunk's duration). The **speaker** runs +``generate_fn()`` (LLM + refinements + ``streaming_tts``) on a **background thread**, with +the TTS chunk **bridge** started/stopped in that thread. The **listener** runs +:class:`StreamingInputEnv` via :meth:`TreeDebater.start_streaming_listen` on another +background thread, so ASR and tree updates can proceed while chunks appear. + +``streaming_listen: true`` on the listener enables ``StreamingInputEnv``; if false, only +the speaker worker + main playback run (no duplicate tree ingest). + +Run from ``TreeDebater/src``:: + + python -m streaming.overlap --config configs/overlap_debate.yml +""" + +from __future__ import annotations + +import argparse +import json +import sys +import threading +import time +from pathlib import Path +from typing import Callable, List, Optional + +import yaml +from pydub import AudioSegment + +from .bridges import run_streaming_tts_chunk_copy_bridge +from .env import StreamingDebateEnv, opponent_side, tts_outputs_dir_from_log +from utils.tool import logger + + +class OverlappingStreamingDebateEnv(StreamingDebateEnv): + """ + Extends :class:`StreamingDebateEnv` with playback-driven main: simulate debate delivery + on the main thread while speaker (LLM + TTS + bridge) and optional listener ASR run in + background threads. + """ + + @staticmethod + def _playback_chunk_ready(path: Path, min_size: int, poll_interval: float) -> bool: + if not path.is_file(): + return False + try: + s1 = path.stat().st_size + except OSError: + return False + if s1 < min_size: + return False + time.sleep(poll_interval) + try: + s2 = path.stat().st_size + except OSError: + return False + return s1 == s2 + + def _playback_main_loop( + self, + turn_watch: Path, + side: str, + stage: str, + speaker_done: threading.Event, + error_holder: List[Optional[BaseException]], + done_ts: List[Optional[float]], + playback_cursor: List[float], + *, + grace_seconds: float = 4.0, + ) -> None: + """Main thread: assemble continuous audio and sleep in fixed increments, updating cursor.""" + ext = self._audio_format.lower().lstrip(".") + poll = min(0.2, self._poll_interval) + next_idx = 1 + t0 = time.monotonic() + deadline = t0 + max(600.0, self._listener_join_timeout * 6) + min_size = 512 + + continuous_audio = AudioSegment.empty() + continuous_file = turn_watch / "continuous_audio.mp3" + + logger.debug( + f"[PlaybackMain] playback_start stage={stage} side={side} t={time.time():.3f}" + ) + + while time.monotonic() < deadline: + if error_holder[0] is not None: + logger.info("[PlaybackMain] stopping playback due to speaker thread error.") + return + + path = turn_watch / f"{side}_chunk{next_idx:03d}.{ext}" + should_stop = False + + # Start waiting for chunk + wait_start = time.time() + logger.debug( + f"[PlaybackMain] wait_chunk_start stage={stage} side={side} chunk_idx={next_idx} " + f"cursor={playback_cursor[0]:.2f}s t={wait_start:.3f}" + ) + + while time.monotonic() < deadline: + if error_holder[0] is not None: + return + if self._playback_chunk_ready(path, min_size, poll): + break + if speaker_done.is_set() and done_ts[0] is not None: + try: + chunk_ok = path.is_file() and path.stat().st_size >= min_size + except OSError: + chunk_ok = False + if not chunk_ok and time.monotonic() - done_ts[0] >= grace_seconds: + should_stop = True + break + time.sleep(poll) + else: + logger.warning("[PlaybackMain] deadline waiting for next chunk.") + return + + # End waiting for chunk + wait_end = time.time() + wait_duration = wait_end - wait_start + if not should_stop: + logger.debug( + f"[PlaybackMain] wait_chunk_end stage={stage} side={side} chunk_idx={next_idx} " + f"wait_time={wait_duration:.3f}s t={wait_end:.3f}" + ) + + if should_stop: + logger.info(f"[PlaybackMain] no further chunks after index {next_idx - 1}; done.") + logger.debug( + f"[PlaybackMain] playback_end stage={stage} side={side} " + f"cursor={playback_cursor[0]:.2f}s t={time.time():.3f}" + ) + return + + try: + seg = AudioSegment.from_file(str(path)) + chunk_duration = len(seg) / 1000.0 + except Exception as e: + logger.warning(f"[PlaybackMain] skip unreadable {path}: {e}") + next_idx += 1 + continue + + # Assemble into continuous audio + assemble_start = time.time() + continuous_audio += seg + logger.debug( + f"[PlaybackMain] chunk_assembled stage={stage} side={side} chunk_idx={next_idx} " + f"duration={chunk_duration:.2f}s total_audio={len(continuous_audio)/1000.0:.2f}s t={time.time():.3f}" + ) + + try: + write_start = time.time() + continuous_audio.export(str(continuous_file), format=ext) + write_end = time.time() + logger.debug( + f"[PlaybackMain] file_write stage={stage} side={side} chunk_idx={next_idx} " + f"size_sec={len(continuous_audio)/1000.0:.2f}s write_time={write_end - write_start:.3f}s t={time.time():.3f}" + ) + except Exception as e: + logger.warning(f"[PlaybackMain] failed to export continuous audio: {e}") + + # Start playback of this chunk + playback_start = time.time() + logger.debug( + f"[PlaybackMain] chunk_playback_start stage={stage} side={side} chunk_idx={next_idx} " + f"duration={chunk_duration:.2f}s t={playback_start:.3f}" + ) + + # Simulate speaking in fixed increments + elapsed_in_chunk = 0.0 + while elapsed_in_chunk < chunk_duration: + sleep_time = min(self._min_playback_increment, chunk_duration - elapsed_in_chunk) + time.sleep(sleep_time) + + # Advance cursor - listener can now access this audio + playback_cursor[0] += sleep_time + elapsed_in_chunk += sleep_time + logger.debug(f"[PlaybackMain] cursor={playback_cursor[0]:.1f}s (chunk {next_idx}, elapsed={elapsed_in_chunk:.1f}s/{chunk_duration:.1f}s)") + + # End playback of this chunk + logger.debug( + f"[PlaybackMain] chunk_playback_end stage={stage} side={side} chunk_idx={next_idx} " + f"cursor={playback_cursor[0]:.2f}s t={time.time():.3f}" + ) + + next_idx += 1 + + logger.debug( + f"[PlaybackMain] playback_end stage={stage} side={side} cursor={playback_cursor[0]:.2f}s t={time.time():.3f}" + ) + + def _play_speech_turn(self, stage_key: str, side: str, max_time: float, generate_fn: Callable[[], str]) -> None: + from .chunk_audio import split_audio, stream_chunks_to_directory + + listener = opponent_side(side) + player = self._env.debaters[side] + listener_deb = self._env.debaters[listener] + + logger.debug(f"[Turn] turn_start stage={stage_key} side={side} t={time.time():.3f}") + + if listener_deb.type != "treedebater": + response = generate_fn() + self._env.debate_process.append({"stage": stage_key, "side": side, "content": response}) + logger.debug(f"[Turn] turn_end stage={stage_key} side={side} t={time.time():.3f}") + return + + use_streaming_listen = getattr(listener_deb.config, "streaming_listen", False) + use_streaming_tts = self._env.time_control and getattr(player.config, "streaming_tts", False) + + # Log mode configuration + mode = f"tts={'stream' if use_streaming_tts else 'batch'}_listen={'stream' if use_streaming_listen else 'batch'}" + logger.debug(f"[Turn] mode_config stage={stage_key} side={side} streaming_tts={use_streaming_tts} " + f"streaming_listen={use_streaming_listen} mode={mode} t={time.time():.3f}") + + turn_watch = self._watch_root / f"{stage_key}_{side}" + turn_watch.mkdir(parents=True, exist_ok=True) + self._clear_dir(turn_watch) + + # Create playback cursor for synchronization + playback_cursor = [0.0] # seconds of audio "played" so far + + if use_streaming_listen: + listener_deb.start_streaming_listen( + turn_watch, + stage_key, + min_audio_seconds=self._min_audio_seconds, + min_text_words=self._min_text_words, + poll_interval=self._poll_interval, + audio_format=self._audio_format, + max_audio_wait_seconds=self._max_audio_wait, + max_text_wait_seconds=self._max_text_wait, + max_total_audio_seconds=self._max_total_audio, + playback_cursor=playback_cursor, + ) + time.sleep(max(0.5, self._poll_interval)) + + use_live_bridge = self._env.time_control and getattr(player.config, "streaming_tts", False) + + response_holder: List[Optional[str]] = [None] + error_holder: List[Optional[BaseException]] = [None] + live_counts: List[int] = [0] + speaker_done = threading.Event() + done_ts: List[Optional[float]] = [None] + + def speaker_worker() -> None: + logger.debug(f"[SpeakerWorker] thread_start stage={stage_key} side={side} t={time.time():.3f}") + bridge_stop = threading.Event() + bridge_thread: Optional[threading.Thread] = None + if use_live_bridge: + chunks_dir = tts_outputs_dir_from_log() / f"{player.config.type}_{stage_key}_{side}_chunks" + bridge_thread = threading.Thread( + target=run_streaming_tts_chunk_copy_bridge, + name="TtsChunkBridge", + args=(chunks_dir, turn_watch, side, bridge_stop, live_counts), + kwargs={ + "audio_format": self._audio_format, + "poll_interval": min(0.5, self._poll_interval), + }, + daemon=True, + ) + bridge_thread.start() + try: + logger.debug(f"[SpeakerWorker] generation_start stage={stage_key} side={side} t={time.time():.3f}") + response_holder[0] = generate_fn() + logger.debug(f"[SpeakerWorker] generation_end stage={stage_key} side={side} response_len={len(response_holder[0]) if response_holder[0] else 0} t={time.time():.3f}") + except BaseException as e: + error_holder[0] = e + logger.exception("[OverlappingStreamingDebateEnv] Speaker worker failed") + finally: + bridge_stop.set() + if bridge_thread is not None: + bridge_thread.join(timeout=self._listener_join_timeout) + + if error_holder[0] is None: + try: + skip_posthoc = use_live_bridge and live_counts[0] > 0 + mp3_path = tts_outputs_dir_from_log() / f"{player.config.type}_{stage_key}_{player.side}.mp3" + if not skip_posthoc and mp3_path.is_file() and mp3_path.stat().st_size > 2048: + logger.debug(f"[SpeakerWorker] posthoc_chunk_start mode=batch_tts mp3_path={mp3_path.name} t={time.time():.3f}") + audio = AudioSegment.from_file(str(mp3_path)) + if self._split_mode == "fixed": + chunks = split_audio(audio, mode="fixed", time_seconds=self._chunk_seconds) + else: + chunks = split_audio( + audio, mode="silence", time_seconds=self._silence_window_seconds + ) + logger.info( + f"[OverlappingStreamingDebateEnv] Post-hoc {len(chunks)} chunk(s) from " + f"{mp3_path.name} → {turn_watch} (realtime_pace=False)" + ) + logger.debug(f"[SpeakerWorker] posthoc_chunk_split audio_duration={len(audio)/1000.0:.2f}s " + f"num_chunks={len(chunks)} t={time.time():.3f}") + stream_start = time.time() + stream_chunks_to_directory( + chunks, + turn_watch, + side, + audio_format=self._audio_format, + dry_run=False, + max_total_seconds=None, + chunk_index_start=1, + realtime_pace=False, + ) + stream_end = time.time() + logger.debug(f"[SpeakerWorker] posthoc_chunk_end num_chunks={len(chunks)} " + f"stream_time={stream_end - stream_start:.3f}s t={stream_end:.3f}") + elif not skip_posthoc: + logger.warning( + f"[OverlappingStreamingDebateEnv] No TTS MP3 at {mp3_path} (skip chunk stream)." + ) + except Exception as e: + logger.error(f"[OverlappingStreamingDebateEnv] Post-hoc chunk stream failed: {e}") + + done_ts[0] = time.monotonic() + speaker_done.set() + logger.debug(f"[SpeakerWorker] thread_end stage={stage_key} side={side} t={time.time():.3f}") + + wt = threading.Thread(target=speaker_worker, name="SpeakerTurn", daemon=True) + wt.start() + + self._playback_main_loop( + turn_watch, + side, + stage_key, + speaker_done, + error_holder, + done_ts, + playback_cursor, + grace_seconds=4.0, + ) + + wt.join(timeout=self._listener_join_timeout) + if wt.is_alive(): + logger.warning("[OverlappingStreamingDebateEnv] Speaker thread still alive after join timeout.") + + if use_streaming_listen: + listener_deb.stop_streaming_listen(self._listener_join_timeout) + else: + # Non-streaming listener: process after playback completes + logger.debug(f"[NonStreamingListener] batch_listen_start stage={stage_key} side={listener} t={time.time():.3f}") + # The listener will process via regular listen() call in next turn + # Just log that we're in batch mode + logger.debug(f"[NonStreamingListener] batch_listen_mode stage={stage_key} side={listener} " + f"listener_will_process_in_next_turn=True t={time.time():.3f}") + + if error_holder[0] is not None: + err = error_holder[0] + if isinstance(err, Exception): + raise err + raise RuntimeError(str(err)) + + response = response_holder[0] + if response is not None: + rec: dict = {"stage": stage_key, "side": side, "content": response} + if use_streaming_listen: + rec["tree_via_streaming"] = True + self._env.debate_process.append(rec) + + logger.debug(f"[Turn] turn_end stage={stage_key} side={side} t={time.time():.3f}") + + +__all__ = ["OverlappingStreamingDebateEnv", "parse_args", "main"] + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Full debate with overlapping streaming listen (see module docstring)." + ) + p.add_argument("--config", type=str, required=True, help="YAML config (same layout as env.py).") + p.add_argument("--watch-dir", type=str, default=None, help="Optional root for per-turn watch subdirs.") + p.add_argument("--debug", action="store_true", default=False) + p.add_argument("--min-audio-seconds", type=float, default=15.0, help="Lower default than streaming.env for overlap.") + p.add_argument("--min-text-words", type=int, default=40) + p.add_argument("--poll-interval", type=float, default=1.0) + p.add_argument("--audio-format", type=str, default="mp3") + p.add_argument("--split-mode", type=str, choices=["fixed", "silence"], default="fixed") + p.add_argument("--chunk-seconds", type=float, default=10.0) + p.add_argument("--silence-window-seconds", type=float, default=0.7) + p.add_argument("--max-audio-wait-seconds", type=float, default=0.0) + p.add_argument("--max-text-wait-seconds", type=float, default=0.0) + p.add_argument("--max-total-audio-seconds", type=float, default=0.0) + p.add_argument("--listener-join-timeout", type=float, default=300.0) + p.add_argument("--min-playback-increment", type=float, default=3.0, help="Minimum playback increment in seconds (env-level cursor update frequency).") + return p.parse_args() + + +def main() -> None: + args = parse_args() + config_path = Path(args.config).resolve() + if not config_path.is_file(): + raise FileNotFoundError(config_path) + + from agents import AudienceConfig, DebaterConfig, JudgeConfig + from env import EnvConfig + + with open(config_path, "r", encoding="utf-8") as f: + full_config = yaml.load(f, Loader=yaml.FullLoader) + logger.info(f"Config: {full_config}") + env_config = EnvConfig( + debater_config=[DebaterConfig(**c) for c in full_config["debater"]], + judge_config=JudgeConfig(**full_config["judge"]), + audience_config=AudienceConfig(**full_config["audience"]), + **full_config["env"], + ) + watch_root = Path(args.watch_dir).resolve() if args.watch_dir else None + sde = OverlappingStreamingDebateEnv( + env_config, + args.debug, + watch_root, + min_audio_seconds=args.min_audio_seconds, + min_text_words=args.min_text_words, + poll_interval=args.poll_interval, + audio_format=args.audio_format, + split_mode=args.split_mode, + chunk_seconds=args.chunk_seconds, + silence_window_seconds=args.silence_window_seconds, + max_audio_wait_seconds=args.max_audio_wait_seconds, + max_text_wait_seconds=args.max_text_wait_seconds, + max_total_audio_seconds=args.max_total_audio_seconds, + listener_join_timeout=args.listener_join_timeout, + min_playback_increment=args.min_playback_increment, + ) + sde.play() + log_file = logger.handlers[0].baseFilename + save_file = log_file.replace(".log", ".json") + logger.info(f"Saving to {save_file}") + record = { + "motion": sde.motion, + "config": full_config, + "debate_process": sde.debate_process[1:], + "debate_thoughts": { + "for": sde.debaters["for"].debate_thoughts, + "against": sde.debaters["against"].debate_thoughts, + }, + "debate_tree": { + "for": [ + ( + sde.debaters["for"].debate_tree.get_tree_info() + if sde.debaters["for"].type in ["treedebater"] + else {} + ), + ( + sde.debaters["for"].oppo_debate_tree.get_tree_info() + if sde.debaters["for"].type in ["treedebater"] + else {} + ), + ], + "against": [ + ( + sde.debaters["against"].debate_tree.get_tree_info() + if sde.debaters["against"].type in ["treedebater"] + else {} + ), + ( + sde.debaters["against"].oppo_debate_tree.get_tree_info() + if sde.debaters["against"].type in ["treedebater"] + else {} + ), + ], + }, + "conversation": { + "for": sde.debaters["for"].conversation, + "against": sde.debaters["against"].conversation, + }, + } + json.dump(record, open(save_file, "w"), indent=2) + if not args.debug: + evaluation, side_into = sde.eval() + logger.info(f"Result: {evaluation}") + record.update({"evaluation": evaluation, "eval_side_info": side_into}) + json.dump(record, open(save_file, "w"), indent=2) + + +if __name__ == "__main__": + src_dir = Path(__file__).resolve().parent.parent + root = src_dir.parent + for p in (src_dir, root): + ps = str(p) + if ps not in sys.path: + sys.path.insert(0, ps) + main() diff --git a/src/streaming/run_listen_demo.py b/src/streaming/run_listen_demo.py new file mode 100644 index 0000000..2dfee96 --- /dev/null +++ b/src/streaming/run_listen_demo.py @@ -0,0 +1,288 @@ +""" +Live streaming demo (three roles): + +1. **Speaker** — ``env.py`` (or any TTS) writes ``log_files/N_outputs/{type}_{stage}_{side}.mp3`` + for the side that is speaking. +2. **Chunk bridge** — polls that folder for **new** speaker-side MP3s, waits until each file is + stable, splits into timed chunks, and writes ``{speaker}_chunkNNN.mp3`` into ``--watch-dir``. +3. **Listener** — ``StreamingInputEnv`` (TreeDebater for ``--debater-side``) watches + ``--watch-dir``, runs streaming ASR, and calls ``_analyze_statement`` with the speaker as + ``statement_side``. + +Start this script **before** or **while** running ``env.py`` so TTS drops new files into +``N_outputs``; the bridge feeds the listener in (approximate) real time. + +**Audio source** + +- Default: latest ``log_files/N_outputs`` (see ``streaming.bridges.infer_session_log_id``), + or ``--log-id`` / ``--outputs-dir``. +- ``--audio-file``: one-shot test (no bridge): stream a single file then stop (no live poll). + +Example (listener = ``for``, speaker = ``against``, run while ``env.py`` produces TTS):: + + python -m streaming.run_listen_demo \\ + --config configs/base_st.yml \\ + --watch-dir /tmp/td_stream_demo \\ + --debater-side for + +Requires ``OPENAI_API_KEY`` (Whisper) and keys for ``TreeDebater`` / ``HelperClient``. +""" + +from __future__ import annotations + +import argparse +import sys +import threading +import time +from pathlib import Path + +_src_dir = Path(__file__).resolve().parent.parent +for _p in (_src_dir, _src_dir.parent): + _ps = str(_p) + if _ps not in sys.path: + sys.path.insert(0, _ps) + +from pydub import AudioSegment + +from .bridges import default_outputs_dir, infer_session_log_id, run_live_chunk_bridge +from .chunk_audio import clear_watch_chunk_files, split_audio, stream_chunks_to_directory +from .env import StreamingInputConfig, StreamingInputEnv, load_treedebater_for_side, opponent_side + + +def _resolve_chunk_log_id(args: argparse.Namespace, speaker_side: str, audio_path: Path | None) -> str: + if args.log_id is not None: + return args.log_id + if audio_path is not None: + stem = audio_path.stem + return stem.split("_")[0] if "_" in stem else stem + return speaker_side + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Live bridge: speaker TTS → log_files/N_outputs → chunks → watch-dir; " + "listener TreeDebater runs streaming ASR on watch-dir." + ) + p.add_argument("--config", type=str, required=True, help="YAML config (same as env.py).") + p.add_argument( + "--log-id", + type=str, + default=None, + help="Session folder log_files/_outputs. If omitted, uses latest existing N_outputs.", + ) + p.add_argument( + "--outputs-dir", + type=str, + default=None, + help="Override TTS output directory (same naming as agents.post_process).", + ) + p.add_argument( + "--audio-file", + type=str, + default=None, + help="If set, run a one-shot chunk stream from this file (no live bridge on N_outputs).", + ) + p.add_argument("--watch-dir", type=str, required=True) + p.add_argument( + "--debater-side", + type=str, + choices=["for", "against"], + required=True, + help="TreeDebater / listener side in YAML (the agent that watches watch-dir and runs ASR).", + ) + p.add_argument( + "--speaker-side", + type=str, + choices=["for", "against"], + default=None, + help="Who is speaking into log_files/N_outputs (default: opponent of --debater-side).", + ) + p.add_argument("--stage", type=str, choices=["opening", "rebuttal", "closing"], default="opening") + p.add_argument("--split-mode", type=str, choices=["fixed", "silence"], default="fixed") + p.add_argument("--chunk-seconds", type=float, default=10.0) + p.add_argument("--silence-window-seconds", type=float, default=0.7) + p.add_argument("--audio-format", type=str, default="mp3") + p.add_argument("--min-audio-seconds", type=float, default=30.0) + p.add_argument("--min-text-words", type=int, default=50) + p.add_argument("--poll-interval", type=float, default=1.0) + p.add_argument("--bridge-stable-polls", type=int, default=2, help="Consecutive polls with same file size before loading.") + p.add_argument( + "--process-existing-speaker-mp3s", + action="store_true", + help="Also chunk speaker MP3s already in N_outputs when the bridge starts (default: only new files).", + ) + p.add_argument("--max-audio-wait-seconds", type=float, default=0.0) + p.add_argument("--max-text-wait-seconds", type=float, default=0.0) + p.add_argument("--max-total-audio-seconds", type=float, default=0.0) + p.add_argument("--max-total-seconds", type=float, default=None, help="Only for --audio-file one-shot cap.") + p.add_argument( + "--max-wall-seconds", + type=float, + default=0.0, + help="Live mode: stop bridge and listener after this many wall-clock seconds (0 = until Ctrl+C).", + ) + p.add_argument( + "--pipeline-timeout-seconds", + type=float, + default=90.0, + help="After stop, wait up to this long for the listener thread to finish draining.", + ) + p.add_argument("--no-clear-watch", action="store_true", help="Do not delete existing chunk files in watch-dir before run.") + return p.parse_args() + + +def _run_one_shot_audio_file( + args: argparse.Namespace, + watch_dir: Path, + debater, + statement_side: str, + chunk_log_id: str, + env: StreamingInputEnv, + worker: threading.Thread, +) -> None: + audio_path = Path(args.audio_file).resolve() + if not audio_path.is_file(): + raise FileNotFoundError(audio_path) + print(f"One-shot mode: loading {audio_path}...") + audio = AudioSegment.from_file(str(audio_path)) + if args.split_mode == "fixed": + chunks = split_audio(audio, mode="fixed", time_seconds=args.chunk_seconds) + else: + chunks = split_audio(audio, mode="silence", time_seconds=args.silence_window_seconds) + print(f"Streaming {len(chunks)} chunks to {watch_dir}...") + stream_chunks_to_directory( + chunks, + watch_dir, + chunk_log_id, + audio_format=args.audio_format, + dry_run=False, + max_total_seconds=args.max_total_seconds, + chunk_index_start=1, + ) + print("One-shot stream finished; draining listener...") + timeout = args.pipeline_timeout_seconds + if timeout and timeout > 0: + t0 = time.time() + while time.time() - t0 < timeout: + if not worker.is_alive(): + break + time.sleep(1.0) + env.stop() + worker.join(timeout=min(timeout, 30.0) if timeout else 30.0) + + +def main() -> None: + args = parse_args() + + watch_dir = Path(args.watch_dir).resolve() + watch_dir.mkdir(parents=True, exist_ok=True) + if not args.no_clear_watch: + clear_watch_chunk_files(watch_dir, args.audio_format) + + config_path = Path(args.config).resolve() + debater_side = args.debater_side + _, debater = load_treedebater_for_side(config_path, debater_side) + speaker_side = args.speaker_side or opponent_side(debater_side) + statement_side = speaker_side + chunk_log_id = speaker_side + + print( + f"Listener TreeDebater side={debater.side!r}; speaker (TTS → N_outputs) side={speaker_side!r}; " + f"chunks use log_id={chunk_log_id!r}." + ) + + audio_glob = f"*.{args.audio_format}" + sic = StreamingInputConfig( + watch_dir=watch_dir, + motion=debater.motion, + stage=args.stage, + statement_side=statement_side, + min_audio_seconds=args.min_audio_seconds, + min_text_words=args.min_text_words, + poll_interval=args.poll_interval, + audio_file_glob=audio_glob, + max_audio_wait_seconds=args.max_audio_wait_seconds if args.max_audio_wait_seconds > 0 else None, + max_text_wait_seconds=args.max_text_wait_seconds if args.max_text_wait_seconds > 0 else None, + max_total_audio_seconds=args.max_total_audio_seconds if args.max_total_audio_seconds > 0 else None, + audio_format=args.audio_format, + ) + + env = StreamingInputEnv(debater, sic) + listener = threading.Thread(target=env.run, name="StreamingInputEnv", daemon=False) + print(f"Starting listener on {watch_dir} (glob {audio_glob}).") + listener.start() + time.sleep(max(0.5, args.poll_interval)) + + stop_bridge = threading.Event() + + try: + if args.audio_file: + chunk_log_id = _resolve_chunk_log_id(args, speaker_side, Path(args.audio_file).resolve()) + _run_one_shot_audio_file(args, watch_dir, debater, statement_side, chunk_log_id, env, listener) + return + + if args.outputs_dir: + outputs_dir = Path(args.outputs_dir).resolve() + else: + session_log_id = args.log_id or infer_session_log_id() + if not session_log_id: + raise SystemExit( + "No log_files/N_outputs found. Create one (e.g. run env.py once) or pass --log-id / --outputs-dir." + ) + if args.log_id is None: + print(f"Inferred session log-id={session_log_id!r} (latest log_files/N_outputs).") + outputs_dir = default_outputs_dir(session_log_id) + + if not outputs_dir.is_dir(): + raise NotADirectoryError(f"Speaker TTS directory missing: {outputs_dir}") + + bridge_thread = threading.Thread( + target=run_live_chunk_bridge, + kwargs={ + "outputs_dir": outputs_dir, + "watch_dir": watch_dir, + "speaker_side": speaker_side, + "chunk_log_id": chunk_log_id, + "stop_event": stop_bridge, + "audio_format": args.audio_format, + "split_mode": args.split_mode, + "chunk_seconds": args.chunk_seconds, + "silence_window_seconds": args.silence_window_seconds, + "poll_interval": args.poll_interval, + "stable_polls": args.bridge_stable_polls, + "skip_initial_files": not args.process_existing_speaker_mp3s, + }, + name="LiveChunkBridge", + daemon=True, + ) + bridge_thread.start() + + wall = args.max_wall_seconds + t0 = time.time() + print("Live mode running (Ctrl+C to stop).") + while True: + if not listener.is_alive(): + print("Listener thread exited unexpectedly.") + break + if wall and wall > 0 and (time.time() - t0) >= wall: + print(f"--max-wall-seconds={wall} reached; stopping.") + break + time.sleep(0.5) + + except KeyboardInterrupt: + print("Interrupted.") + finally: + stop_bridge.set() + time.sleep(0.5) + env.stop() + drain = args.pipeline_timeout_seconds if args.pipeline_timeout_seconds and args.pipeline_timeout_seconds > 0 else 90.0 + listener.join(timeout=drain) + if listener.is_alive(): + print("Listener still running after drain timeout.") + + if not args.no_clear_watch: + clear_watch_chunk_files(watch_dir, args.audio_format) + + +if __name__ == "__main__": + main() diff --git a/src/utils/db.py b/src/utils/db.py index ceeed92..26589bd 100644 --- a/src/utils/db.py +++ b/src/utils/db.py @@ -3,21 +3,21 @@ from datetime import datetime from typing import Optional -CACHE_DIR = "../.cache" +# Repo root is .../TreeDebater (parent of src/), not cwd-relative ../.cache +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +CACHE_DIR = os.path.join(_REPO_ROOT, ".cache") def init_db(force: bool = False): - db_path = f"{CACHE_DIR}/search.db" - if not os.path.exists(db_path) or os.path.getsize(db_path) == 0: - init_db() - conn = sqlite3.connect(db_path) + db_path = os.path.join(CACHE_DIR, "search.db") if not os.path.exists(CACHE_DIR): - os.makedirs(CACHE_DIR) + os.makedirs(CACHE_DIR, exist_ok=True) - if force and os.path.exists(".cache/search.db"): - os.remove(".cache/search.db") - conn = sqlite3.connect(db_name) + if force and os.path.exists(db_path): + os.remove(db_path) + + conn = sqlite3.connect(db_path) c = conn.cursor() c.execute( @@ -46,7 +46,7 @@ def init_db(force: bool = False): def save_query(query: str, answer: str): - conn = sqlite3.connect(f"{CACHE_DIR}/search.db") + conn = sqlite3.connect(os.path.join(CACHE_DIR, "search.db")) c = conn.cursor() current_time = datetime.now().isoformat() @@ -66,7 +66,7 @@ def save_query(query: str, answer: str): def get_cached_answer(query: str) -> Optional[tuple]: - conn = sqlite3.connect(f"{CACHE_DIR}/search.db") + conn = sqlite3.connect(os.path.join(CACHE_DIR, "search.db")) c = conn.cursor() c.execute("SELECT answer, created_at, updated_at FROM queries WHERE query = ?", (query,)) result = c.fetchone() @@ -75,7 +75,7 @@ def get_cached_answer(query: str) -> Optional[tuple]: def remove_query(query: str) -> bool: - conn = sqlite3.connect("cache/search.db") + conn = sqlite3.connect(os.path.join(CACHE_DIR, "search.db")) c = conn.cursor() try: diff --git a/src/utils/helper.py b/src/utils/helper.py index 0236616..ca19c65 100644 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -4,8 +4,16 @@ from debate_tree import PrepareTree +from .llm_schemas import ( + BattlefieldResponse, + QueryResponse, + SelectionClaimsOnlyResponse, + SelectionFrameworkResponse, + StatementsResponse, +) from .prompts import * from .time_estimator import LengthEstimator +from .timing_log import log_llm_io from .tool import get_response_with_retry, identify_number_in_text, logger, sort_by_action ##################### Evidence ##################### @@ -19,9 +27,9 @@ def select_query(llm, motion, stance, claim, action, candidate_queries): prompt = select_query_prompt.format( claim=claim, motion=motion, stance=stance, action=action, candidate_queries=candidate_queries ) - logger.debug("[Query-Helper-Prompt] " + prompt.strip().replace("\n", " ||| ")) - query, response = get_response_with_retry(llm, prompt, "query") - logger.debug("[Query-Helper-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="helper", title="Query-Helper-Prompt", body=prompt.strip()) + query, response = get_response_with_retry(llm, prompt, "query", response_model=QueryResponse) + log_llm_io(logger, phase="helper", title="Query-Helper-Response", body=response.strip()) return query @@ -83,9 +91,14 @@ def build_cot_claims(llm, motion, side, claim_pool): "Use Json format with one key of **selection**. The value is a list of selected claims (string) that can be used in this debate.\n" ) - logger.debug("[CoT-Claims-Prompt] " + prompt.strip().replace("\n", " ||| ")) - selected_claims, response = get_response_with_retry(llm, prompt, "selection") - logger.debug("[CoT-Claims-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="helper", title="CoT-Claims-Prompt", body=prompt.strip(), side=side) + selected_claims, response = get_response_with_retry( + llm, + prompt, + "selection", + response_model=SelectionClaimsOnlyResponse, + ) + log_llm_io(logger, phase="helper", title="CoT-Claims-Response", body=response.strip(), side=side) # selected_claims = [x if x.endswith(".") else x + '.' for x in selected_claims] claim_content = [x[0]["claim"] for x in claim_pool] @@ -138,9 +151,14 @@ def build_logic_claims(llm, motion, side, claim_pool, context="", definition="", prompt = main_claim_selection.format( motion=motion, side=side, tree=tree_info, claims="\n".join(ori_claims), context=context, definition=definition ) - logger.debug("[Logic-Claims-Prompt] " + prompt.strip().replace("\n", " ||| ")) - content, response = get_response_with_retry(llm, prompt, "selection") - logger.debug("[Logic-Claims-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="helper", title="Logic-Claims-Prompt", body=prompt.strip(), side=side) + content, response = get_response_with_retry( + llm, + prompt, + "selection", + response_model=SelectionFrameworkResponse, + ) + log_llm_io(logger, phase="helper", title="Logic-Claims-Response", body=response.strip(), side=side) # Step 5. Parse model outputs selected_claims = content["claims"] @@ -236,7 +254,7 @@ def get_actions_from_tree(claims, tree, oppo_tree): } ) - logger.debug(f"[Debate-Flow-Tree-Action] {actions}") + log_llm_io(logger, phase="helper", title="Debate-Flow-Tree-Action", body=json.dumps(actions, indent=2)) df = pd.DataFrame(actions) df = df.drop_duplicates(subset=["target_claim"]) @@ -254,9 +272,14 @@ def get_battlefields_from_actions(llm, motion, side, claims, actions, tree, oppo tree=tree.print_tree(include_status=True), oppo_tree=oppo_tree.print_tree(include_status=True), ) - logger.debug("[Debate-Flow-Tree-Action-Eval-Prompt] " + prompt.strip().replace("\n", " ||| ")) - eval_results, response = get_response_with_retry(llm, prompt, "response") - logger.debug("[Debate-Flow-Tree-Action-Eval-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io(logger, phase="helper", title="Debate-Flow-Tree-Action-Eval-Prompt", body=prompt.strip(), side=side) + eval_results, response = get_response_with_retry( + llm, + prompt, + "response", + response_model=BattlefieldResponse, + ) + log_llm_io(logger, phase="helper", title="Debate-Flow-Tree-Action-Eval-Response", body=response.strip(), side=side) battlefields = [] for eval_result in eval_results: @@ -484,7 +507,26 @@ def extract_statement(llm, motion, statement, claims=None, tree=None, side=None, else: prompt = extract_statment_prompt.format(motion=motion, statement=statement) - logger.debug("[Analyze-Helper-Prompt] " + prompt.strip().replace("\n", " ||| ")) - claims, response = get_response_with_retry(llm, prompt, "statements") - logger.debug("[Analyze-Helper-Response] " + response.strip().replace("\n", " ||| ")) + log_llm_io( + logger, + phase="helper", + title="Analyze-Helper-Prompt", + body=prompt.strip(), + side=side, + stage=stage, + ) + claims, response = get_response_with_retry( + llm, + prompt, + "statements", + response_model=StatementsResponse, + ) + log_llm_io( + logger, + phase="helper", + title="Analyze-Helper-Response", + body=response.strip(), + side=side, + stage=stage, + ) return claims diff --git a/src/utils/llm_schemas.py b/src/utils/llm_schemas.py new file mode 100644 index 0000000..45e5489 --- /dev/null +++ b/src/utils/llm_schemas.py @@ -0,0 +1,96 @@ +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class SchemaBase(BaseModel): + model_config = ConfigDict(extra="ignore") + + +class PurposeItem(SchemaBase): + action: Literal["propose", "rebut", "reinforce", "attack"] + target: str + targeted_debate_tree: Literal["you", "opponent"] + + +class StatementItem(SchemaBase): + claim: str + arguments: list[str] = Field(default_factory=list) + content: str | None = None + type: Literal["common", "definition", "criteria"] | None = None + purpose: list[PurposeItem] | PurposeItem | None = None + + +class StatementsResponse(SchemaBase): + statements: list[StatementItem] + + +class SelectionClaimsOnlyResponse(SchemaBase): + selection: list[str] + + +class SelectionFramework(SchemaBase): + claims: list[str] + framework: str + explanation: str + + +class SelectionFrameworkResponse(SchemaBase): + selection: SelectionFramework + + +class ActionItem(SchemaBase): + action: str + target_claim: str + target_argument: str | None = None + prepared_materials: str | None = None + targeted_debate_tree: Literal["you", "opponent"] | None = None + idx: int | None = None + argument: str | None = None + importance: Literal["high", "medium", "low"] | None = None + + +class ActionListResponse(SchemaBase): + response: list[ActionItem] + + +class BattlefieldEvalItem(SchemaBase): + battlefield: str + idx_list: list[int] = Field(default_factory=list) + unified_argument: str = "" + importance: Literal["high", "medium", "low"] = "medium" + + +class BattlefieldResponse(SchemaBase): + response: list[BattlefieldEvalItem] + + +class ResultsItem(SchemaBase): + claim: str + explanation: str | None = None + perspective: str | None = None + concepts: list[str] | None = None + strength: int | float | None = None + + +class ResultsResponse(SchemaBase): + results: list[ResultsItem] + + +class AuthorItem(SchemaBase): + id: str | int | None = None + author: str | None = None + author_info: str | None = None + publication: str | None = None + + +class AuthorsResponse(SchemaBase): + authors: list[AuthorItem] + + +class SelectedIdsResponse(SchemaBase): + selected_ids: list[int | str] + + +class QueryResponse(SchemaBase): + query: list[str] diff --git a/src/utils/model.py b/src/utils/model.py index 923820d..3235453 100644 --- a/src/utils/model.py +++ b/src/utils/model.py @@ -1,11 +1,21 @@ +import time +from typing import Any, Type + import litellm import numpy as np import torch +from pydantic import BaseModel from transformers import AutoTokenizer, LlamaForSequenceClassification from utils.constants import ATTACK_RM_PATH, SUPPORT_RM_PATH, google_api_key +from utils.timing_log import log_timing from utils.tool import logger +try: + import instructor +except Exception: + instructor = None + safety_setting = [ { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", @@ -34,7 +44,9 @@ def HelperClient( n=1, stop=None, sys=None, -) -> list: + response_model: Type[BaseModel] | None = None, + use_instructor: bool | None = None, +) -> list[str] | list[BaseModel]: if sys is not None: messages = [{"role": "system", "content": sys}] else: @@ -65,30 +77,101 @@ def HelperClient( messages.append({"role": "user", "content": prompt}) responses = [] - for _ in range(n): - # # Check if we need JSON response format and if the model supports it - # use_json_format = ("json" in prompt.lower() or (sys is not None and "json" in sys.lower())) - - # # Only use response_format for models that support it - # if use_json_format and ("gpt-4o" in model_name.lower() or "gpt-4-turbo" in model_name.lower() or "gpt-3.5-turbo" in model_name.lower()): - if "json" in prompt.lower() or (sys is not None and "json" in sys.lower()): - response = litellm.completion( - model=model_name, - response_format={"type": "json_object"}, + for i in range(n): + t0 = time.perf_counter() + wants_json = "json" in prompt.lower() or (sys is not None and "json" in sys.lower()) + structured_enabled = response_model is not None and ( + use_instructor is True or (use_instructor is None and _supports_structured_output(model_name)) + ) + response = None + structured_value = None + if structured_enabled: + try: + structured_value = _completion_structured( + model_name=model_name, + messages=messages, + response_model=response_model, + temperature=temperature, + max_tokens=max_tokens, + stop=stop, + kwargs=kwargs, + ) + except Exception as e: + logger.warning(f"Structured output fallback for {model_name}: {e}") + + if structured_value is None: + response = _completion_text( + model_name=model_name, messages=messages, + wants_json=wants_json or response_model is not None, temperature=temperature, max_tokens=max_tokens, stop=stop, - **kwargs, + kwargs=kwargs, ) + + elapsed = time.perf_counter() - t0 + ctx = {"model": model, "n_index": i + 1, "max_tokens": max_tokens} + if response is not None: + try: + cost = getattr(response, "_hidden_params", {}).get("response_cost") + if cost is not None: + ctx["response_cost"] = cost + except Exception: + pass + log_timing(logger, "helper_client_litellm", elapsed, **ctx) + if response_model is not None: + responses.append(structured_value if structured_value is not None else response.choices[0].message.content) else: - response = litellm.completion( - model=model_name, messages=messages, temperature=temperature, max_tokens=max_tokens, stop=stop, **kwargs - ) - responses.append(response.choices[0].message.content) + responses.append(response.choices[0].message.content) return responses +def _supports_structured_output(model_name: str) -> bool: + name = model_name.lower() + return any(x in name for x in ["gpt", "o1", "claude", "gemini"]) + + +def _completion_text(model_name: str, messages, wants_json: bool, temperature: float, max_tokens: int, stop, kwargs): + call_kwargs: dict[str, Any] = dict( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + stop=stop, + **kwargs, + ) + if wants_json: + call_kwargs["response_format"] = {"type": "json_object"} + return litellm.completion(**call_kwargs) + + +def _completion_structured( + model_name: str, + messages, + response_model: Type[BaseModel], + temperature: float, + max_tokens: int, + stop, + kwargs, +) -> BaseModel: + if instructor is None: + raise RuntimeError("instructor is not installed.") + + if hasattr(instructor, "from_litellm"): + client = instructor.from_litellm(litellm.completion) + return client.chat.completions.create( + model=model_name, + messages=messages, + response_model=response_model, + temperature=temperature, + max_tokens=max_tokens, + stop=stop, + **kwargs, + ) + raise RuntimeError("Installed instructor version does not support from_litellm.") + + models_loaded = False pro_model = None con_model = None diff --git a/src/utils/prompts/others.py b/src/utils/prompts/others.py index 383acf8..ec14a32 100644 --- a/src/utils/prompts/others.py +++ b/src/utils/prompts/others.py @@ -176,38 +176,36 @@ extract_statment_with_tree_prompt = ( - "## Task: Analyze the statements\n" - "Your task is to analyze the statements and identify the key claims presented in the statement and evidence or reasoning to support the claims.\n" - "1. These claims are used to support your stance on the debate topic. Therefore, do not include the debate topic as the claim.\n" - "2. Identify the key claims presented in the statement and evidence or reasoning to support the claims.\n" - "3. For each claim, put the original statement for this claim in **content** and summarize the evidence or reasoning in the statement in **arguments**.\n" - "4. The type of the claim can be **common**, **definition**, **criteria**. **definition** and **criteria** only appear in the opening stage to clarify the definition of the debate topic and the criteria for judging the debate topic.\n" - "5. You are given two debate trees that models the back-and-forth between you and your opponent. Your extracted claims can be used to: \n" - "\t- propose the main claims under Level-0 of your debate tree (only if there is no Level-1 in your debate tree) \n" - "\t- rebut the opponent's attacks in Level-2 of your debate tree. The extracted claim should be the counter-claim to the opponent's attack in Level-2 of your debate tree\n" - "\t- reinforce the main claims in Level-1 of your debate tree. Only use this action if this claim is not designed to rebut the opponent's attack in Level-2 of your debate tree\n" - "\t- attack the opponent's proposed claims in Level-1 of your opponent's debate tree. The extracted claim should be the counter-claim to the opponent's proposed claim in Level-1 of your opponent's debate tree\n" - "The purpose of the claim should be consistent with the debate tree. \n" - "Each claim should be used for one of the above purposes or the combination of them. For example, if the node in Level-2 of your debate tree is the same with the node in Level-1 of your opponent's debate tree, the purpose of the claim will be **rebut** and **attack**.\n" - "Please provide all the possible purposes for each claim. The purpose includes a list of dictionaries with the following three keys: " - "\n- **action**: propose, reinforce, rebut or attack " - "\n- **targeted_debate_tree**: you or opponent" - "\n- **target**: the *claim* value of the node in the debate tree. return `N/A` if there is no target" - "\n - For propose: the target is the proposed claim to be added in Level-0 of your debate tree. It should be the same as the claim" - "\n - For rebut: the target should be the claim of the Level-2 nodes in your debate tree" - "\n - For attack: the target should be the claim of the Level-1 nodes in your opponent's debate tree" - "\n - For reinforce: the target should be the claim of the Level-1 nodes in your debate tree, or the claim of the Level-2 nodes in your opponent's debate tree" - "6. It should be at least 3 claims in the statement.\n\n" + "## Task: Analyze the statement\n" + "Read the statement and extract distinct claims, each with the evidence or reasoning that supports it in the text.\n" + "1. Claims support your stance on the debate topic. Do not treat the motion or restating the topic as a claim.\n" + "2. For each claim, copy the verbatim span from the statement into **content** (the claim plus its supporting reasoning or evidence in the statement). Summarize that support as short items in **arguments**.\n" + "3. Assign **type**: **common**, **definition**, or **criteria**. Use **definition** and **criteria** only in the opening stage—for clarifying how the motion is defined or how the debate should be judged.\n" + "4. You are given two debate trees that model the exchange between you and your opponent. Each extracted claim must align with the trees. It may serve one role or several; list every applicable role in **purpose**.\n" + "\t- **propose**: add a main claim at Level-0 of your tree (only when your tree has no Level-1 nodes yet).\n" + "\t- **rebut**: counter an opponent attack at Level-2 of your tree; the extracted claim should oppose that Level-2 node's **claim**.\n" + "\t- **reinforce**: strengthen a Level-1 main claim in your tree. Do not use this for material whose primary role is rebutting a Level-2 attack on your tree.\n" + "\t- **attack**: counter a Level-1 main claim in the opponent's tree; the extracted claim should oppose that Level-1 node's **claim**.\n" + "Purposes must be consistent with the trees. A claim can combine roles (e.g. if a Level-2 node in your tree matches a Level-1 node in the opponent's tree, **rebut** and **attack** may both apply).\n" + "**purpose** is a JSON array of objects. Each object has exactly these keys:\n" + "\t- **action**: one of propose, reinforce, rebut, attack\n" + "\t- **targeted_debate_tree**: \"you\" or \"opponent\" (which tree the action refers to)\n" + "\t- **target**: the **claim** field of the targeted tree node—never an argument text. Use the string `N/A` only when no node applies.\n" + "\t - **propose**: target is the claim to add at Level-0 of your tree; it must match this item's **claim**.\n" + "\t - **rebut**: target is the **claim** of the Level-2 node in your tree you are answering.\n" + "\t - **attack**: target is the **claim** of the Level-1 node in the opponent's tree you are attacking.\n" + "\t - **reinforce**: target is the **claim** of a Level-1 node in your tree, or the **claim** of a Level-2 node in the opponent's tree you are reinforcing against.\n" + "5. Output at least three distinct claims drawn from the statement.\n\n" "## Tree Structure\n" - "You are given a debate tree that models the back-and-forth between you and your opponent:\n" - "Your debate tree: \n" - "* Level-1: The main claims proposed by you\n" - "* Level-2: Your opponent's attacks on your claims\n" - "* Level-3: Your rebuttal on the attacks\n\n" - "Opponent's debate tree: \n" - "* Level-1: The main claims proposed by your opponent\n" - "* Level-2: Your attacks on the opponent's claims\n" - "* Level-3: The opponent's rebuttal on your attacks\n\n" + "Each tree is a back-and-forth model. Levels describe depth, not speaker order alone.\n" + "Your debate tree:\n" + "* Level-1: Your main claims\n" + "* Level-2: Opponent attacks on your claims\n" + "* Level-3: Your rebuttals to those attacks\n\n" + "Opponent's debate tree:\n" + "* Level-1: Opponent main claims\n" + "* Level-2: Your attacks on those claims\n" + "* Level-3: Opponent rebuttals to your attacks\n\n" "## Input Information\n" "**Debate Topic**: {motion} \n\n" "**Your Stance**: {side} \n\n" @@ -216,13 +214,13 @@ "**Your Debate Tree**: \n{tree} \n\n" "**Opponent's Debate Tree**: \n{oppo_tree} \n\n" "##Response Format\n" - "Provide your response in JSON format with one key of **statements**. The value of this key is a list of claims and their arguments (evidence or reasoning). \n" - "The keys of each element of the list are **claim**, **content**, **type**, **arguments**, **purpose**.\n" - "- The value of **claim** is the main claim. \n" - "- The value of **content** is the original part of the statement for the claim, including the claim and the evidence or reasoning used to support the claim. \n" - "- The value of **type** is the type of the claim, it can be **common**, **definition**, **criteria**. \n" - "- The value of **arguments** is a list of summarized reasoning and evidence used to support the claim. \n" - "- The value of **purpose** is a list of all possible purposes of the claim. \n\n" + "Return JSON with a single top-level key **statements**. Its value is an array of objects.\n" + "Each object must have keys **claim**, **content**, **type**, **arguments**, **purpose**.\n" + "- **claim**: concise statement of the main claim.\n" + "- **content**: verbatim excerpt from the statement covering the claim and its support.\n" + "- **type**: **common**, **definition**, or **criteria**.\n" + "- **arguments**: list of short summaries of reasoning or evidence for the claim. Omit items already represented as arguments on the corresponding tree nodes.\n" + "- **purpose**: array of purpose objects as specified above.\n\n" ) extract_statment_by_claim_prompt = ( diff --git a/src/utils/timing_log.py b/src/utils/timing_log.py new file mode 100644 index 0000000..4a173d7 --- /dev/null +++ b/src/utils/timing_log.py @@ -0,0 +1,191 @@ +"""Single-line agent timing logs (decoupled from streaming parse_log_line format).""" + +from __future__ import annotations + +import itertools +import logging +import threading +import time +from contextlib import contextmanager +from typing import Any, Iterator, Mapping, Optional + +# While ``Agent.speak`` / ``TreeDebater.speak`` runs, I/O helpers reuse this ``call_id`` and +# ``speak_session`` so ``[timing-meta]``, ``[io]``, and ``log_llm_io`` stay aligned. +_speak_io_tls = threading.local() + +# Monotonic correlation id per process (thread-safe enough for itertools.count in CPython GIL). +_call_id_counter = itertools.count(1) + + +def next_call_id() -> int: + return next(_call_id_counter) + + +def set_speak_io_context(call_id: int, speak_session: str) -> None: + """Bind ``call_id`` and ``speak_session`` for nested ``log_llm_io`` / ``post_process`` I/O.""" + _speak_io_tls.call_id = int(call_id) + _speak_io_tls.speak_session = speak_session + + +def clear_speak_io_context() -> None: + for attr in ("call_id", "speak_session"): + if hasattr(_speak_io_tls, attr): + delattr(_speak_io_tls, attr) + + +def get_speak_io_call_id() -> Optional[int]: + return getattr(_speak_io_tls, "call_id", None) + + +def get_speak_io_session() -> Optional[str]: + return getattr(_speak_io_tls, "speak_session", None) + + +def _fmt_val(v: Any) -> str: + if isinstance(v, bool): + return str(v) + if isinstance(v, float): + return f"{v:.4f}".rstrip("0").rstrip(".") + s = str(v).replace("\n", " ").replace("\r", " ") + if len(s) > 200: + return s[:197] + "..." + return s + + +def _timing_kv_parts(ctx: Mapping[str, Any]) -> list[str]: + """Stable key order first, then remaining keys sorted.""" + preferred = ( + "stage", + "side", + "speak_session", + "call_id", + "pass_index", + "add_evidence", + "block", + "iteration", + "max_retry", + "kind", + "model", + "cache_hit", + "audio_duration_s", + "n_claims", + ) + seen = set() + parts: list[str] = [] + for k in preferred: + if k not in ctx: + continue + v = ctx[k] + if v is None: + continue + parts.append(f"{k}={_fmt_val(v)}") + seen.add(k) + for k in sorted(ctx.keys()): + if k in seen: + continue + v = ctx[k] + if v is None: + continue + parts.append(f"{k}={_fmt_val(v)}") + return parts + + +def format_timing_line(phase: str, duration_s: float, **ctx: Any) -> str: + parts = ["[timing]", f"phase={phase}", f"duration_s={duration_s:.4f}", *_timing_kv_parts(ctx)] + return " ".join(parts) + + +def log_timing( + log: logging.Logger, + phase: str, + duration_s: float, + *, + level: int = logging.DEBUG, + **ctx: Any, +) -> None: + log.log(level, format_timing_line(phase, duration_s, **ctx)) + + +@contextmanager +def timed_phase( + log: logging.Logger, + phase: str, + *, + log_start: bool = False, + level: int = logging.DEBUG, + **ctx: Any, +) -> Iterator[None]: + if log_start: + start_parts = ["[timing]", f"phase={phase}", "event=start", *_timing_kv_parts(ctx)] + log.log(level, " ".join(start_parts)) + t0 = time.perf_counter() + try: + yield + finally: + log_timing(log, phase, time.perf_counter() - t0, level=level, **ctx) + + +def log_io_block( + io_log: logging.Logger, + *, + call_id: int, + phase: str, + title: str, + body: str, + level: int = logging.DEBUG, + **ctx: Any, +) -> None: + """One prompt/response block in the I/O log file (not on main debate logger).""" + head = ["[io]", f"call_id={call_id}", f"phase={phase}", f"title={title}", *_timing_kv_parts(ctx)] + header = " ".join(head) + sep = "\n" + ("-" * 60) + "\n" + io_log.log(level, header + sep + (body or "").rstrip() + "\n" + ("=" * 60)) + + +def one_line_preview(s: str, max_len: int = 280) -> str: + """Short single-line preview for main log when full body is in *_io.log.""" + t = (s or "").strip().replace("\n", " ||| ") + if len(t) <= max_len: + return t + return t[: max_len - 24] + "... [truncated]" + + +def log_llm_io( + main_log: logging.Logger, + *, + phase: str, + title: str, + body: str, + level: int = logging.DEBUG, + emit_main_ref: bool = True, + call_id: Optional[int] = None, + io_block_phase: Optional[str] = None, + **ctx: Any, +) -> None: + """ + When I/O logging is enabled, write full ``body`` to *_io.log via :func:`log_io_block`. + Otherwise emit the legacy single-line ``[{title}]`` message on ``main_log``. + + ``call_id``: reuse the id from ``[timing-meta]`` when inside ``set_speak_io_context`` or when + passed explicitly; otherwise allocate a new id (e.g. BaselineDebater). + + ``io_block_phase``: value for ``[io] phase=...``; defaults to active ``speak_session`` from + :func:`set_speak_io_context`, then to ``phase``. Keeps one speak turn under one session label + (e.g. ``default_speak``) while ``title`` distinguishes Prompt vs Response vs TTS artifacts. + """ + from utils.tool import io_logger, io_logging_enabled + + text = (body or "").strip() + if io_logging_enabled(): + cid = call_id if call_id is not None else get_speak_io_call_id() + if cid is None: + cid = next_call_id() + io_ph = io_block_phase if io_block_phase is not None else (get_speak_io_session() or phase) + log_io_block(io_logger, call_id=cid, phase=io_ph, title=title, body=text, level=level, **ctx) + if emit_main_ref: + main_log.log( + level, + f"[io-ref] speak_session={io_ph} title={title} call_id={cid}", + ) + else: + main_log.log(level, f"[{title}] " + text.replace("\n", " ||| ")) diff --git a/src/utils/tool.py b/src/utils/tool.py index d8a18a6..2d89b56 100644 --- a/src/utils/tool.py +++ b/src/utils/tool.py @@ -2,14 +2,115 @@ import os import re import time -from typing import List +import json +from typing import Any, List, TypeVar +from pydantic import BaseModel, ValidationError from pulp import LpMaximize, LpProblem, LpVariable from .constants import MAX_TRY_NUM from .prompts import debater_system_prompt log_file_path = "" +io_log_file_path = "" + +debate_io_logger = logging.getLogger("debate_io_logger") +debate_io_logger.setLevel(logging.DEBUG) +debate_io_logger.propagate = False + + +def io_logging_enabled() -> bool: + """True when prompt/response blocks go to the I/O log file (default on).""" + if os.environ.get("DEBATE_LOG_PROMPTS", "1").lower() in ("0", "false", "no", "off"): + return False + return bool(debate_io_logger.handlers) + + +def _setup_debate_io_logger(main_log_file: str) -> None: + """Sibling file ``N_io.log`` next to ``N.log`` for large prompt/response bodies.""" + global io_log_file_path + if not main_log_file or not main_log_file.endswith(".log"): + return + if os.environ.get("DEBATE_LOG_PROMPTS", "1").lower() in ("0", "false", "no", "off"): + return + if debate_io_logger.handlers: + return + io_path = main_log_file.replace(".log", "_io.log") + io_log_file_path = io_path + fmt = logging.Formatter("%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + h = LazyFileHandler(io_path, mode="a", encoding="utf-8") + h.setLevel(logging.DEBUG) + h.setFormatter(fmt) + debate_io_logger.addHandler(h) + + +class LazyFileHandler(logging.FileHandler): + """ + FileHandler that only creates the log file when the first log record is emitted. + This prevents creation of empty log files when programs exit early or crash during init. + """ + + def __init__(self, filename, mode='a', encoding=None, delay=True): + """ + Initialize with delay=True to defer file creation. + File will be created on first emit() call. + """ + # Store filename for later use + self._lazy_filename = filename + self._lazy_mode = mode + self._lazy_encoding = encoding + + # Don't call parent __init__ yet - we'll do it lazily + logging.Handler.__init__(self) + + self.baseFilename = os.path.abspath(filename) + self.mode = mode + self.encoding = encoding + self.stream = None + self._file_created = False + + def _open(self): + """Open the log file (called on first emit).""" + # Ensure directory exists + log_dir = os.path.dirname(self.baseFilename) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir) + + return open(self.baseFilename, self.mode, encoding=self.encoding) + + def emit(self, record): + """ + Emit a record. Create file on first call. + """ + if not self._file_created: + # Create file now that we have something to log + self.stream = self._open() + self._file_created = True + + # Now emit normally + if self.stream: + try: + msg = self.format(record) + stream = self.stream + stream.write(msg + self.terminator) + self.flush() + except Exception: + self.handleError(record) + + def close(self): + """Close file handler.""" + self.acquire() + try: + if self.stream and self._file_created: + try: + self.flush() + if hasattr(self.stream, "close"): + self.stream.close() + finally: + self.stream = None + finally: + self.release() + logging.Handler.close(self) def get_output_path(base_dir="../log_files/", suffix="log"): @@ -17,8 +118,10 @@ def get_output_path(base_dir="../log_files/", suffix="log"): if not os.path.exists(base_dir): os.makedirs(base_dir) log_files = [f for f in os.listdir(base_dir) if f.endswith(".log")] - if log_files: - max_num = max(int(f.split(".")[0]) for f in log_files) + # Only "N.log" (integer N), not e.g. "19_io.log" or "debug.log" + numbered_logs = [f for f in log_files if len(f) > 4 and f[:-4].isdigit()] + if numbered_logs: + max_num = max(int(f[:-4]) for f in numbered_logs) new_log_file = f"{max_num + 1}.{suffix}" else: new_log_file = f"1.{suffix}" @@ -35,8 +138,9 @@ def create_log(log_file=None): log_file = get_output_path() print(f"Log file: {log_file}") - # File handler for logging to a file with DEBUG level - file_handler = logging.FileHandler(log_file) + # Lazy file handler for logging to a file with DEBUG level + # File will only be created when first log record is written + file_handler = LazyFileHandler(log_file, mode='a', encoding='utf-8') file_handler.setLevel(logging.DEBUG) file_formatter = logging.Formatter( "%(asctime)s %(levelname)s %(module)s - %(funcName)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" @@ -52,18 +156,79 @@ def create_log(log_file=None): # Add handlers to the logger log.addHandler(file_handler) log.addHandler(stream_handler) + _setup_debate_io_logger(log_file) + if io_log_file_path: + log.debug(f"[timing] phase=io_log_ready io_log={io_log_file_path}") return log logger = create_log() +# Alias for imports: ``from utils.tool import io_logger`` +io_logger = debate_io_logger + +T = TypeVar("T", bound=BaseModel) + + +def _strip_markdown_json_fence(text: str) -> str: + fenced = re.search(r"```(?:json)?\s*(.*?)```", text, re.IGNORECASE | re.DOTALL) + if fenced: + return fenced.group(1).strip() + return text + def find_json(x): - idx = x.find("{") - ridx = x.rfind("}") - if idx == -1 or ridx == -1: + return extract_json_object(x) + + +def extract_json_object(text: str) -> str: + if text is None: return "" - return x[idx : ridx + 1] + if isinstance(text, (dict, list)): + return json.dumps(text) + if not isinstance(text, str): + text = str(text) + + text = _strip_markdown_json_fence(text).strip() + idx = text.find("{") + ridx = text.rfind("}") + if idx != -1 and ridx != -1 and idx <= ridx: + return text[idx : ridx + 1] + lidx = text.find("[") + rridx = text.rfind("]") + if lidx != -1 and rridx != -1 and lidx <= rridx: + return text[lidx : rridx + 1] + return text + + +def parse_llm_json(text: Any, *, response_model: type[T] | None = None, required_key: str | None = None) -> T | Any: + if isinstance(text, BaseModel): + parsed = text + elif isinstance(text, (dict, list)): + parsed = text + else: + payload = extract_json_object(text) + parsed = json.loads(payload) + + if response_model is not None: + if isinstance(parsed, response_model): + validated = parsed + else: + validated = response_model.model_validate(parsed) + if required_key is None: + return validated + dumped = validated.model_dump() + if required_key not in dumped: + raise KeyError(f"Missing required key '{required_key}' in validated response.") + return dumped[required_key] + + if required_key is None: + return parsed + if not isinstance(parsed, dict): + raise TypeError(f"Expected dict for required_key='{required_key}', got {type(parsed).__name__}") + if required_key not in parsed: + raise KeyError(f"Missing required key '{required_key}' in parsed response.") + return parsed[required_key] def extract_numbers(s): @@ -188,23 +353,49 @@ def lp_optimize(actions: List[str], rewards: List[float], costs: List[float], bu return selected_actions, total_reward, total_cost -def get_response_with_retry(llm, prompt, required_key, **kwargs): +def get_response_with_retry(llm, prompt, required_key, *, response_model: type[T] | None = None, **kwargs): + from utils.timing_log import log_timing + retry = 0 response = "" content = {} - while len(content) == 0 and retry < MAX_TRY_NUM: + while retry < MAX_TRY_NUM: try: - response = llm(prompt=prompt, sys=debater_system_prompt, **kwargs)[0] - content = find_json(response) - response = response.replace("null", "") - content = eval(content) - content = content[required_key] - except Exception as e: + t0 = time.perf_counter() + response_obj = llm(prompt=prompt, sys=debater_system_prompt, response_model=response_model, **kwargs)[0] + llm_s = time.perf_counter() - t0 + log_timing( + logger, + "get_response_with_retry_llm", + llm_s, + required_key=required_key, + attempt=retry + 1, + response_model=response_model.__name__ if response_model is not None else None, + ) + if isinstance(response_obj, BaseModel): + response = json.dumps(response_obj.model_dump(), ensure_ascii=False) + content = parse_llm_json( + response_obj, + response_model=response_model or type(response_obj), + required_key=required_key, + ) + else: + response = response_obj if isinstance(response_obj, str) else json.dumps(response_obj, ensure_ascii=False) + content = parse_llm_json(response_obj, response_model=response_model, required_key=required_key) + if content is not None and content != {}: + return content, response + except (json.JSONDecodeError, ValidationError, KeyError, TypeError, ValueError) as e: logger.warning(f"Error {e} in extracting {required_key} from: {response}") content = {} retry += 1 logger.debug(f"Retry {retry} times.") time.sleep(30) + except Exception as e: + logger.warning(f"Unexpected error {e} in extracting {required_key} from: {response}") + content = {} + retry += 1 + logger.debug(f"Retry {retry} times.") + time.sleep(30) return content, response