<a href="https://colab.research.google.com/github/Dao-you/Whisper-for-Meeting-on-Colab/blob/main/Whisper_for_Meeting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
# =========================================================
# Whisper Automatic Subtitle Generation: GPU Transcription + CPU Denoising + OpenCC Post-processing (Traditional/Simplified Conversion)
# And LLM Summarization (GPT-OSS-20B / llama.cpp / CUDA)
# - Transcription: faster-whisper (CUDA, compute: int8_float16→float16→int8)
# - Denoising: ffmpeg afftdn (CPU)
# - Progress: Real-time printing of "current sentence + video total length percentage"
# - Network source download and output: MyDrive/whisper; Files in Drive: Output to the same folder
# - LLM Summary: llama.cpp + GPT-OSS-20B GGUF for summarizing transcription
# - Prompts "Delete runtime and restart" if download is blocked or abnormal
# =========================================================

# Restrict multithreading (more stable)
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

# [1/8] Mount Google Drive
from google.colab import drive
print("請授權以掛載 Google Drive")
drive.mount("/content/gdrive", force_remount=True)

# Built-in Imports
import sys, gc, shutil, datetime, subprocess as sp
from pathlib import Path
import re, math, time, textwrap
from typing import List, Tuple, Optional, Iterable, Iterator
from collections import defaultdict
from IPython.display import display, Markdown
import json


ROOT = Path("/content/gdrive/MyDrive")
WHISPER_DIR = ROOT / "whisper"
WHISPER_DIR.mkdir(exist_ok=True, parents=True)
os.chdir(ROOT)
print(f"→ 當前工作目錄：{os.getcwd()}")

# [2/8] User Form Parameters (Unified)
#@markdown # Whisper Transcription & LLM Summary Pipeline

#@markdown ## Input & Transcription Settings
#@markdown **Input Source:** Google Drive file (relative to MyDrive) or video URL (YouTube/HTTP).
filename = "whisper/2022第三屆青春來說課-第1會議室吳土城教授回饋.mp4"  #@param {type:"string"}
#@markdown **Download Option:** Check to save network source files to `MyDrive/whisper`.
save_video_to_google_drive = True  #@param {type:"boolean"}
#@markdown **Whisper Model Size:** Choose a model size. `large-v3` requires more GPU VRAM; `medium` is a good alternative if VRAM is limited.
model_size = "medium"  #@param ["tiny", "base", "small", "medium", "large-v2", "large-v3"] # Changed model_size to "medium"
#@markdown **Language:** Select the language for transcription. "自動偵測" (Auto-detect) is usually sufficient.
language = "自動偵測"  #@param ["自動偵測", "中文", "英文"]
#@markdown **Denoising:** Apply CPU-based denoising to the audio before transcription. `afftdn` is recommended.
denoise_method = "afftdn (建議)"  #@param ["afftdn (建議)", "none"]
#@markdown **Text Post-processing (OpenCC):** Convert the transcribed text (SRT/TXT output) between Simplified and Traditional Chinese variants.
text_postprocess = "臺灣繁體中文（預設）"  #@param ["臺灣繁體中文（預設）","香港繁體中文","大陸簡體中文","關閉"]
#@markdown **YouTube Cookies (Optional):** Path to a Netscape-format cookies file (relative to MyDrive) for accessing age-restricted or member-only YouTube videos (e.g., `cookies/youtube.txt`).
youtube_cookies_txt_path = ""  #@param {type:"string"}

#@markdown ## Summarization Settings
#@markdown **Topic Hint (Optional):** Provide a brief hint about the topic to guide the summarization process.
topic_hint = ""  #@param {type:"string"}

language_code_map = {"自動偵測": None, "中文":"zh", "英文":"en"}
language_code = language_code_map[language]

# =========================================================
# Developer Options
# Advanced users can fine-tune parameters in this section.
# Modify only if you understand the impact.
# =========================================================
DEBUG_MODE = False # Set to True for more detailed logging

# --- Transcription Parameters ---
TRANSCRIPTION_BEAM_SIZE_PRIMARY = 1
TRANSCRIPTION_CHUNK_LENGTH_PRIMARY = 20
TRANSCRIPTION_BEAM_SIZE_FALLBACK = 1 # Used if primary fails
TRANSCRIPTION_CHUNK_LENGTH_FALLBACK = 15 # Used if primary fails

# --- Denoising Parameters ---
DENOISE_NOISE_FLOOR_DB = -25

# --- Filtering Parameters ---
FILTER_MIN_DURATION_SHORT = 1.5 # Minimum duration for short segments
FILTER_AVG_LOGPROB_THRESHOLD = -0.5 # Avg log probability threshold for short segments
FILTER_MIN_DURATION_SPEECH_PROB = 2.0 # Minimum duration for speech probability filtering
FILTER_NO_SPEECH_PROB_THRESHOLD = 0.6 # No speech probability threshold

# --- Summary Model Parameters ---
REPO_ID   = "unsloth/gpt-oss-20b-GGUF"   # GGUF Model Repository
GGUF_FILE = "gpt-oss-20b-Q4_K_M.gguf"    # Approx. 10.8GiB, T4 can run

# --- Summary Inference Parameters (Increase available generation space to avoid truncation) ---
CTX_WINDOW_CANDIDATES   = [12288, 16384, 8192]  # T4/Q4_K_M usually handles 12k–16k; fallback to 8192
ctx_window              = CTX_WINDOW_CANDIDATES[-1]  # Runtime picks the first successful candidate
map_max_new_tokens      = 800   # Segment output upper bound (~550-800 chars)
map_repeat_penalty      = 1.10  # Tunable repeat penalty for map stage
reduce_repeat_penalty   = 1.10  # Tunable repeat penalty for reduce stage
reduce_max_new_tokens   = 1500  # Summary output upper bound (~1k-1.3k chars)
temperature             = 0.5
top_p                   = 0.9
repeat_penalty          = 1.05


# --- Summary Segmentation Heuristics ---
SEMANTIC_VAD_PRESETS = {
    "Conservative": 1.2,  # 1.2s: keep more context for safety (保守)
    "Aggressive":   0.8,  # 0.8s: quicker resets to dodge loop traps (積極)
}
SELECTED_VAD_SILENCE_PRESET = "Aggressive"  # Default preset tuned for repetitive loop mitigation
SEMANTIC_PAUSE_THRESHOLD = SEMANTIC_VAD_PRESETS.get(SELECTED_VAD_SILENCE_PRESET, 0.8)
SEMANTIC_MIN_TOKENS      = 192   # Minimum tokens before we allow punctuation-based splits (tighter for Chinese)
SEMANTIC_MAX_CHARS       = 1800  # Safety valve to avoid overly long segments with no punctuation
SLIDING_OVERLAP_TOKENS   = 200   # Tokens preserved between neighbouring summary windows
SEMANTIC_FORCE_FLUSH_LINES = 40   # Force flush after 40 lines without punctuation
SEMANTIC_FORCE_FLUSH_SECONDS = 120.0  # Force flush if segment spans ≥120s


tokenizer_config_data = None
harmony_chat_formatter = None

def ensure_harmony_formatter():
    """Ensure the Harmony chat formatter is available for GPT-OSS prompts."""
    global harmony_chat_formatter
    if harmony_chat_formatter is not None:
        return harmony_chat_formatter
    try:
        from llama_cpp.llama_chat_format import (
            hf_tokenizer_config_to_chat_formatter,
            Jinja2ChatFormatter,
        )
    except Exception as exc:
        raise RuntimeError("llama_cpp Harmony chat helpers are unavailable") from exc

    formatter = None
    if tokenizer_config_data:
        try:
            formatter = hf_tokenizer_config_to_chat_formatter(
                tokenizer_config_data,
                add_generation_prompt=True,
            )
        except Exception as exc:
            if DEBUG_MODE:
                print("  ✗ Failed to initialize Harmony formatter from tokenizer_config.json:", exc)

    if formatter is None:
        template = None
        if 'llm' in globals() and hasattr(llm, 'metadata'):
            template = llm.metadata.get('tokenizer.chat_template')
        if template:
            try:
                bos_token = (
                    llm.metadata.get('tokenizer.bos_token')
                    or llm.metadata.get('tokenizer.ggml.bos_token')
                    or llm.detokenize([llm.token_bos()], special=True).decode('utf-8', errors='ignore')
                )
                eos_token = (
                    llm.metadata.get('tokenizer.eos_token')
                    or llm.metadata.get('tokenizer.ggml.eos_token')
                    or llm.detokenize([llm.token_eos()], special=True).decode('utf-8', errors='ignore')
                )
                formatter = Jinja2ChatFormatter(
                    template,
                    eos_token=eos_token,
                    bos_token=bos_token,
                    stop_token_ids=[llm.token_eos()],
                )
            except Exception as exc:
                if DEBUG_MODE:
                    print("  ✗ Failed to build Harmony formatter from GGUF metadata:", exc)

    if formatter is None:
        raise RuntimeError("Harmony chat formatter could not be prepared")

    harmony_chat_formatter = formatter
    return harmony_chat_formatter




