Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/compare_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def step_play(self, side, stage, history, max_time):
history=history,
max_time=max_time,
time_control=self.time_control,
streaming_tts=self.baseline_debaters[side].config.streaming_tts,
streaming_tts=self.config.streaming_tts,
)

# Generate test response using reference history
test_response = test_call[stage](
history=history,
max_time=max_time,
time_control=self.time_control,
streaming_tts=self.test_debaters[side].config.streaming_tts,
streaming_tts=self.config.streaming_tts,
)
return base_response, test_response

Expand Down
1 change: 1 addition & 0 deletions src/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ env:
claim_pool_size: 50
reverse: False
time_control: True
streaming_tts: true

debater:
- side: for
Expand Down
1 change: 1 addition & 0 deletions src/configs/compare.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ env:
claim_pool_size: 50
reverse: False
time_control: True
streaming_tts: true

debater:
- side: for
Expand Down
1 change: 1 addition & 0 deletions src/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class EnvConfig:
claim_pool_size: int = 50
reverse: bool = False
time_control: bool = True
streaming_tts: bool = False


def extract_overall_score(obj_scores): # larger is better
Expand Down
2 changes: 1 addition & 1 deletion src/ouragents.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, config, motion):
+ f"use_rehearsal_tree: {self.use_rehearsal_tree}, use_debate_flow_tree: {self.use_debate_flow_tree}"
)

helper_model = getattr(config, "helper_model", self.config.model)
helper_model = getattr(config, "helper_model", None) or self.config.model
self.helper_client = partial(HelperClient, model=helper_model, temperature=0, max_tokens=config.max_tokens, n=1)
self.simulated_audience = [Audience(AudienceConfig(model=self.config.model, temperature=1)) for _ in range(1)]

Expand Down
16 changes: 15 additions & 1 deletion src/scripts/overlap_viz.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
#!/bin/bash
# Usage: bash overlap_viz.sh [chunks_dir_pattern]
# Default: all _*_chunks/ directories under src/
# When 2+ chunks dirs are matched, also produce a combined top-to-bottom PNG
# (overlap_timeline_combined.png) in the common parent directory.

SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
PATTERN="${1:-$SCRIPT_DIR/_*_chunks}"
PATTERN="${@:-$SCRIPT_DIR/_*_chunks}"

pngs=()
found=0
for dir in $PATTERN; do
csv="$dir/chunk_profile.csv"
if [ -f "$csv" ]; then
echo "[VIZ] $csv"
python "$SCRIPT_DIR/overlap_viz_par.py" "$csv"
png="$dir/overlap_timeline_par.png"
if [ -f "$png" ]; then
pngs+=("$png")
fi
found=$((found + 1))
else
echo "[SKIP] $csv not found"
Expand All @@ -19,4 +26,11 @@ done