# =========================================================
# End of Developer Options
# =========================================================


# [3/8] Install Dependencies
# Combine installation steps from both original cells
if DEBUG_MODE: print("[Install] faster-whisper / yt-dlp / soundfile / opencc / srt / huggingface_hub / llama-cpp-python ...")

def pip_install(pkgs, extra_args=None, env=None):
    cmd = [sys.executable, "-m", "pip", "install", "--upgrade"]
    if extra_args:
        cmd += extra_args
    cmd += pkgs
    return sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True, env=env)

import importlib
import importlib.util
# Install common dependencies first
common_missing = []
if importlib.util.find_spec("srt") is None:
    common_missing.append("srt>=3.5.3")
if importlib.util.find_spec("huggingface_hub") is None:
    common_missing.append("huggingface_hub>=0.23.0")
if importlib.util.find_spec("soundfile") is None:
    common_missing.append("soundfile")
if importlib.util.find_spec("opencc") is None:
    common_missing.append("opencc-python-reimplemented")
if importlib.util.find_spec("jinja2") is None:
    common_missing.append("jinja2>=3.1.0")

if common_missing:
    if DEBUG_MODE: print("→ Installing common missing packages:", ", ".join(common_missing))
    r = pip_install(common_missing)
    if r.returncode != 0:
        if DEBUG_MODE: print(r.stdout)
        raise RuntimeError("基礎依賴安裝失敗，請重啟執行階段後重試。")

# Install faster-whisper and yt-dlp separately as they were in the first cell
if importlib.util.find_spec("faster_whisper") is None:
    if DEBUG_MODE: print("→ Installing missing package: faster-whisper yt-dlp")
    r = pip_install(["faster-whisper", "yt-dlp"])
    if r.returncode != 0:
        if DEBUG_MODE: print(r.stdout)
        raise RuntimeError("faster-whisper / yt-dlp 安裝失敗，請重啟執行階段後重試。")

# Import external packages after ensuring installation
import soundfile as sf
from faster_whisper import WhisperModel
from opencc import OpenCC
import srt as _srt  # Import srt as _srt to avoid name conflict later with the module itself
from huggingface_hub import snapshot_download
def suggest_runtime_reset():
    print("\n🧹 建議動作（Colab）")
    print("1) 依序：『執行階段 Runtime』 → 『刪除執行階段/還原出廠設定 Factory reset runtime』\n2) 重新執行本 Notebook（從掛載雲端硬碟那格開始）", flush=True)

def run_cmd(cmd:list, check=True):
    if DEBUG_MODE: print("  $", " ".join(cmd))
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if DEBUG_MODE and p.stdout: sys.stdout.write(p.stdout)
    if check and p.returncode != 0:
        raise RuntimeError(f"命令失敗：{' '.join(cmd)}")
    return p

def is_youtube_url(s:str)->bool:
    return isinstance(s, str) and ("youtu.be" in s or "youtube.com" in s)
def is_http_url(s:str)->bool:
    return isinstance(s, str) and s.lower().startswith("http")
def to_abs_mydrive(p:str)->Path:
    return (Path(p) if p.startswith("/") else (ROOT / p)).resolve()