if [ "$found" -eq 0 ]; then
echo "No chunk_profile.csv found. Run with streaming_tts: true first."
elif [ "${#pngs[@]}" -ge 2 ]; then
parent=$(python -c "
import os, sys
print(os.path.commonpath([os.path.dirname(p) for p in sys.argv[1:]]))
" "${pngs[@]}")
out="$parent/overlap_timeline_combined.png"
python "$SCRIPT_DIR/stack_pngs.py" "${pngs[@]}" -o "$out"
fi
206 changes: 153 additions & 53 deletions src/scripts/overlap_viz_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,43 +62,85 @@ def compute_timeline(rows):
play0_start = t + rows[0]["total_elapsed_s"] # = tts_api_s for chunk 0
play0_end = play0_start + rows[0]["audio_seconds"]

# chunk 0 has a single TTS candidate (no workers); legacy path.
legacy_tts_times = rows[0].get("iter_tts_times_s", []) or []
legacy_cands = []
for k, tts_dur in enumerate(legacy_tts_times):
legacy_cands.append({
"k": k,
"submit_t": t,
"duration": float(tts_dur),
"is_chosen": True,
"worker": "normal",
"intra_iter": k,
})

events.append({
"chunk": 0,
"prep_start": t,
"main_segments": [], # no fs/llm
"tts_candidates": _build_tts_candidates(rows[0], t),
"play_start": play0_start,
"play_end": play0_end,
"gap": 0.0,
"chunk": 0,
"prep_start_normal": t,
"prep_start_prestart": None,
"prep_end": play0_start,
"normal_segments": [],
"prestart_segments": [],
"tts_candidates": legacy_cands,
"play_start": play0_start,
"play_end": play0_end,
"gap": 0.0,
"is_prestart": False,
"prestart_kind": "",
})

for i in range(1, len(rows)):
prep_start = events[i - 1]["play_start"]
prep_duration = rows[i]["total_elapsed_s"] # true wall-clock (parallel)
prep_start_default = events[i - 1]["play_start"]
prep_duration = rows[i]["total_elapsed_s"] # streaming-adjusted (lead already subtracted for pre-started)

lead_s = float(rows[i].get("prep_start_lead_s", 0.0) or 0.0)
is_prestart = lead_s > 0.01
prestart_kind = rows[i].get("prestart_kind", "") or ""
prep_start_normal = prep_start_default
prep_start_prestart = (prep_start_default - lead_s) if is_prestart else None

prev_play_end = events[i - 1]["play_end"]
ready_at = prep_start + prep_duration
ready_at = prep_start_default + prep_duration # wall-clock when prep is done
gap = max(0.0, ready_at - prev_play_end)
play_start = max(prev_play_end, ready_at)
play_end = play_start + rows[i]["audio_seconds"]

prestart_segments = _build_worker_segments(rows[i], "prestart", prep_start_prestart) if is_prestart else []
normal_segments = _build_worker_segments(rows[i], "normal", prep_start_normal)
tts_candidates = _build_dual_candidates(rows[i], prep_start_prestart, prep_start_normal)

events.append({
"chunk": i,
"prep_start": prep_start,
"main_segments": _build_main_segments(rows[i], prep_start),
"tts_candidates": _build_tts_candidates(rows[i], prep_start),
"play_start": play_start,
"play_end": play_end,
"gap": gap,
"chunk": i,
"prep_start_normal": prep_start_normal,
"prep_start_prestart": prep_start_prestart,
"prep_end": ready_at,
"normal_segments": normal_segments,
"prestart_segments": prestart_segments,
"tts_candidates": tts_candidates,
"play_start": play_start,
"play_end": play_end,
"gap": gap,
"is_prestart": is_prestart,
"prestart_kind": prestart_kind if is_prestart else "",
})

return events


def _build_main_segments(row, prep_start):
"""fs / llm alternating segments on the main thread."""
fs_times = row.get("iter_fs_times_s", []) or []
llm_times = row.get("iter_llm_times_s", []) or []
def _build_worker_segments(row, worker, prep_start):
"""fs / llm alternating segments for one worker, starting at prep_start."""
if prep_start is None:
return []
fs_times = row.get(f"{worker}_fs_times_s", []) or []
llm_times = row.get(f"{worker}_llm_times_s", []) or []
# Backward compat: if per-worker fields are absent (old CSV), fall back to combined.
if not fs_times and not llm_times:
if worker != "normal":
return []
fs_times = row.get("iter_fs_times_s", []) or []
llm_times = row.get("iter_llm_times_s", []) or []

segments = []
t = prep_start
n_iters = max(len(fs_times), len(llm_times))
Expand All @@ -112,28 +154,67 @@ def _build_main_segments(row, prep_start):
return segments


def _build_tts_candidates(row, prep_start):
def _build_dual_candidates(row, prep_start_prestart, prep_start_normal):
"""
Return list of dicts {k, submit_t, duration, is_chosen}.
Build TTS candidate list combining prestart + normal workers.
Each candidate is {k, submit_t, duration, is_chosen, worker, intra_iter}.
k is assigned by sorting candidates by submit_t (so vertical stacking
matches chronological submission order).

Candidate k is submitted after sum(fs[0..k]) + sum(llm[0..k-1])
duration == -1.0 means the future had not completed at selection time.
Within a single worker: candidate at intra_iter j is submitted after
sum(fs[0..j]) + sum(llm[0..j-1]) of that worker's timeline.
"""
fs_times = row.get("iter_fs_times_s", []) or []
llm_times = row.get("iter_llm_times_s", []) or []
tts_times = row.get("iter_tts_times_s", []) or []
used_iter = int(row.get("used_candidate_iter", 0))

result = []
for k, tts_dur in enumerate(tts_times):
submit_t = prep_start + sum(fs_times[: k + 1]) + sum(llm_times[:k])
result.append({
"k": k,
"submit_t": submit_t,
"duration": float(tts_dur),
"is_chosen": k == used_iter,
})
return result
chosen_label = (row.get("chosen_worker_label", "") or "").strip()
chosen_intra = int(row.get("chosen_intra_iter", 0) or 0)
used_iter_legacy = int(row.get("used_candidate_iter", 0) or 0)

cands = []

def _add_worker(worker, start_t):
if start_t is None:
return
fs_times = row.get(f"{worker}_fs_times_s", []) or []
llm_times = row.get(f"{worker}_llm_times_s", []) or []
tts_times = row.get(f"{worker}_tts_times_s", []) or []
for j, tts_dur in enumerate(tts_times):
submit_t = start_t + sum(fs_times[: j + 1]) + sum(llm_times[:j])
is_chosen = (chosen_label == worker and j == chosen_intra)
cands.append({
"submit_t": submit_t,
"duration": float(tts_dur),
"is_chosen": is_chosen,
"worker": worker,
"intra_iter": j,
})

has_per_worker = bool(
row.get("normal_tts_times_s")
or row.get("prestart_tts_times_s")
)

if has_per_worker:
_add_worker("prestart", prep_start_prestart)
_add_worker("normal", prep_start_normal)
else:
# Backward compat: old CSV with combined iter_*_times_s
fs_times = row.get("iter_fs_times_s", []) or []
llm_times = row.get("iter_llm_times_s", []) or []
tts_times = row.get("iter_tts_times_s", []) or []
start_t = prep_start_prestart if prep_start_prestart is not None else prep_start_normal
for j, tts_dur in enumerate(tts_times):
submit_t = start_t + sum(fs_times[: j + 1]) + sum(llm_times[:j])
cands.append({
"submit_t": submit_t,
"duration": float(tts_dur),
"is_chosen": j == used_iter_legacy,
"worker": "normal",
"intra_iter": j,
})

cands.sort(key=lambda c: c["submit_t"])
for k, c in enumerate(cands):
c["k"] = k
return cands


# ── plotting ──────────────────────────────────────────────────────────────────
Expand All @@ -148,10 +229,11 @@ def _build_tts_candidates(row, prep_start):
"tts_nd": "#BBBBBB", # grey – TTS still running at selection
}

BAR_H = 0.32 # playback / main-thread bar height
TTS_H = 0.18 # height of each TTS candidate sub-row
TTS_GAP = 0.06 # vertical gap between candidate sub-rows
Y_MAIN = 0.0 # centre of main-thread lane
BAR_H = 0.32 # playback / main-thread bar height
TTS_H = 0.18 # height of each TTS candidate sub-row
TTS_GAP = 0.06 # vertical gap between candidate sub-rows
Y_MAIN = 0.0 # centre of main-thread lane
Y_PRESTART = -0.6 # centre of pre-start lane (only shown when a pre-started chunk exists)


def _tts_y(k):
Expand All @@ -169,8 +251,10 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline
n = len(rows)
max_cands = max((len(ev["tts_candidates"]) for ev in events), default=1)
Y_PLAY = _y_play(max_cands)
has_prestart = any(ev.get("is_prestart") for ev in events)
bottom_y = Y_PRESTART if has_prestart else Y_MAIN

fig_h = max(5.0, Y_PLAY + 1.2)
fig_h = max(5.0, Y_PLAY - bottom_y + 1.2)
fig, ax = plt.subplots(figsize=(max(14, n * 2.5), fig_h))

for ev in events:
Expand All @@ -181,8 +265,9 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline
dur = ev["play_end"] - ev["play_start"]
ax.barh(Y_PLAY, dur, left=ev["play_start"], height=BAR_H,
color=COLORS["play"], alpha=0.88, edgecolor="white", linewidth=0.6)
kind_suffix = f" [{ev['prestart_kind']}-prestart]" if ev.get("prestart_kind") else ""
ax.text(ev["play_start"] + dur / 2, Y_PLAY,
f"▶{i} {dur:.1f}s",
f"▶{i} {dur:.1f}s{kind_suffix}",
ha="center", va="center", fontsize=8, fontweight="bold", color="white")