def fmt_ts_srt(t:float)->str:
    h = int(t//3600); m = int((t%3600)//60); s = t - h*3600 - m*60
    return f"{h:02d}:{m:02d}:{int(s):02d},{int(round((s-int(s))*1000)):03d}"
def verify_wav_ok(path: Path)->bool:
    try:
        info = sf.info(str(path))
        return info.samplerate > 0 and info.channels in (1, 2)
    except Exception:
        return False

# OpenCC converter setup
def build_opencc_pipeline(choice:str):
    if choice.startswith("臺灣"):
        return [OpenCC('s2t'), OpenCC('t2tw')]
    if choice.startswith("香港"):
        return [OpenCC('s2t'), OpenCC('t2hk')]
    if choice.startswith("大陸"):
        return [OpenCC('t2s')]
    return []  # Disable

def apply_opencc(text:str, pipeline)->str:
    for cc in pipeline:
        text = cc.convert(text)
    return text

def ytdl(yturl:str)->Path:
    tmp = Path("/tmp/dl"); tmp.mkdir(parents=True, exist_ok=True)
    for x in tmp.glob("*"):
        try: x.unlink()
        except: shutil.rmtree(x, ignore_errors=True)
    ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out = tmp / f"downloaded_{ts}.mp4"
    if DEBUG_MODE: print("[Download] Getting YouTube video ...")
    # Use sp.run instead of subprocess.run directly
    cmd = ["yt-dlp", "-f", "mp4", "-o", str(tmp / "%(title)s.%(ext)s")]
    if youtube_cookies_txt_path.strip():
        cookies_abs = to_abs_mydrive(youtube_cookies_txt_path.strip())
        if cookies_abs.exists():
            cmd += ["--cookies", str(cookies_abs)]
        else:
            if DEBUG_MODE: print(f"⚠️ 找不到 cookies 檔：{cookies_abs}（改為不帶 cookies）")
    cmd.append(yturl)
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if DEBUG_MODE and p.stdout: sys.stdout.write(p.stdout)
    if p.returncode != 0:
        if "Sign in to confirm" in (p.stdout or ""):
            print()
            print("❗YouTube 要求登入/驗證，請提供 cookies 或先自行下載到雲端硬碟。")
        print("🔄 若多次失敗，請刪除執行階段並重啟後重試。")
        suggest_runtime_reset()
        raise RuntimeError("yt-dlp 下載失敗")
    files = list(tmp.glob("*"))
    if not files:
        print("🔄 下載為空，建議刪除執行階段再重試。")
        suggest_runtime_reset()
        raise FileNotFoundError("YouTube 下載失敗：/tmp/dl 為空")
    f = files[0]
    if save_video_to_google_drive:
        shutil.copy2(f, WHISPER_DIR / f.name)
    return f

def http_dl(url:str)->Path:
    tmp = Path("/tmp/dl"); tmp.mkdir(parents=True, exist_ok=True)
    for x in tmp.glob("*"):
        try: x.unlink()
        except: shutil.rmtree(x, ignore_errors=True)
    ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out = tmp / f"downloaded_{ts}.mp4"
    if DEBUG_MODE: print("[Download] Getting HTTP(S) video ...")
    run_cmd(["curl", "-L", "-o", str(out), url])
    if save_video_to_google_drive:
        shutil.copy2(out, WHISPER_DIR / out.name)
    return out

# Extract audio: ffmpeg -> 16k/mono WAV
def ffmpeg_extract_wav(in_path:Path, out_wav:Path, sr=16000):
    cmd = ["ffmpeg","-y","-i",str(in_path),"-vn","-ac","1","-ar",str(sr),"-f","wav",str(out_wav)]
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if p.returncode != 0:
        if DEBUG_MODE: print(p.stdout)
        raise RuntimeError("ffmpeg 轉 WAV 失敗")

# CPU Denoising: ffmpeg afftdn
def ffmpeg_afftdn(in_wav: Path, out_wav: Path, noise_floor_db=DENOISE_NOISE_FLOOR_DB):
    cmd = ["ffmpeg","-y","-i",str(in_wav),"-af",f"afftdn=nf={noise_floor_db}",
           "-ac","1","-ar","16000","-f","wav",str(out_wav)]
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if p.returncode != 0:
        if DEBUG_MODE: print(p.stdout)
        raise RuntimeError("ffmpeg afftdn 失敗")

# Safeguard: Repack WAV header if format is strange
def ffmpeg_repack_wav(in_wav: Path, out_wav: Path, sr=16000):
    cmd = ["ffmpeg","-y","-i",str(in_wav),"-acodec","pcm_s16le","-ac","1","-ar",str(sr),str(out_wav)]
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if p.returncode != 0:
        if DEBUG_MODE: print(p.stdout)
        raise RuntimeError("ffmpeg 重包 WAV 失敗")

# [4/8] Parse Source (Transcription) - Uses 'filename' and 'save_video_to_google_drive'
if DEBUG_MODE: print("[4/8] Parsing input source ...")
try:
    if is_youtube_url(filename):
        src_path = ytdl(filename); out_base_dir = WHISPER_DIR
    elif is_http_url(filename):
        src_path = http_dl(filename); out_base_dir = WHISPER_DIR
    else:
        src_path = to_abs_mydrive(filename)
        if not src_path.exists(): raise FileNotFoundError(f"找不到檔案：{src_path}")
        out_base_dir = src_path.parent
except Exception as e:
    print(f"\n⛔ 來源解析/下載失敗：{e}")
    print("🔄 請刪除執行階段並重新啟動後重跑。"); suggest_runtime_reset(); raise

print(f"→ 來源檔：{src_path}")
print(f"→ 輸出資料夾：{out_base_dir}")

# [5/8] Extract Audio & CPU Denoising (Transcription) - Uses 'denoise_method' and 'DENOISE_NOISE_FLOOR_DB'
AUDIO_16K = Path("/tmp/audio_16k.wav")
if DEBUG_MODE: print("[5/8] Extracting audio (ffmpeg → 16k/mono WAV) ...")
ffmpeg_extract_wav(src_path, AUDIO_16K, sr=16000)

if denoise_method.startswith("afftdn"):
    if DEBUG_MODE: print("[5.5/8] Denoising (ffmpeg afftdn, CPU) ...")
    DENOISED = Path("/tmp/audio_16k_denoised.wav")
    ffmpeg_afftdn(AUDIO_16K, DENOISED, noise_floor_db=DENOISE_NOISE_FLOOR_DB)
    denoised_audio = DENOISED if verify_wav_ok(DENOISED) else AUDIO_16K
else:
    denoised_audio = AUDIO_16K

if not verify_wav_ok(denoised_audio):
    if DEBUG_MODE: print("  - 音訊格式異常；嘗試重包 WAV ...")
    FIXED = Path("/tmp/audio_16k_fixed.wav")
    ffmpeg_repack_wav(denoised_audio, FIXED, sr=16000)
    denoised_audio = FIXED

if DEBUG_MODE: print(f"→ 最終輸入音訊：{denoised_audio}")

# [6/8] Load faster-whisper (GPU enforced) - Uses 'model_size'
if DEBUG_MODE: print("[6/8] Loading faster-whisper model (GPU) ...")
device = "cuda"  # Enforce GPU
model = None; last_err = None
compute_type_candidates = ["float16", "int8_float16", "int8"]
chosen_compute_type = None
for ctype in compute_type_candidates:
    try:
        if DEBUG_MODE: print(f"  - Trying compute_type={ctype}")
        model = WhisperModel(model_size, device=device, compute_type=ctype)
        chosen_compute_type = ctype
        if DEBUG_MODE: print("  - Model loaded successfully")
        break
    except Exception as e:
        last_err = e
        if DEBUG_MODE: print(f"  - Load failed: {e}")
if model is None:
    print()
    print("⛔ GPU 模型載入失敗。請確認『變更執行階段類型』選了 GPU（T4/A100），或刪除執行階段後重試。")
    suggest_runtime_reset()
    raise RuntimeError(f"無法載入模型：{last_err}")
print(f"→ faster-whisper compute_type: {chosen_compute_type}")

gc.collect()  # Clean up before transcription (safety)

# [7/8] Transcribe (GPU; real-time progress per segment) - Uses 'language_code', 'TRANSCRIPTION_BEAM_SIZE_PRIMARY', 'TRANSCRIPTION_CHUNK_LENGTH_PRIMARY', 'TRANSCRIPTION_BEAM_SIZE_FALLBACK', 'TRANSCRIPTION_CHUNK_LENGTH_FALLBACK'
if DEBUG_MODE: print(f"[7/8] Starting transcription (GPU: beam={TRANSCRIPTION_BEAM_SIZE_PRIMARY} / chunk={TRANSCRIPTION_CHUNK_LENGTH_PRIMARY}s / VAD+no-repeat) ...")

SELECTED_VAD_SILENCE_MS = int(max(0.2, SEMANTIC_PAUSE_THRESHOLD) * 1000)
LOOP_SENTINEL_MIN_REPEAT = 5          # Trigger guard if the same line repeats >=5 times
LOOP_SENTINEL_TOLERANCE = 0.35        # Allow ±350ms drift when checking 1s increments
LOOP_ESCALATE_REPEAT_THRESHOLD = 8    # Require ≥8 repeats before escalating guardrails
LOOP_ESCALATE_DURATION_SECONDS = 10.0 # Require loops to span ≥10s when escalating by duration
LOOP_ESCALATE_EVENTS_PER_MIN = 2      # Escalate when ≥2 loop events occur within the same minute


def transcribe_gpu(
    _beam: int = TRANSCRIPTION_BEAM_SIZE_PRIMARY,
    _chunk: int = TRANSCRIPTION_CHUNK_LENGTH_PRIMARY,
    *,
    cond_prev: bool = True,
    compression_threshold: float = 1.35,
    no_repeat_ngram: int = 2,
):
    return model.transcribe(
        str(denoised_audio),
        task="transcribe",
        language=language_code,
        temperature=0.0,
        condition_on_previous_text=cond_prev,
        compression_ratio_threshold=compression_threshold,
        log_prob_threshold=-0.5,
        no_speech_threshold=0.6,
        beam_size=_beam,
        chunk_length=_chunk,
        vad_filter=True,
        vad_parameters={"min_silence_duration_ms": SELECTED_VAD_SILENCE_MS},
        no_repeat_ngram_size=no_repeat_ngram,
        word_timestamps=False,
    )


def run_transcription_attempt(
    label: str,
    *,
    cond_prev: bool,
    compression_threshold: float,
    no_repeat_ngram: int,
):
    try:
        seg_iter, info = transcribe_gpu(
            cond_prev=cond_prev,
            compression_threshold=compression_threshold,
            no_repeat_ngram=no_repeat_ngram,
        )
    except Exception as exc:
        if DEBUG_MODE:
            print(
                f"  - Transcription attempt \'{label}\' failed ({exc})\n"
                f"    → Retrying with fallback beam/chunk (beam={TRANSCRIPTION_BEAM_SIZE_FALLBACK}, "
                f"chunk={TRANSCRIPTION_CHUNK_LENGTH_FALLBACK})"
            )
        seg_iter, info = transcribe_gpu(
            _beam=TRANSCRIPTION_BEAM_SIZE_FALLBACK,
            _chunk=TRANSCRIPTION_CHUNK_LENGTH_FALLBACK,
            cond_prev=cond_prev,
            compression_threshold=compression_threshold,
            no_repeat_ngram=no_repeat_ngram,
        )

    duration = float(getattr(info, "duration", 0.0) or 0.0)
    if duration <= 0:
        duration = 1.0

    segments: List[Tuple[float, float, str]] = []
    filtered: List[Tuple[float, float, str]] = []
    last_text: Optional[str] = None
    last_start: Optional[float] = None
    last_end: Optional[float] = None
    repeat_streak = 1
    max_repeat_streak = 1
    loop_anchor_start: Optional[float] = None
    loop_events: List[dict] = []
    raw_speech_seconds = 0.0
    filtered_speech_seconds = 0.0

    stats = {
        "duration": duration,
        "speech_coverage": 0.0,
        "raw_segment_seconds": 0.0,
        "vad_removed_seconds": 0.0,
        "post_filter_removed_seconds": 0.0,
        "compression_hits": 0,
        "max_repeat_streak": 1,
        "loop_events": [],
        "loop_escalate_reasons": [],
        "loop_events_per_minute": {},
        "language": getattr(info, "language", "未知"),
        "language_probability": float(getattr(info, "language_probability", 0.0) or 0.0),
        "condition_on_previous_text": cond_prev,
        "compression_threshold": compression_threshold,
        "no_repeat_ngram_size": no_repeat_ngram,
    }

    for seg in seg_iter:
        text = seg.text.strip()
        pct = int(min(100, round((float(seg.end) / duration) * 100)))
        print(f"[{pct:3d}%] {fmt_ts_srt(seg.start)} → {fmt_ts_srt(seg.end)}  {text}", flush=True)

        seg_tuple = (float(seg.start), float(seg.end), text)
        segments.append(seg_tuple)

        keep = True
        seg_dur = max(0.0, float(seg_tuple[1] - seg_tuple[0]))
        raw_speech_seconds += seg_dur
        if (
            seg_dur < FILTER_MIN_DURATION_SHORT
            and getattr(seg, "avg_logprob", None) is not None
            and seg.avg_logprob < FILTER_AVG_LOGPROB_THRESHOLD
        ):
            keep = False
        if (
            seg_dur < FILTER_MIN_DURATION_SPEECH_PROB
            and getattr(seg, "no_speech_prob", None) is not None
            and seg.no_speech_prob > FILTER_NO_SPEECH_PROB_THRESHOLD
        ):
            keep = False
        if keep:
            filtered.append(seg_tuple)
            filtered_speech_seconds += seg_dur

        compression_ratio = getattr(seg, "compression_ratio", None)
        if compression_ratio is not None and compression_ratio >= compression_threshold:
            stats["compression_hits"] += 1

        if last_text is not None and text == last_text:
            gap = float(seg_tuple[0] - (last_start or seg_tuple[0]))
            if abs(gap - 1.0) <= LOOP_SENTINEL_TOLERANCE:
                repeat_streak += 1
                if loop_anchor_start is None:
                    loop_anchor_start = last_start
                if repeat_streak == LOOP_SENTINEL_MIN_REPEAT:
                    event_start = (
                        loop_anchor_start
                        if loop_anchor_start is not None
                        else (last_start if last_start is not None else seg_tuple[0])
                    )
                    event = {
                        "label": label,
                        "text": text,
                        "start": event_start,
                        "end": seg_tuple[1],
                        "count": repeat_streak,
                    }
                    loop_events.append(event)
                    print(
                        f"⚠️ De-loop sentinel triggered ({label}): {fmt_ts_srt(event_start)} → {fmt_ts_srt(seg_tuple[1])} | repeats={repeat_streak}"
                    )
                elif loop_events:
                    loop_events[-1]["end"] = seg_tuple[1]
                    loop_events[-1]["count"] = repeat_streak
            else:
                if repeat_streak >= LOOP_SENTINEL_MIN_REPEAT and loop_events:
                    loop_events[-1]["end"] = last_end if last_end is not None else seg_tuple[1]
                    loop_events[-1]["count"] = repeat_streak
                repeat_streak = 1
                loop_anchor_start = None
        else:
            if repeat_streak >= LOOP_SENTINEL_MIN_REPEAT and loop_events:
                loop_events[-1]["end"] = last_end if last_end is not None else seg_tuple[1]
                loop_events[-1]["count"] = repeat_streak
            repeat_streak = 1
            loop_anchor_start = None

        max_repeat_streak = max(max_repeat_streak, repeat_streak)
        last_text = text
        last_start = seg_tuple[0]
        last_end = seg_tuple[1]

    if repeat_streak >= LOOP_SENTINEL_MIN_REPEAT and loop_events:
        loop_events[-1]["end"] = last_end if last_end is not None else loop_events[-1]["end"]
        loop_events[-1]["count"] = repeat_streak

    stats["raw_segment_seconds"] = max(0.0, raw_speech_seconds)
    stats["speech_coverage"] = max(0.0, filtered_speech_seconds)
    stats["vad_removed_seconds"] = max(0.0, duration - raw_speech_seconds)
    stats["post_filter_removed_seconds"] = max(0.0, raw_speech_seconds - filtered_speech_seconds)
    stats["max_repeat_streak"] = max(max_repeat_streak, repeat_streak)

    minute_counts = defaultdict(int)
    escalate_reasons = []
    for event in loop_events:
        start_val = float(event.get("start", 0.0) or 0.0)
        end_val = float(event.get("end", start_val) or start_val)
        duration_val = max(0.0, end_val - start_val)
        event["duration"] = duration_val
        bucket = int(max(0.0, start_val) // 60)
        minute_counts[bucket] += 1
        if (
            event.get("count", 0) >= LOOP_ESCALATE_REPEAT_THRESHOLD
            and duration_val >= LOOP_ESCALATE_DURATION_SECONDS
        ):
            escalate_reasons.append(
                f"repeat_streak={event.get('count', 0)} lasting {duration_val:.1f}s near {fmt_ts_srt(start_val)}"
            )

    for minute, count in sorted(minute_counts.items()):
        if count >= LOOP_ESCALATE_EVENTS_PER_MIN:
            window_start = minute * 60
            window_end = window_start + 60
            escalate_reasons.append(
                f"{count} loop events within {fmt_ts_srt(window_start)}–{fmt_ts_srt(window_end)}"
            )

    stats["loop_events"] = loop_events
    stats["loop_escalate_reasons"] = escalate_reasons
    stats["loop_events_per_minute"] = dict(minute_counts)

    return {
        "segments": segments,
        "filtered": filtered,
        "info": info,
        "stats": stats,
    }


transcription_attempts = [

    ("Baseline", dict(cond_prev=True,  compression_threshold=1.35, no_repeat_ngram=2)),
    ("Reset",    dict(cond_prev=False, compression_threshold=1.35, no_repeat_ngram=2)),
    ("Reinforce",dict(cond_prev=False, compression_threshold=1.30, no_repeat_ngram=3)),
]

segments: List[Tuple[float, float, str]] = []
filtered: List[Tuple[float, float, str]] = []
info = None
transcription_stats = {}

for idx, (label, params) in enumerate(transcription_attempts, 1):
    if idx > 1:
        print(
            f"→ De-loop retry {idx}/{len(transcription_attempts)}: {label} "
            f"(condition_on_previous_text={params['cond_prev']}, no_repeat_ngram_size={params['no_repeat_ngram']}, "
            f"compression_ratio_threshold={params['compression_threshold']})"
        )
    attempt = run_transcription_attempt(label, **params)
    segments = attempt["segments"]
    filtered = attempt["filtered"]
    info = attempt["info"]
    transcription_stats = attempt["stats"]
    loop_events = transcription_stats.get("loop_events", [])
    escalate_reasons = transcription_stats.get("loop_escalate_reasons", [])
    if not loop_events:
        break
    if not escalate_reasons:
        print("  ↳ De-loop sentinel noted repeats but below escalation thresholds; keeping current parameters.")
        break
    if idx < len(transcription_attempts):
        print("  ↳ De-loop sentinel escalation triggered due to:")
        for reason in escalate_reasons:
            print(f"     • {reason}")
        print("    → Retrying with tighter decoding guardrails ...")
    else:
        print("  ↳ Warning: loop persisted despite all guardrails.")

if DEBUG_MODE:
    print(
        f"  - Detected language: {transcription_stats.get('language', '未知')} "
        f"(p={transcription_stats.get('language_probability', 0.0):.2f})"
    )
    print(f"  - Audio length: {transcription_stats.get('duration', 0.0):.2f}s")


# ---- OpenCC Normalization (for output text) ---- - Uses 'text_postprocess'
pipeline = build_opencc_pipeline(text_postprocess)
def norm(txt: str) -> str:
    return apply_opencc(txt, pipeline) if pipeline else txt

# [8/8] Output (text after OpenCC) - Uses 'out_base_dir' (derived from 'filename')
print("[8/8] 輸出 SRT / TXT ...")
# Determine the output directory for transcription based on input type
# If input is a network source, output to WHISPER_DIR
# If input is a local file, output to the same directory as the input file
if is_youtube_url(filename) or is_http_url(filename):
    out_base_dir = WHISPER_DIR
else:
    src_path_abs = to_abs_mydrive(filename)
    out_base_dir = src_path_abs.parent

# Create the transcription output directory if it doesn't exist
out_dir = out_base_dir
out_dir.mkdir(exist_ok=True, parents=True)

# Determine the stem from the original source file path
stem = Path(src_path).stem
SRT = out_dir / f"{stem}.srt"
TXT = out_dir / f"{stem}.txt"

with open(SRT, "w", encoding="utf-8") as f:
    for i, (seg_start, seg_end, seg_text) in enumerate(filtered, 1):
        text_out = norm(seg_text.strip())
        f.write(f"{i}\n{fmt_ts_srt(seg_start)} --> {fmt_ts_srt(seg_end)}\n{text_out}\n\n")

with open(TXT, "w", encoding="utf-8") as f:
    for _, _, seg_text in filtered:
        f.write(norm(seg_text.strip()) + "\n")

print(f"→ 完成！\n  SRT: {SRT}\n  TXT: {TXT}")
print("[Transcription Metrics]")
print(f"→ faster-whisper compute_type: {chosen_compute_type}")
print(f"→ VAD preset: {SELECTED_VAD_SILENCE_PRESET} (silence ≥ {SEMANTIC_PAUSE_THRESHOLD:.2f}s → {SELECTED_VAD_SILENCE_MS} ms)")
print(
    f"→ Language: {transcription_stats.get('language', '未知')} "
    f"(p={transcription_stats.get('language_probability', 0.0):.2f})"
)
print(
    f"→ Audio length: {transcription_stats.get('duration', 0.0):.2f}s；"
    f"Raw speech: {transcription_stats.get('raw_segment_seconds', 0.0):.2f}s；"
    f"VAD removed: {transcription_stats.get('vad_removed_seconds', 0.0):.2f}s；"
    f"Post-filter removed: {transcription_stats.get('post_filter_removed_seconds', 0.0):.2f}s；"
    f"Speech kept: {transcription_stats.get('speech_coverage', 0.0):.2f}s"
)
print(f"→ compression_ratio_threshold hits: {transcription_stats.get('compression_hits', 0)} segments")
print(f"→ Max consecutive repeat streak: {transcription_stats.get('max_repeat_streak', 1)}")
if transcription_stats.get('loop_events'):
    for idx_evt, evt in enumerate(transcription_stats.get('loop_events', []), 1):
        start_ts = fmt_ts_srt(evt.get('start', 0.0))
        end_ts = fmt_ts_srt(evt.get('end', evt.get('start', 0.0)))
        duration_val = evt.get('duration', max(0.0, (evt.get('end', 0.0) or 0.0) - (evt.get('start', 0.0) or 0.0)))
        print(
            f"   ⚠️ Event {idx_evt}: {start_ts} → {end_ts} "
            f"(duration≈{duration_val:.1f}s, repeats={evt.get('count', LOOP_SENTINEL_MIN_REPEAT)}, tag={evt.get('label')})"
        )
    if transcription_stats.get('loop_escalate_reasons'):
        for reason in transcription_stats.get('loop_escalate_reasons', []):
            print(f"   ↳ escalation note: {reason}")
else:
    print("→ No de-loop sentinel events detected.")

# Release model (release GPU memory)
try: del model
except: pass
gc.collect()
if DEBUG_MODE: print("→ Model released; can run again directly if needed.")


# ===== Summarization Logic Starts Here =====
# Use SRT from transcription step for summarization
summary_srt_path_abs = SRT
assert summary_srt_path_abs.exists(), f"SRT 檔不存在：{summary_srt_path_abs}"

# ===== Summary 1/6) Check GPU and Install Dependencies (llama-cpp-python specific) =====
# llama-cpp-python installation logic - Keep this separate as it has specific CUDA requirements
# Moved this section to just before reading the SRT for summarization
if DEBUG_MODE: print("[Summary 1/6] Checking GPU and installing llama-cpp-python ...")

def detect_cuda_tag():
    try:
        out = sp.check_output(["nvidia-smi"], text=True)
        version_token = None
        for line in out.splitlines():
            if "CUDA Version" in line:
                _, _, tail = line.partition("CUDA Version:")
                cleaned = ''.join(ch for ch in tail if ch.isdigit() or ch == '.')
                if cleaned:
                    version_token = cleaned
                    break
        if not version_token:
            return "cu124"
        parts = version_token.split(".")
        major = int(parts[0]) if parts and parts[0].isdigit() else 0
        minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
        if major > 12 or (major == 12 and minor >= 5):
            return "cu125"
        if major == 12 and minor >= 0:
            return "cu124"
        if major == 11 and minor >= 8:
            return "cu118"
        return "cu117"
    except Exception:
        return "cu124"

cuda_tag = detect_cuda_tag()
if DEBUG_MODE: print(f"GPU 0: Detected CUDA version tag {cuda_tag}")

def try_import_llama():
    try:
        from llama_cpp import Llama
        return Llama
    except ModuleNotFoundError:
        return None

Llama = try_import_llama()
if Llama is None:
    # Keep your existing installation strategy: extra-index -> fallback to source compilation on failure
    candidates = [cuda_tag, "cu125", "cu124", "cu122", "cu121"]
    ok = False
    for tag in candidates:
        idx = f"https://abetlen.github.io/llama-cpp-python/whl/{tag}"
        if DEBUG_MODE: print(f"→ Attempting to install llama-cpp-python ({tag}) ...")
        r = pip_install(["llama-cpp-python"], extra_args=["--extra-index-url", idx])
        if r.returncode == 0:
            Llama = try_import_llama()
            if Llama is not None:
                ok = True
                break
        else:
            if DEBUG_MODE: print("  ✗ Installation failed (summary):", "\n".join(r.stdout.splitlines()[-5:]))
    if not ok:
        if DEBUG_MODE: print("→ Pre-compiled wheels not available, switching to 'source compilation (CUDA=ON)' ... (takes longer)")
        try:
            import ninja # noqa: F401 # Import ninja to check if installed
        except ModuleNotFoundError:
            if DEBUG_MODE: print("→ Installing missing package: ninja")
            r = pip_install(["ninja"])
            if r.returncode != 0:
                if DEBUG_MODE: print(r.stdout)
                raise RuntimeError("安裝 ninja 失敗。請重啟後重試。")
        env = os.environ.copy()
        env["CMAKE_ARGS"] = "-DGGML_CUDA=on -DLLAMA_CUBLAS=on"
        env["FORCE_CMAKE"] = "1"
        r = pip_install(["llama-cpp-python"], env=env)
        if r.returncode != 0:
            if DEBUG_MODE: print(r.stdout)
            raise RuntimeError("無法安裝 GPU 版 llama-cpp-python。")
        Llama = try_import_llama()


if DEBUG_MODE: print("[Summary 2/6] Reading SRT ...")
with open(summary_srt_path_abs, "r", encoding="utf-8") as f:
    srt_text = f.read()
subs = list(_srt.parse(srt_text)) # Use _srt as srt module was imported as _srt
def td2s(td): return td.total_seconds()
segments = []
for it in subs:
    txt = it.content.strip()
    if not txt:
        continue
    segments.append((td2s(it.start), td2s(it.end), txt))

def compress_repetitive_segments(
    segs: List[Tuple[float, float, str]],
    *,
    short_len: int = 2,
    long_duration: float = 60.0,
    long_count: int = 30,
) -> Tuple[List[Tuple[float, float, str]], List[dict]]:
    """Collapse repetitive floods in subtitles while keeping context samples."""
    if not segs:
        return [], []
    compressed: List[Tuple[float, float, str]] = []
    reports: List[dict] = []
    i = 0
    while i < len(segs):
        s, e, text = segs[i]
        stripped = text.strip()
        run_start = i
        run_end = i + 1
        total_duration = max(0.0, e - s)
        while run_end < len(segs) and segs[run_end][2].strip() == stripped:
            total_duration += max(0.0, segs[run_end][1] - segs[run_end][0])
            run_end += 1
        run_len = run_end - run_start
        if run_len > 1 and len(stripped) <= short_len:
            marker = f"[重複 x{run_len}: {stripped}]"
            start_ts = segs[run_start][0]
            end_ts = segs[run_end - 1][1]
            compressed.append((start_ts, end_ts, marker))
            report = {
                "text": stripped,
                "count": run_len,
                "duration": total_duration,
                "start": start_ts,
                "end": end_ts,
                "samples": [],
            }
            if total_duration >= long_duration and run_len >= long_count:
                stride = max(5, min(10, max(1, run_len // 10)))
                sample_indices = list(range(run_start, run_end, stride))
                # Always keep the final segment for context
                if sample_indices[-1] != run_end - 1:
                    sample_indices.append(run_end - 1)
                for idx in sample_indices[:10]:
                    sample_seg = segs[idx]
                    sample_txt = sample_seg[2].strip()
                    compressed.append((sample_seg[0], sample_seg[1], f"[重複樣本] {sample_txt}"))
                    report["samples"].append({
                        "index": idx,
                        "start": sample_seg[0],
                        "end": sample_seg[1],
                        "text": sample_txt,
                    })
            reports.append(report)
            i = run_end
            continue
        compressed.append((s, e, text))
        i += 1
    return compressed, reports

segments, repetition_reports = compress_repetitive_segments(segments)
if repetition_reports:
    print(f"→ 偵測到 {len(repetition_reports)} 組重複洪水，已壓縮並保留樣本")

total_secs = (segments[-1][1] - segments[0][0]) if segments else 0
if DEBUG_MODE: print(f"→ Number of subtitle segments: {len(segments)}；Video length (est): {total_secs/60:.1f} minutes")


# ===== Summary 3/6) Download and Load GGUF Model (Summary) - Uses summary model parameters (REPO_ID, GGUF_FILE, ctx_window, etc.)
# Moved this section to just after installing llama-cpp-python
if DEBUG_MODE: print("[Summary 3/6] Loading GPT-OSS-20B (GGUF, CUDA) ...")
local_repo = snapshot_download(REPO_ID, allow_patterns=[GGUF_FILE, "tokenizer_config.json"])
gguf_path = str(Path(local_repo)/GGUF_FILE)

tokenizer_config_path = Path(local_repo)/"tokenizer_config.json"
if tokenizer_config_path.exists():
    try:
        tokenizer_config_data = json.loads(tokenizer_config_path.read_text())
        if DEBUG_MODE: print("→ Loaded tokenizer_config.json (Harmony template)")
    except Exception as exc:
        if DEBUG_MODE: print("  ✗ Failed to parse tokenizer_config.json:", exc)

llm = None
selected_ctx_window = None
ctx_errors = []
for ctx_candidate in CTX_WINDOW_CANDIDATES:
    try:
        if DEBUG_MODE:
            print(f"  - Trying ctx_window={ctx_candidate}")
        llm = Llama(
            model_path=gguf_path,
            n_ctx=ctx_candidate,
            n_gpu_layers=-1,
            seed=0,
            logits_all=False,
            verbose=True  # Display the actual chat format used
        )
        selected_ctx_window = ctx_candidate
        break
    except Exception as exc:
        ctx_errors.append((ctx_candidate, exc))
        if DEBUG_MODE:
            print(f"  ✗ Failed ctx_window={ctx_candidate}: {exc}")
        gc.collect()

if llm is None:
    msgs = ", ".join(f"{cand}→{err}" for cand, err in ctx_errors[-3:])
    raise RuntimeError(f"無法載入 GGUF（ctx candidates={CTX_WINDOW_CANDIDATES}）：{msgs}")

ctx_window = selected_ctx_window or ctx_window
print(f"→ Selected ctx_window: {ctx_window}")
if DEBUG_MODE:
    print("→ Model loaded successfully (GPU)")




try:
    ensure_harmony_formatter()
    if DEBUG_MODE: print("→ Harmony formatter prepared")
except Exception as exc:
    raise RuntimeError(f"Failed to prepare Harmony formatter: {exc}")


# ===== Summary 4/6) Token-aware Segmentation (Summary) - Uses ctx_window, map_max_new_tokens, prompt_overhead
if DEBUG_MODE: print("[Summary 4/6] Generating segments (token-aware; single segment ≤ safety limit) ...")

def count_tokens_text(text: str) -> int:
    # Check if llm is initialized before using it
    if 'llm' not in globals() or llm is None:
         raise RuntimeError("LLM model is not loaded. Cannot count tokens.")
    return len(llm.tokenize(text.encode("utf-8")))

SYSTEM_INSTR = (
  "你是一個會議總結機器人。根據使用者提供的逐字稿（可能雜訊、重複、錯字），"
  "請去除雜訊與重複、嚴守事實、不腦補。遇到不明確資訊以「待補充／未明確」標註。"
  "輸出為 Markdown（繁體中文），不要輸出任何系統／思考標記。"
)

# — Segment Summary Prompt: More concise request, avoid verbosity and system language - Uses 'topic_hint'
MAP_USER_TMPL = textwrap.dedent("""\
主題（可留空）：{topic}

以下是逐字稿片段（非完整全文）：
{chunk}

請就此片段輸出「條列式重點摘要」（500–900 字，繁體中文），注意：
- 只寫最終內容，不要寫解題想法、不要出現任何系統提示或中英括號標記。
- 聚焦可驗證事實（時間、人物、任務、結論、未決事項、行動）。
- 結構：可用小標題＋項目符號，語句務必短、準確、無贅詞。
""")

# — Summary Prompt: Maintain your three-section output structure - Uses 'topic_hint'
REDUCE_USER_TMPL = textwrap.dedent("""\
主題（可留空）：{topic}

以下是所有片段的重點摘要彙整（仍可能有重疊）：
{maps}

請整合為一份會議筆記（Markdown，繁體）：
1) **整體提要**（3–6 句，避免冗言）
2) **章節要點（含時間脈絡）**：條列呈現，每點一行，可附粗略時間，**不得忽略任何片段**，每條尾端標註對應片段編號 (片段 i) 或時間。
3) **可執行重點**：具體待辦（每條以動詞開頭）
請確保所有片段至少納入 1–2 個重點，若資訊不足請註明「待補充」。
請只輸出最終筆記，不要出現系統或思考標記，不要加入未出現的新資訊。
""")

# Single segment token budget (reserve space for prompt and generation)
prompt_overhead = 700
chunk_target    = max(1024, min(2048, ctx_window - prompt_overhead - map_max_new_tokens))
SENTENCE_END_RE = re.compile(r"[。！？?!…]+[\"”』】）]*$")

def build_semantic_segments(raw_segments: List[Tuple[float, float, str]]) -> List[Tuple[float, float, str, int]]:
    semantic_segments: List[Tuple[float, float, str, int]] = []
    buffer: List[str] = []
    buffer_tokens = 0
    buffer_chars = 0
    buffer_line_count = 0
    lines_since_punct = 0
    start_time: Optional[float] = None
    buffer_end: Optional[float] = None
    last_end: Optional[float] = None

    def flush_buffer():
        nonlocal buffer, buffer_tokens, buffer_chars, buffer_line_count, lines_since_punct, start_time, buffer_end
        if not buffer:
            return
        text = "\n".join(buffer).strip()
        if text:
            semantic_segments.append((start_time or 0.0, buffer_end or start_time or 0.0, text, buffer_tokens))
        buffer = []
        buffer_tokens = 0
        buffer_chars = 0
        buffer_line_count = 0
        lines_since_punct = 0
        start_time = None
        buffer_end = None

    for s, e, txt in raw_segments:
        txt = txt.strip()
        if not txt:
            last_end = e
            continue

        if buffer and last_end is not None:
            gap = max(0.0, s - last_end)
            if gap >= SEMANTIC_PAUSE_THRESHOLD:
                flush_buffer()

        if not buffer:
            start_time = s

        buffer.append(txt)
        piece_tokens = count_tokens_text(txt)
        buffer_tokens += piece_tokens
        buffer_chars += len(txt)
        buffer_line_count += 1
        if SENTENCE_END_RE.search(txt):
            lines_since_punct = 0
        else:
            lines_since_punct += 1
        buffer_end = e
        last_end = e

        buffer_duration = 0.0
        if start_time is not None and buffer_end is not None:
            buffer_duration = max(0.0, buffer_end - start_time)

        should_flush = False
        if buffer_tokens >= chunk_target:
            should_flush = True
        elif buffer_tokens >= SEMANTIC_MIN_TOKENS and SENTENCE_END_RE.search(txt):
            should_flush = True
        elif buffer_chars >= SEMANTIC_MAX_CHARS and SENTENCE_END_RE.search(txt):
            should_flush = True
        elif lines_since_punct >= SEMANTIC_FORCE_FLUSH_LINES:
            should_flush = True
        elif buffer_duration >= SEMANTIC_FORCE_FLUSH_SECONDS:
            should_flush = True

        if should_flush:
            flush_buffer()

    flush_buffer()
    return semantic_segments


semantic_segments = build_semantic_segments(segments)

chunks: List[Tuple[float, float, str]] = []
i = 0
while i < len(semantic_segments):
    window: List[Tuple[float, float, str, int]] = []
    total_tokens = 0
    start_ts = semantic_segments[i][0]
    end_ts = semantic_segments[i][1]
    j = i
    while j < len(semantic_segments):
        seg = semantic_segments[j]
        seg_tokens = seg[3]
        # Always include at least one semantic block per chunk
        if window and total_tokens + seg_tokens > chunk_target and SENTENCE_END_RE.search(window[-1][2]):
            break
        window.append(seg)
        total_tokens += seg_tokens
        end_ts = seg[1]
        j += 1
        if total_tokens >= chunk_target or (total_tokens >= SEMANTIC_MIN_TOKENS and SENTENCE_END_RE.search(seg[2])):
            break

    chunk_text = "".join(seg[2] for seg in window).strip()
    if chunk_text:
        chunks.append((start_ts, end_ts, chunk_text))

    if j >= len(semantic_segments):
        break

    overlap_segments = 0
    if len(window) > 1:
        trailing_tokens = 0
        for idx in range(len(window) - 1, -1, -1):
            trailing_tokens += window[idx][3]
            if trailing_tokens >= SLIDING_OVERLAP_TOKENS:
                break
            overlap_segments += 1
        overlap_segments = min(overlap_segments, len(window) - 1)

    next_i = j if overlap_segments == 0 else max(j - overlap_segments, i + 1)
    i = max(next_i, i + 1)

if DEBUG_MODE:
    avg_tokens = (sum(seg[3] for seg in semantic_segments) / max(len(semantic_segments), 1)) if semantic_segments else 0
    print(f"→ Semantic segments: {len(semantic_segments)} (avg tokens ≈ {avg_tokens:.0f}); summary windows: {len(chunks)} (target ~{chunk_target} tokens)")

# ===== Common: Streaming Tools (No regex cleaning; use correct stop sequence) - Uses temperature, top_p, repeat_penalty, map_max_new_tokens, reduce_max_new_tokens


STREAM_FALLBACK_USED = False
ARTIFACT_TAG_RE = re.compile(r"<\|[^|]*\|>")
def clean_harmony_artifacts(s: str) -> str:
    return ARTIFACT_TAG_RE.sub("", s or "")

def stream_harmony_final_pieces(text_chunks: Iterable[str]) -> Iterator[str]:
    """Yield Harmony streamed text, preferring the final channel.

    Some community models only emit the assistant channel; fall back to it
    so we do not drop the actual content.
    """
    global STREAM_FALLBACK_USED
    buffer = ""
    current_channel: Optional[str] = None
    pending_channel = False
    channel_name_buffer = ""
    in_message = False
    assistant_cache: List[str] = []
    final_seen = False
    any_text_emitted = False
    raw_plain_chunks: List[str] = []

    def canonical_channel(name: Optional[str]) -> str:
        if not name:
            return ""
        lowered = name.strip().lower()
        if not lowered:
            return ""
        if "final" in lowered:
            return "final"
        if "assistant" in lowered:
            return "assistant"
        return lowered

    def emit_text(text: str):
        nonlocal final_seen, assistant_cache, any_text_emitted
        if not text:
            return
        channel = canonical_channel(current_channel)
        if channel == "final":
            final_seen = True
            if assistant_cache:
                cached = assistant_cache[:]
                assistant_cache.clear()
                for cached_piece in cached:
                    if cached_piece:
                        any_text_emitted = True
                        yield cached_piece
            any_text_emitted = True
            yield text
        elif channel == "assistant":
            if final_seen:
                any_text_emitted = True
                yield text
            else:
                assistant_cache.append(text)
        elif final_seen:
            any_text_emitted = True
            yield text

    for piece in text_chunks:
        if not piece:
            continue
        raw_plain_chunks.append(piece)
        buffer += piece
        while True:
            if pending_channel:
                idx = buffer.find("<|")
                if idx == -1:
                    channel_name_buffer += buffer
                    buffer = ""
                    break
                channel_name_buffer += buffer[:idx]
                buffer = buffer[idx:]
                channel = channel_name_buffer.strip()
                channel_name_buffer = ""
                pending_channel = False
                current_channel = channel
                channel_canonical = canonical_channel(current_channel)
                in_message = bool(channel_canonical in {"assistant", "final"})
                continue
            tag_start = buffer.find("<|")
            if tag_start == -1:
                if in_message or canonical_channel(current_channel) in {"assistant", "final"}:
                    for out in emit_text(buffer):
                        yield out
                buffer = ""
                break
            if tag_start > 0:
                text = buffer[:tag_start]
                if in_message or canonical_channel(current_channel) in {"assistant", "final"}:
                    for out in emit_text(text):
                        yield out
                buffer = buffer[tag_start:]
            tag_end = buffer.find("|>")
            if tag_end == -1:
                break
            tag = buffer[2:tag_end].strip().lower()
            buffer = buffer[tag_end + 2:]
            if tag == "start":
                current_channel = None
                in_message = False
            elif tag == "channel":
                pending_channel = True
            elif tag == "message":
                in_message = True
            elif tag == "end":
                in_message = False
                current_channel = None
            elif tag == "return":
                if not final_seen and assistant_cache:
                    for cached_piece in assistant_cache:
                        if cached_piece:
                            any_text_emitted = True
                            yield cached_piece
                    assistant_cache.clear()
                return
            else:
                continue
    if (in_message or canonical_channel(current_channel) in {"assistant", "final"}) and buffer:
        for out in emit_text(buffer):
            yield out
    if not final_seen and assistant_cache:
        for cached_piece in assistant_cache:
            if cached_piece:
                any_text_emitted = True
                yield cached_piece
    if not any_text_emitted and raw_plain_chunks:
        fallback_text = "".join(raw_plain_chunks).strip()
        if fallback_text:
            STREAM_FALLBACK_USED = True
            fallback_text = clean_harmony_artifacts(fallback_text)
            print("stream-flush fallback used")
            yield fallback_text


def llm_stream(messages, max_tokens, *, repeat_penalty_override=None):
    """Stream Harmony-formatted completions and yield only the final channel."""
    if 'llm' not in globals() or llm is None:
        raise RuntimeError("LLM model is not loaded. Cannot stream generation.")
    formatter = ensure_harmony_formatter()
    chat_response = formatter(
        messages=messages,
        add_generation_prompt=True,
    )

    completion_kwargs = dict(
        prompt=chat_response.prompt,
        temperature=float(temperature),
        top_p=float(top_p),
        repeat_penalty=float(repeat_penalty_override if repeat_penalty_override is not None else repeat_penalty),
        max_tokens=int(max_tokens),
        stream=True,
    )
    if chat_response.stop:
        completion_kwargs["stop"] = chat_response.stop
    if chat_response.stopping_criteria is not None:
        completion_kwargs["stopping_criteria"] = chat_response.stopping_criteria

    gen = llm.create_completion(**completion_kwargs)

    def _iter_text_stream(events):
        for ev in events:
            yield ev.get("choices", [{}])[0].get("text", "")

    text_stream = _iter_text_stream(gen)

    for final_piece in stream_harmony_final_pieces(text_stream):
        if final_piece:
            yield final_piece

# ===== Summary 5/6) Segment Summary (map) - Uses map_max_new_tokens, ctx_window, prompt_overhead, topic_hint
if DEBUG_MODE: print("[Summary 5/6] Segment summarization (map) ...")
maps: List[str] = []
map_stats: List[dict] = []
map_debug_payload: List[dict] = []


def escape_braces(text: str) -> str:
    """Escape braces so str.format does not treat user content as placeholders."""
    return text.replace("{", "{{").replace("}", "}}")


def format_timestamp(seconds: float) -> str:
    seconds = max(0, int(seconds))
    h, rem = divmod(seconds, 3600)
    m, s = divmod(rem, 60)
    if h:
        return f"{h:02d}:{m:02d}:{s:02d}"
    return f"{m:02d}:{s:02d}"


safe_topic_hint = escape_braces(topic_hint or "（無）")
map_generation_limit = map_max_new_tokens
if len(chunks) > 6:
    map_generation_limit = min(map_generation_limit, 600)
    if DEBUG_MODE:
        print(f"→ 動態調整 map_max_new_tokens → {map_generation_limit}")

for i, (s, e, body) in enumerate(chunks, 1):
    pct = i / max(len(chunks), 1) * 100
    print(f"  - 處理分段 {i}/{len(chunks)}（~{pct:.1f}%）")

    budget_tokens = max(512, ctx_window - map_generation_limit - prompt_overhead)
    def shrink_to_budget(text: str, budget_tokens: int) -> str:
        cur = text
        for _ in range(6):
            if count_tokens_text(cur) <= budget_tokens:
                return cur
            keep = max(800, int(len(cur) * 0.85))
            cur = cur[:keep]
        return cur
    body2 = shrink_to_budget(body, budget_tokens)
    input_tokens = count_tokens_text(body2)
    input_chars = len(body2)

    safe_body = escape_braces(body2)
    user_txt = MAP_USER_TMPL.format(topic=safe_topic_hint, chunk=safe_body)
    user_txt = user_txt.replace("{{", "{").replace("}}", "}")
    messages = [
        {"role": "system", "content": SYSTEM_INSTR},
        {"role": "user",   "content": user_txt},
    ]

    live = display(Markdown(""), display_id=True)
    part_buf: List[str] = []
    for token in llm_stream(messages, map_generation_limit, repeat_penalty_override=map_repeat_penalty):
        part_buf.append(token)
        if len(part_buf) % 24 == 0:
            cur_txt = "".join(part_buf)
            live.update(Markdown(cur_txt))
            print(f"    ↳ 分段 {i} 已產生字元：{len(cur_txt)}")
    cur_txt = "".join(part_buf)
    if STREAM_FALLBACK_USED:
        cur_txt = clean_harmony_artifacts(cur_txt)
    live.update(Markdown(cur_txt))
    print(f"    ↳ 分段 {i} 最終字元：{len(cur_txt)}")

    cleaned_txt = cur_txt.strip()
    map_empty = not cleaned_txt
    if map_empty:
        cleaned_txt = f"[片段 {i} 無輸出；時間 {format_timestamp(s)}–{format_timestamp(e)}]"
        print(f"    ⚠️ 片段 {i} 無輸出，已寫入占位訊息")

    output_tokens = count_tokens_text(cleaned_txt)
    try:
        token_ids = llm.tokenize(cleaned_txt.encode("utf-8"))
    except Exception:
        token_ids = []
    unique_ratio = len(set(token_ids)) / max(1, len(token_ids))
    map_entry = {
        "index": i,
        "input_tokens": input_tokens,
        "input_chars": input_chars,
        "output_tokens": output_tokens,
        "output_chars": len(cleaned_txt),
        "hit_limit": output_tokens >= map_generation_limit,
        "start": s,
        "end": e,
        "empty": map_empty,
        "unique_token_ratio": unique_ratio,
        "max_new_tokens": map_generation_limit,
    }
    map_stats.append(map_entry)
    map_debug_payload.append({
        "meta": map_entry,
        "text": cleaned_txt,
    })
    maps.append(cleaned_txt)

print(f"[Map Summary] chunks={len(chunks)} / maps={len(maps)}")
if repetition_reports:
    for report in repetition_reports:
        duration = report["duration"]
        start_ts = format_timestamp(report["start"])
        end_ts = format_timestamp(report["end"])
        print(
            f"  ↳ 重複壓縮：'{report['text']}' x{report['count']} | 時間 {start_ts}–{end_ts} | 持續 {duration:.1f}s"
        )

for stat in map_stats:
    print(
        f"  - Map {stat['index']:02d}: in={stat['input_tokens']} tok/{stat['input_chars']} chars, "
        f"out={stat['output_tokens']} tok/{stat['output_chars']} chars, "
        f"unique_ratio={stat['unique_token_ratio']:.2f}, hit_limit={stat['hit_limit']}, empty={stat['empty']}"
    )

if DEBUG_MODE: print("→ Segment summarization complete")

summary_stem = Path(summary_srt_path_abs).stem
debug_output_dir = out_base_dir
debug_output_dir.mkdir(parents=True, exist_ok=True)
map_file_paths = []

maps_md_parts = []
for payload in map_debug_payload:
    meta = payload["meta"]
    text_body = payload["text"]
    seg_header = (
        f"[SEG {meta['index']} | t≈{format_timestamp(meta['start'])}–{format_timestamp(meta['end'])} | "
        f"len≈{meta['output_tokens']} tok]"
    )
    maps_md_parts.append(f"{seg_header}\n### 片段 {meta['index']} 要點\n\n{text_body}")
    map_path = debug_output_dir / f"{summary_stem}_map_{meta['index']:02d}.md"
    header_lines = [
        f"# 片段 {meta['index']:02d} | 時間 {format_timestamp(meta['start'])}–{format_timestamp(meta['end'])}",
        f"> in={meta['input_tokens']} tok/{meta['input_chars']} chars | out={meta['output_tokens']} tok/{meta['output_chars']} chars | unique_ratio={meta['unique_token_ratio']:.2f}",
        "",
    ]
    map_path.write_text("\n".join(header_lines) + text_body + "\n", encoding="utf-8")
    map_file_paths.append(map_path)

maps_md_with_headers = "\n\n---\n\n".join(maps_md_parts)
reduce_input_path = debug_output_dir / f"{summary_stem}_reduce_input.md"
reduce_input_content = maps_md_with_headers
if repetition_reports:
    diag_lines = ["", "<!-- 重複洪水壓縮原始片段 -->"]
    for report in repetition_reports:
        diag_lines.append(
            f"<!-- {format_timestamp(report['start'])}–{format_timestamp(report['end'])} | '{report['text']}' x{report['count']} | {report['duration']:.1f}s -->"
        )
        for sample in report.get('samples', []):
            diag_lines.append(
                f"<!--   sample {format_timestamp(sample['start'])}–{format_timestamp(sample['end'])}: {sample['text']} -->"
            )
    reduce_input_content += "\n".join(diag_lines)
reduce_input_path.write_text(reduce_input_content + "\n", encoding="utf-8")
print(f"→ Map 輸出已寫入 {len(map_file_paths)} 段；reduce_input: {reduce_input_path}")

if DEBUG_MODE: print("[Summary 6/6] Consolidating summary (reduce) ...")

# If combined text exceeds window, truncate proportionally first (without changing text within segments to avoid breaking meaning)
def fit_reduce_payload(md_text: str, max_ctx_tokens: int) -> str:
    for _ in range(8):
        need = count_tokens_text(md_text)
        if need + reduce_max_new_tokens + 400 <= max_ctx_tokens:
            return md_text
        md_text = md_text[: int(len(md_text) * 0.9)]
    return md_text

md_cur = fit_reduce_payload(maps_md_with_headers, ctx_window)

safe_md_cur = escape_braces(md_cur)
user_txt = REDUCE_USER_TMPL.format(topic=safe_topic_hint, maps=safe_md_cur)
user_txt = user_txt.replace("{{", "{").replace("}}", "}")
messages = [{"role":"system","content":SYSTEM_INSTR},
            {"role":"user","content":user_txt}]

live2 = display(Markdown(""), display_id=True)
final_buf = []
for token in llm_stream(messages, reduce_max_new_tokens, repeat_penalty_override=reduce_repeat_penalty):
    final_buf.append(token)
    if len(final_buf) % 24 == 0:
        current_text = "".join(final_buf)
        live2.update(Markdown(current_text))
        print(f"    ↳ 彙整 已產生字元：{len(current_text)}")
live2.update(Markdown("".join(final_buf)))
print(f"    ↳ 彙整 最終字元：{len("".join(final_buf))}")

final_text = "".join(final_buf).strip()
if STREAM_FALLBACK_USED:
    final_text = clean_harmony_artifacts(final_text)
reduce_input_tokens = count_tokens_text(md_cur)
reduce_output_tokens = count_tokens_text(final_text)
reduce_hit_limit = reduce_output_tokens >= reduce_max_new_tokens

missing_segments = []
for payload in map_debug_payload:
    meta = payload["meta"]
    idx = meta["index"]
    patterns = [
        rf"片段\s*{idx}\b",
        re.escape(format_timestamp(meta["start"])),
        re.escape(format_timestamp(meta["end"])),
    ]
    if not any(re.search(pattern, final_text) for pattern in patterns):
        missing_segments.append(idx)
if missing_segments:
    print(f"⚠️ 章節覆蓋檢查：缺少片段 {', '.join(str(i) for i in missing_segments)}")
else:
    print("→ 章節覆蓋檢查通過（所有片段皆被提及）")

# Determine and create the summary output directory
summary_output_dir_abs = out_base_dir
summary_output_dir_abs.mkdir(parents=True, exist_ok=True)

# Determine the summary output file path using the stem of the input SRT
out_md = summary_output_dir_abs / f"{Path(summary_srt_path_abs).stem}_summary.md"


with open(out_md, "w", encoding="utf-8") as f:
    f.write(final_text)

print(f"→ 完成 ✅  {out_md}")
print("[Summary Metrics]")
print(f"→ ctx_window candidates: {CTX_WINDOW_CANDIDATES} | selected: {ctx_window}")
print(f"→ Segments: {len(chunks)} | Maps: {len(maps)}")
for stat in map_stats:
    ratio = stat["output_tokens"] / max(1, stat.get("max_new_tokens", map_max_new_tokens))
    print(
        f"  - Map {stat['index']:02d}: in={stat['input_tokens']} tok/{stat['input_chars']} chars, "
        f"out={stat['output_tokens']} tok/{stat['output_chars']} chars, "
        f"unique_ratio={stat['unique_token_ratio']:.2f}, hit_limit={stat['hit_limit']}, empty={stat['empty']} (ratio={ratio:.2f})"
    )
print(
    f"  ↳ Reduce: in={reduce_input_tokens} tok, out={reduce_output_tokens} tok, "
    f"hit_limit={reduce_hit_limit} (ratio={reduce_output_tokens / max(1, reduce_max_new_tokens):.2f})"
)
if missing_segments:
    print(f"  ↳ Reduce coverage warning: 未提及片段 {missing_segments}")
else:
    print("  ↳ Reduce coverage OK (all segments referenced)")
if STREAM_FALLBACK_USED:
    print("  ↳ stream-flush fallback used")
try:
    del llm
except Exception:
    pass
gc.collect()
if DEBUG_MODE: print("（顯存已釋放，如需重跑可直接再次執行）")