# ── gap / silence ──────────────────────────────────────────────────────
Expand All @@ -195,7 +280,7 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline
ha="center", va="center", fontsize=7, color="white")

# ── TTS candidates (parallel band) ────────────────────────────────────
prep_end = ev["prep_start"] + rows[i]["total_elapsed_s"]
prep_end = ev["prep_end"]
for cand in cands:
y = _tts_y(cand["k"])
if cand["duration"] > 0:
Expand All @@ -220,8 +305,8 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline
f"T{cand['k']} (still running)",
fontsize=5.5, color="#888888", va="bottom")

# ── main thread (fs / llm) ─────────────────────────────────────────────
for seg_start, seg_end, kind in ev["main_segments"]:
# ── main thread (normal worker fs/llm) ────────────────────────────────
for seg_start, seg_end, kind in ev.get("normal_segments", []):
seg_dur = seg_end - seg_start
if seg_dur < 0.05:
continue
Expand All @@ -231,6 +316,17 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline
f"{kind}\n{seg_dur:.1f}s",
ha="center", va="center", fontsize=6.5, color="white")

# ── pre-start lane (prestart worker fs/llm; runs in parallel with normal) ──
for seg_start, seg_end, kind in ev.get("prestart_segments", []):
seg_dur = seg_end - seg_start
if seg_dur < 0.05:
continue
ax.barh(Y_PRESTART, seg_dur, left=seg_start, height=BAR_H,
color=COLORS[kind], alpha=0.88, edgecolor="white", linewidth=0.5)
ax.text(seg_start + seg_dur / 2, Y_PRESTART,
f"{kind}\n{seg_dur:.1f}s",
ha="center", va="center", fontsize=6.5, color="white")

# chunk boundary lines at each play_start
for ev in events:
ax.axvline(ev["play_start"], color="gray", linewidth=0.9,
Expand All @@ -245,12 +341,16 @@ def plot_timeline(rows, events, out_path, title="Parallel Chunk Overlap Timeline

# y-axis ticks
tts_mid = _tts_y((max_cands - 1) / 2)
ax.set_yticks([Y_MAIN, tts_mid, Y_PLAY])
ax.set_yticklabels(["Main thread\n(fs + llm)", "TTS candidates\n(parallel)", "Playback"],
fontsize=9)
yticks = [Y_MAIN, tts_mid, Y_PLAY]
ylabels = ["Main thread\n(fs + llm)", "TTS candidates\n(parallel)", "Playback"]
if has_prestart:
yticks = [Y_PRESTART] + yticks
ylabels = ["Pre-start lane\n(parallel worker)"] + ylabels
ax.set_yticks(yticks)
ax.set_yticklabels(ylabels, fontsize=9)
ax.set_xlabel("Wall-clock time (s)", fontsize=10)
ax.set_title(title, fontsize=12, fontweight="bold", pad=10)
ax.set_ylim(Y_MAIN - 0.5, Y_PLAY + 0.7)
ax.set_ylim(bottom_y - 0.5, Y_PLAY + 0.7)
ax.grid(axis="x", alpha=0.25)

legend_patches = [
Expand Down
Loading
Loading