<a href="https://colab.research.google.com/github/Dao-you/Whisper-for-Meeting-on-Colab/blob/main/Whisper_for_Meeting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
# =========================================================
# Whisper Automatic Subtitle Generation: GPU Transcription + CPU Denoising + OpenCC Post-processing (Traditional/Simplified Conversion)
# And LLM Summarization (GPT-OSS-20B / llama.cpp / CUDA)
# - Transcription: faster-whisper (CUDA, compute: int8_float16→float16→int8)
# - Denoising: ffmpeg afftdn (CPU)
# - Progress: Real-time printing of "current sentence + video total length percentage"
# - Network source download and output: MyDrive/whisper; Files in Drive: Output to the same folder
# - LLM Summary: llama.cpp + GPT-OSS-20B GGUF for summarizing transcription
# - Prompts "Delete runtime and restart" if download is blocked or abnormal
# =========================================================

# Restrict multithreading (more stable)
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

# [1/8] Mount Google Drive
from google.colab import drive
print("請授權以掛載 Google Drive")
drive.mount("/content/gdrive", force_remount=True)

# Built-in Imports
import sys, gc, shutil, datetime, subprocess as sp
from pathlib import Path
import re, math, time, textwrap
from typing import List, Tuple, Optional, Iterable, Iterator
from IPython.display import display, Markdown
import json


ROOT = Path("/content/gdrive/MyDrive")
WHISPER_DIR = ROOT / "whisper"
WHISPER_DIR.mkdir(exist_ok=True, parents=True)
os.chdir(ROOT)
print(f"→ 當前工作目錄：{os.getcwd()}")

# [2/8] User Form Parameters (Unified)
#@markdown # Whisper Transcription & LLM Summary Pipeline

#@markdown ## Input & Transcription Settings
#@markdown **Input Source:** Google Drive file (relative to MyDrive) or video URL (YouTube/HTTP).
filename = "whisper/jcz-mfkq-frc (2025-08-08 10_00 GMT+8).mp4"  #@param {type:"string"}
#@markdown **Download Option:** Check to save network source files to `MyDrive/whisper`.
save_video_to_google_drive = True  #@param {type:"boolean"}
#@markdown **Whisper Model Size:** Choose a model size. `large-v3` requires more GPU VRAM; `medium` is a good alternative if VRAM is limited.
model_size = "medium"  #@param ["tiny", "base", "small", "medium", "large-v2", "large-v3"] # Changed model_size to "medium"
#@markdown **Language:** Select the language for transcription. "自動偵測" (Auto-detect) is usually sufficient.
language = "自動偵測"  #@param ["自動偵測", "中文", "英文"]
#@markdown **Denoising:** Apply CPU-based denoising to the audio before transcription. `afftdn` is recommended.
denoise_method = "afftdn (建議)"  #@param ["afftdn (建議)", "none"]
#@markdown **Text Post-processing (OpenCC):** Convert the transcribed text (SRT/TXT output) between Simplified and Traditional Chinese variants.
text_postprocess = "臺灣繁體中文（預設）"  #@param ["臺灣繁體中文（預設）","香港繁體中文","大陸簡體中文","關閉"]
#@markdown **YouTube Cookies (Optional):** Path to a Netscape-format cookies file (relative to MyDrive) for accessing age-restricted or member-only YouTube videos (e.g., `cookies/youtube.txt`).
youtube_cookies_txt_path = ""  #@param {type:"string"}

#@markdown ## Summarization Settings
#@markdown **Topic Hint (Optional):** Provide a brief hint about the topic to guide the summarization process.
topic_hint = ""  #@param {type:"string"}

language_code_map = {"自動偵測": None, "中文":"zh", "英文":"en"}
language_code = language_code_map[language]

# =========================================================
# Developer Options
# Advanced users can fine-tune parameters in this section.
# Modify only if you understand the impact.
# =========================================================
DEBUG_MODE = False # Set to True for more detailed logging

# --- Transcription Parameters ---
TRANSCRIPTION_BEAM_SIZE_PRIMARY = 3
TRANSCRIPTION_CHUNK_LENGTH_PRIMARY = 20
TRANSCRIPTION_BEAM_SIZE_FALLBACK = 1 # Used if primary fails
TRANSCRIPTION_CHUNK_LENGTH_FALLBACK = 15 # Used if primary fails

# --- Denoising Parameters ---
DENOISE_NOISE_FLOOR_DB = -25

# --- Filtering Parameters ---
FILTER_MIN_DURATION_SHORT = 1.5 # Minimum duration for short segments
FILTER_AVG_LOGPROB_THRESHOLD = -1.0 # Avg log probability threshold for short segments
FILTER_MIN_DURATION_SPEECH_PROB = 2.0 # Minimum duration for speech probability filtering
FILTER_NO_SPEECH_PROB_THRESHOLD = 0.6 # No speech probability threshold

# --- Summary Model Parameters ---
REPO_ID   = "unsloth/gpt-oss-20b-GGUF"   # GGUF Model Repository
GGUF_FILE = "gpt-oss-20b-Q4_K_M.gguf"    # Approx. 10.8GiB, T4 can run

# --- Summary Inference Parameters (Increase available generation space to avoid truncation) ---
ctx_window            = 8192
map_max_new_tokens    = 512   # Segment output: original 256 -> 512 (approx. 350-450 chars)
reduce_max_new_tokens = 1024  # Summary output: original 512 -> 1024 (approx. 700-900+ chars)
temperature           = 0.2
top_p                 = 0.9
repeat_penalty        = 1.05

tokenizer_config_data = None
harmony_chat_formatter = None

def ensure_harmony_formatter():
    """Ensure the Harmony chat formatter is available for GPT-OSS prompts."""
    global harmony_chat_formatter
    if harmony_chat_formatter is not None:
        return harmony_chat_formatter
    try:
        from llama_cpp.llama_chat_format import (
            hf_tokenizer_config_to_chat_formatter,
            Jinja2ChatFormatter,
        )
    except Exception as exc:
        raise RuntimeError("llama_cpp Harmony chat helpers are unavailable") from exc

    formatter = None
    if tokenizer_config_data:
        try:
            formatter = hf_tokenizer_config_to_chat_formatter(
                tokenizer_config_data,
                add_generation_prompt=True,
            )
        except Exception as exc:
            if DEBUG_MODE:
                print("  ✗ Failed to initialize Harmony formatter from tokenizer_config.json:", exc)

    if formatter is None:
        template = None
        if 'llm' in globals() and hasattr(llm, 'metadata'):
            template = llm.metadata.get('tokenizer.chat_template')
        if template:
            try:
                bos_token = (
                    llm.metadata.get('tokenizer.bos_token')
                    or llm.metadata.get('tokenizer.ggml.bos_token')
                    or llm.detokenize([llm.token_bos()], special=True).decode('utf-8', errors='ignore')
                )
                eos_token = (
                    llm.metadata.get('tokenizer.eos_token')
                    or llm.metadata.get('tokenizer.ggml.eos_token')
                    or llm.detokenize([llm.token_eos()], special=True).decode('utf-8', errors='ignore')
                )
                formatter = Jinja2ChatFormatter(
                    template,
                    eos_token=eos_token,
                    bos_token=bos_token,
                    stop_token_ids=[llm.token_eos()],
                )
            except Exception as exc:
                if DEBUG_MODE:
                    print("  ✗ Failed to build Harmony formatter from GGUF metadata:", exc)

    if formatter is None:
        raise RuntimeError("Harmony chat formatter could not be prepared")

    harmony_chat_formatter = formatter
    return harmony_chat_formatter




# =========================================================
# End of Developer Options
# =========================================================


# [3/8] Install Dependencies
# Combine installation steps from both original cells
if DEBUG_MODE: print("[Install] faster-whisper / yt-dlp / soundfile / opencc / srt / huggingface_hub / llama-cpp-python ...")

def pip_install(pkgs, extra_args=None, env=None):
    cmd = [sys.executable, "-m", "pip", "install", "--upgrade"]
    if extra_args:
        cmd += extra_args
    cmd += pkgs
    return sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True, env=env)

import importlib.util
# Install common dependencies first
common_missing = []
if importlib.util.find_spec("srt") is None:
    common_missing.append("srt>=3.5.3")
if importlib.util.find_spec("huggingface_hub") is None:
    common_missing.append("huggingface_hub>=0.23.0")
if importlib.util.find_spec("soundfile") is None:
    common_missing.append("soundfile")
if importlib.util.find_spec("opencc") is None:
    common_missing.append("opencc-python-reimplemented")
if importlib.util.find_spec("jinja2") is None:
    common_missing.append("jinja2>=3.1.0")

if common_missing:
    if DEBUG_MODE: print("→ Installing common missing packages:", ", ".join(common_missing))
    r = pip_install(common_missing)
    if r.returncode != 0:
        if DEBUG_MODE: print(r.stdout)
        raise RuntimeError("基礎依賴安裝失敗，請重啟執行階段後重試。")

# Install faster-whisper and yt-dlp separately as they were in the first cell
if importlib.util.find_spec("faster_whisper") is None:
    if DEBUG_MODE: print("→ Installing missing package: faster-whisper yt-dlp")
    r = pip_install(["faster-whisper", "yt-dlp"])
    if r.returncode != 0:
        if DEBUG_MODE: print(r.stdout)
        raise RuntimeError("faster-whisper / yt-dlp 安裝失敗，請重啟執行階段後重試。")

# Import external packages after ensuring installation
import soundfile as sf
from faster_whisper import WhisperModel
from opencc import OpenCC
import srt as _srt  # Import srt as _srt to avoid name conflict later with the module itself
from huggingface_hub import snapshot_download
def suggest_runtime_reset():
    print("\n🧹 建議動作（Colab）")
    print("1) 依序：『執行階段 Runtime』 → 『刪除執行階段/還原出廠設定 Factory reset runtime』")
    print("2) 重新執行本 Notebook（從掛載雲端硬碟那格開始）\n", flush=True)

def run_cmd(cmd:list, check=True):
    if DEBUG_MODE: print("  $", " ".join(cmd))
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if DEBUG_MODE and p.stdout: sys.stdout.write(p.stdout)
    if check and p.returncode != 0:
        raise RuntimeError(f"命令失敗：{' '.join(cmd)}")
    return p

def is_youtube_url(s:str)->bool:
    return isinstance(s, str) and ("youtu.be" in s or "youtube.com" in s)
def is_http_url(s:str)->bool:
    return isinstance(s, str) and s.lower().startswith("http")
def to_abs_mydrive(p:str)->Path:
    return (Path(p) if p.startswith("/") else (ROOT / p)).resolve()
def fmt_ts_srt(t:float)->str:
    h = int(t//3600); m = int((t%3600)//60); s = t - h*3600 - m*60
    return f"{h:02d}:{m:02d}:{int(s):02d},{int(round((s-int(s))*1000)):03d}"
def verify_wav_ok(path: Path)->bool:
    try:
        info = sf.info(str(path))
        return info.samplerate > 0 and info.channels in (1, 2)
    except Exception:
        return False

# OpenCC converter setup
def build_opencc_pipeline(choice:str):
    if choice.startswith("臺灣"):
        return [OpenCC('s2t'), OpenCC('t2tw')]
    if choice.startswith("香港"):
        return [OpenCC('s2t'), OpenCC('t2hk')]
    if choice.startswith("大陸"):
        return [OpenCC('t2s')]
    return []  # Disable

def apply_opencc(text:str, pipeline)->str:
    for cc in pipeline:
        text = cc.convert(text)
    return text

def ytdl(yturl:str)->Path:
    tmp = Path("/tmp/dl"); tmp.mkdir(parents=True, exist_ok=True)
    for x in tmp.glob("*"):
        try: x.unlink()
        except: shutil.rmtree(x, ignore_errors=True)
    ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out = tmp / f"downloaded_{ts}.mp4"
    if DEBUG_MODE: print("[Download] Getting YouTube video ...")
    # Use sp.run instead of subprocess.run directly
    cmd = ["yt-dlp", "-f", "mp4", "-o", str(tmp / "%(title)s.%(ext)s")]
    if youtube_cookies_txt_path.strip():
        cookies_abs = to_abs_mydrive(youtube_cookies_txt_path.strip())
        if cookies_abs.exists():
            cmd += ["--cookies", str(cookies_abs)]
        else:
            if DEBUG_MODE: print(f"⚠️ 找不到 cookies 檔：{cookies_abs}（改為不帶 cookies）")
    cmd.append(yturl)
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if DEBUG_MODE and p.stdout: sys.stdout.write(p.stdout)
    if p.returncode != 0:
        if "Sign in to confirm" in (p.stdout or ""):
            print("\n❗YouTube 要求登入/驗證，請提供 cookies 或先自行下載到雲端硬碟。")
        print("🔄 若多次失敗，請刪除執行階段並重啟後重試。")
        suggest_runtime_reset()
        raise RuntimeError("yt-dlp 下載失敗")
    files = list(tmp.glob("*"))
    if not files:
        print("🔄 下載為空，建議刪除執行階段再重試。")
        suggest_runtime_reset()
        raise FileNotFoundError("YouTube 下載失敗：/tmp/dl 為空")
    f = files[0]
    if save_video_to_google_drive:
        shutil.copy2(f, WHISPER_DIR / f.name)
    return f

def http_dl(url:str)->Path:
    tmp = Path("/tmp/dl"); tmp.mkdir(parents=True, exist_ok=True)
    for x in tmp.glob("*"):
        try: x.unlink()
        except: shutil.rmtree(x, ignore_errors=True)
    ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out = tmp / f"downloaded_{ts}.mp4"
    if DEBUG_MODE: print("[Download] Getting HTTP(S) video ...")
    run_cmd(["curl", "-L", "-o", str(out), url])
    if save_video_to_google_drive:
        shutil.copy2(out, WHISPER_DIR / out.name)
    return out

# Extract audio: ffmpeg -> 16k/mono WAV
def ffmpeg_extract_wav(in_path:Path, out_wav:Path, sr=16000):
    cmd = ["ffmpeg","-y","-i",str(in_path),"-vn","-ac","1","-ar",str(sr),"-f","wav",str(out_wav)]
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if p.returncode != 0:
        if DEBUG_MODE: print(p.stdout)
        raise RuntimeError("ffmpeg 轉 WAV 失敗")

# CPU Denoising: ffmpeg afftdn
def ffmpeg_afftdn(in_wav: Path, out_wav: Path, noise_floor_db=DENOISE_NOISE_FLOOR_DB):
    cmd = ["ffmpeg","-y","-i",str(in_wav),"-af",f"afftdn=nf={noise_floor_db}",
           "-ac","1","-ar","16000","-f","wav",str(out_wav)]
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if p.returncode != 0:
        if DEBUG_MODE: print(p.stdout)
        raise RuntimeError("ffmpeg afftdn 失敗")

# Safeguard: Repack WAV header if format is strange
def ffmpeg_repack_wav(in_wav: Path, out_wav: Path, sr=16000):
    cmd = ["ffmpeg","-y","-i",str(in_wav),"-acodec","pcm_s16le","-ac","1","-ar",str(sr),str(out_wav)]
    p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True)
    if p.returncode != 0:
        if DEBUG_MODE: print(p.stdout)
        raise RuntimeError("ffmpeg 重包 WAV 失敗")

# [4/8] Parse Source (Transcription) - Uses 'filename' and 'save_video_to_google_drive'
if DEBUG_MODE: print("[4/8] Parsing input source ...")
try:
    if is_youtube_url(filename):
        src_path = ytdl(filename); out_base_dir = WHISPER_DIR
    elif is_http_url(filename):
        src_path = http_dl(filename); out_base_dir = WHISPER_DIR
    else:
        src_path = to_abs_mydrive(filename)
        if not src_path.exists(): raise FileNotFoundError(f"找不到檔案：{src_path}")
        out_base_dir = src_path.parent
except Exception as e:
    print(f"\n⛔ 來源解析/下載失敗：{e}")
    print("🔄 請刪除執行階段並重新啟動後重跑。"); suggest_runtime_reset(); raise

print(f"→ 來源檔：{src_path}")
print(f"→ 輸出資料夾：{out_base_dir}")

# [5/8] Extract Audio & CPU Denoising (Transcription) - Uses 'denoise_method' and 'DENOISE_NOISE_FLOOR_DB'
AUDIO_16K = Path("/tmp/audio_16k.wav")
if DEBUG_MODE: print("[5/8] Extracting audio (ffmpeg → 16k/mono WAV) ...")
ffmpeg_extract_wav(src_path, AUDIO_16K, sr=16000)

if denoise_method.startswith("afftdn"):
    if DEBUG_MODE: print("[5.5/8] Denoising (ffmpeg afftdn, CPU) ...")
    DENOISED = Path("/tmp/audio_16k_denoised.wav")
    ffmpeg_afftdn(AUDIO_16K, DENOISED, noise_floor_db=DENOISE_NOISE_FLOOR_DB)
    denoised_audio = DENOISED if verify_wav_ok(DENOISED) else AUDIO_16K
else:
    denoised_audio = AUDIO_16K

if not verify_wav_ok(denoised_audio):
    if DEBUG_MODE: print("  - 音訊格式異常；嘗試重包 WAV ...")
    FIXED = Path("/tmp/audio_16k_fixed.wav")
    ffmpeg_repack_wav(denoised_audio, FIXED, sr=16000)
    denoised_audio = FIXED

if DEBUG_MODE: print(f"→ 最終輸入音訊：{denoised_audio}")

# [6/8] Load faster-whisper (GPU enforced) - Uses 'model_size'
if DEBUG_MODE: print("[6/8] Loading faster-whisper model (GPU) ...")
device = "cuda"  # Enforce GPU
model = None; last_err = None
for ctype in ["int8_float16", "float16", "int8"]:
    try:
        if DEBUG_MODE: print(f"  - Trying compute_type={ctype}")
        model = WhisperModel(model_size, device=device, compute_type=ctype)
        if DEBUG_MODE: print("  - Model loaded successfully")
        break
    except Exception as e:
        last_err = e
        if DEBUG_MODE: print(f"  - Load failed: {e}")
if model is None:
    print("\n⛔ GPU 模型載入失敗。請確認『變更執行階段類型』選了 GPU（T4/A100），或刪除執行階段後重試。")
    suggest_runtime_reset()
    raise RuntimeError(f"無法載入模型：{last_err}")

gc.collect()  # Clean up before transcription (safety)

# [7/8] Transcribe (GPU; real-time progress per segment) - Uses 'language_code', 'TRANSCRIPTION_BEAM_SIZE_PRIMARY', 'TRANSCRIPTION_CHUNK_LENGTH_PRIMARY', 'TRANSCRIPTION_BEAM_SIZE_FALLBACK', 'TRANSCRIPTION_CHUNK_LENGTH_FALLBACK'
if DEBUG_MODE: print(f"[7/8] Starting transcription (GPU: beam={TRANSCRIPTION_BEAM_SIZE_PRIMARY} / chunk={TRANSCRIPTION_CHUNK_LENGTH_PRIMARY}s / no VAD) ...")

def transcribe_gpu(_beam=TRANSCRIPTION_BEAM_SIZE_PRIMARY, _chunk=TRANSCRIPTION_CHUNK_LENGTH_PRIMARY):
    return model.transcribe(
        str(denoised_audio),
        task="transcribe",
        language=language_code,
        temperature=0.0,
        condition_on_previous_text=True,
        compression_ratio_threshold=2.4,
        log_prob_threshold=-1.0,
        no_speech_threshold=0.6,
        beam_size=_beam,
        chunk_length=_chunk,
        vad_filter=False,
        word_timestamps=False
    )

try:
    seg_iter, info = transcribe_gpu(_beam=TRANSCRIPTION_BEAM_SIZE_PRIMARY, _chunk=TRANSCRIPTION_CHUNK_LENGTH_PRIMARY)
except Exception as e:
    if DEBUG_MODE: print(f"  - First transcription failed: {e}\n    → Trying more conservative (beam={TRANSCRIPTION_BEAM_SIZE_FALLBACK}, chunk={TRANSCRIPTION_CHUNK_LENGTH_FALLBACK}) ...")
    seg_iter, info = transcribe_gpu(_beam=TRANSCRIPTION_BEAM_SIZE_FALLBACK, _chunk=TRANSCRIPTION_CHUNK_LENGTH_FALLBACK)

# Display percentage based on total video duration
duration = float(getattr(info, "duration", 0.0) or 0.0)
if duration <= 0: duration = 1.0

segments = []
filtered = []

if DEBUG_MODE:
    print(f"  - Detected language: {getattr(info,'language','未知')} (p={getattr(info,'language_probability',0):.2f})")
    print(f"  - Audio length: {duration:.2f}s")

for s in seg_iter:
    pct = int(min(100, round((s.end / duration) * 100)))
    print(f"[{pct:3d}%] {fmt_ts_srt(s.start)} → {fmt_ts_srt(s.end)}  {s.text.strip()}", flush=True)
    segments.append(s)

    # Low confidence/high no-speech short segment filtering (no blacklist) - Uses FILTER_* parameters
    keep = True
    seg_dur = float(s.end - s.start)
    if seg_dur < FILTER_MIN_DURATION_SHORT and getattr(s, "avg_logprob", None) is not None and s.avg_logprob < FILTER_AVG_LOGPROB_THRESHOLD:
        keep = False
    if seg_dur < FILTER_MIN_DURATION_SPEECH_PROB and getattr(s, "no_speech_prob", None) is not None and s.no_speech_prob > FILTER_NO_SPEECH_PROB_THRESHOLD:
        keep = False
    if keep:
        filtered.append(s)

if DEBUG_MODE: print(f"  - Number of segments: Before filtering {len(segments)} → After filtering {len(filtered)}")

# ---- OpenCC Normalization (for output text) ---- - Uses 'text_postprocess'
pipeline = build_opencc_pipeline(text_postprocess)
def norm(txt: str) -> str:
    return apply_opencc(txt, pipeline) if pipeline else txt

# [8/8] Output (text after OpenCC) - Uses 'out_base_dir' (derived from 'filename')
print("[8/8] 輸出 SRT / TXT ...")
# Determine the output directory for transcription based on input type
# If input is a network source, output to WHISPER_DIR
# If input is a local file, output to the same directory as the input file
if is_youtube_url(filename) or is_http_url(filename):
    out_base_dir = WHISPER_DIR
else:
    src_path_abs = to_abs_mydrive(filename)
    out_base_dir = src_path_abs.parent

# Create the transcription output directory if it doesn't exist
out_dir = out_base_dir
out_dir.mkdir(exist_ok=True, parents=True)

# Determine the stem from the original source file path
stem = Path(src_path).stem
SRT = out_dir / f"{stem}.srt"
TXT = out_dir / f"{stem}.txt"

with open(SRT, "w", encoding="utf-8") as f:
    for i, s in enumerate(filtered, 1):
        text_out = norm(s.text.strip())
        f.write(f"{i}\n{fmt_ts_srt(s.start)} --> {fmt_ts_srt(s.end)}\n{text_out}\n\n")

with open(TXT, "w", encoding="utf-8") as f:
    for s in filtered:
        f.write(norm(s.text.strip()) + "\n")  # Each segment on a new line

print(f"→ 完成！\n  SRT: {SRT}\n  TXT: {TXT}")

# Release model (release GPU memory)
try: del model
except: pass
gc.collect()
if DEBUG_MODE: print("→ Model released; can run again directly if needed.")


# ===== Summarization Logic Starts Here =====
# Use SRT from transcription step for summarization
summary_srt_path_abs = SRT
assert summary_srt_path_abs.exists(), f"SRT 檔不存在：{summary_srt_path_abs}"

# ===== Summary 1/6) Check GPU and Install Dependencies (llama-cpp-python specific) =====
# llama-cpp-python installation logic - Keep this separate as it has specific CUDA requirements
# Moved this section to just before reading the SRT for summarization
if DEBUG_MODE: print("[Summary 1/6] Checking GPU and installing llama-cpp-python ...")

def detect_cuda_tag():
    try:
        out = sp.check_output(["nvidia-smi"], text=True)
        m = re.search(r"CUDA Version:\s*([\d.]+)", out)
        if not m:
            return "cu124"
        major, minor = [int(x) for x in m.group(1).split(".")[:2]]
        if major > 12 or (major == 12 and minor >= 5):
            return "cu125"
        return "cu124"
    except Exception:
        return "cu124"

cuda_tag = detect_cuda_tag()
if DEBUG_MODE: print(f"GPU 0: Detected CUDA version tag {cuda_tag}")

def try_import_llama():
    try:
        from llama_cpp import Llama
        return Llama
    except ModuleNotFoundError:
        return None

Llama = try_import_llama()
if Llama is None:
    # Keep your existing installation strategy: extra-index -> fallback to source compilation on failure
    candidates = [cuda_tag, "cu125", "cu124", "cu122", "cu121"]
    ok = False
    for tag in candidates:
        idx = f"https://abetlen.github.io/llama-cpp-python/whl/{tag}"
        if DEBUG_MODE: print(f"→ Attempting to install llama-cpp-python ({tag}) ...")
        r = pip_install(["llama-cpp-python"], extra_args=["--extra-index-url", idx])
        if r.returncode == 0:
            Llama = try_import_llama()
            if Llama is not None:
                ok = True
                break
        else:
            if DEBUG_MODE: print("  ✗ Installation failed (summary):", "\n".join(r.stdout.splitlines()[-5:]))
    if not ok:
        if DEBUG_MODE: print("→ Pre-compiled wheels not available, switching to 'source compilation (CUDA=ON)' ... (takes longer)")
        try:
            import ninja # noqa: F401 # Import ninja to check if installed
        except ModuleNotFoundError:
            if DEBUG_MODE: print("→ Installing missing package: ninja")
            r = pip_install(["ninja"])
            if r.returncode != 0:
                if DEBUG_MODE: print(r.stdout)
                raise RuntimeError("安裝 ninja 失敗。請重啟後重試。")
        env = os.environ.copy()
        env["CMAKE_ARGS"] = "-DGGML_CUDA=on -DLLAMA_CUBLAS=on"
        env["FORCE_CMAKE"] = "1"
        r = pip_install(["llama-cpp-python"], env=env)
        if r.returncode != 0:
            if DEBUG_MODE: print(r.stdout)
            raise RuntimeError("無法安裝 GPU 版 llama-cpp-python。")
        Llama = try_import_llama()


if DEBUG_MODE: print("[Summary 2/6] Reading SRT ...")
with open(summary_srt_path_abs, "r", encoding="utf-8") as f:
    srt_text = f.read()
subs = list(_srt.parse(srt_text)) # Use _srt as srt module was imported as _srt
def td2s(td): return td.total_seconds()
segments = []
for it in subs:
    txt = it.content.strip()
    if not txt: continue
    segments.append((td2s(it.start), td2s(it.end), txt))
total_secs = (segments[-1][1] - segments[0][0]) if segments else 0
if DEBUG_MODE: print(f"→ Number of subtitle segments: {len(segments)}；Video length (est): {total_secs/60:.1f} minutes")


# ===== Summary 3/6) Download and Load GGUF Model (Summary) - Uses summary model parameters (REPO_ID, GGUF_FILE, ctx_window, etc.)
# Moved this section to just after installing llama-cpp-python
if DEBUG_MODE: print("[Summary 3/6] Loading GPT-OSS-20B (GGUF, CUDA) ...")
local_repo = snapshot_download(REPO_ID, allow_patterns=[GGUF_FILE, "tokenizer_config.json"])
gguf_path = str(Path(local_repo)/GGUF_FILE)

tokenizer_config_path = Path(local_repo)/"tokenizer_config.json"
if tokenizer_config_path.exists():
    try:
        tokenizer_config_data = json.loads(tokenizer_config_path.read_text())
        if DEBUG_MODE: print("→ Loaded tokenizer_config.json (Harmony template)")
    except Exception as exc:
        if DEBUG_MODE: print("  ✗ Failed to parse tokenizer_config.json:", exc)

llm = Llama(
    model_path=gguf_path,
    n_ctx=ctx_window,
    n_gpu_layers=-1,
    seed=0,
    logits_all=False,
    verbose=True          # Display the actual chat format used
)
if DEBUG_MODE: print("→ Model loaded successfully (GPU)")

def ensure_harmony_formatter():
    """Ensure the Harmony chat formatter is available for GPT-OSS prompts."""
    global harmony_chat_formatter
    if harmony_chat_formatter is not None:
        return harmony_chat_formatter
    try:
        from llama_cpp.llama_chat_format import (
            hf_tokenizer_config_to_chat_formatter,
            Jinja2ChatFormatter,
        )
    except Exception as exc:
        raise RuntimeError("llama_cpp Harmony chat helpers are unavailable") from exc

    formatter = None
    if tokenizer_config_data:
        try:
            formatter = hf_tokenizer_config_to_chat_formatter(
                tokenizer_config_data,
                add_generation_prompt=True,
            )
        except Exception as exc:
            if DEBUG_MODE:
                print("  ✗ Failed to initialize Harmony formatter from tokenizer_config.json:", exc)

    if formatter is None:
        template = None
        if 'llm' in globals() and hasattr(llm, 'metadata'):
            template = llm.metadata.get('tokenizer.chat_template')
        if template:
            try:
                bos_token = (
                    llm.metadata.get('tokenizer.bos_token')
                    or llm.metadata.get('tokenizer.ggml.bos_token')
                    or llm.detokenize([llm.token_bos()], special=True).decode('utf-8', errors='ignore')
                )
                eos_token = (
                    llm.metadata.get('tokenizer.eos_token')
                    or llm.metadata.get('tokenizer.ggml.eos_token')
                    or llm.detokenize([llm.token_eos()], special=True).decode('utf-8', errors='ignore')
                )
                formatter = Jinja2ChatFormatter(
                    template,
                    eos_token=eos_token,
                    bos_token=bos_token,
                    stop_token_ids=[llm.token_eos()],
                )
            except Exception as exc:
                if DEBUG_MODE:
                    print("  ✗ Failed to build Harmony formatter from GGUF metadata:", exc)

    if formatter is None:
        raise RuntimeError("Harmony chat formatter could not be prepared")

    harmony_chat_formatter = formatter
    return harmony_chat_formatter



try:
    ensure_harmony_formatter()
    if DEBUG_MODE: print("→ Harmony formatter prepared")
except Exception as exc:
    raise RuntimeError(f"Failed to prepare Harmony formatter: {exc}")


# ===== Summary 4/6) Token-aware Segmentation (Summary) - Uses ctx_window, map_max_new_tokens, prompt_overhead
if DEBUG_MODE: print("[Summary 4/6] Generating segments (token-aware; single segment ≤ safety limit) ...")

def count_tokens_text(text: str) -> int:
    # Check if llm is initialized before using it
    if 'llm' not in globals() or llm is None:
         raise RuntimeError("LLM model is not loaded. Cannot count tokens.")
    return len(llm.tokenize(text.encode("utf-8")))

SYSTEM_INSTR = (
  "你是一個會議總結機器人。根據使用者提供的逐字稿（可能雜訊、重複、錯字），"
  "請去除雜訊與重複、嚴守事實、不腦補。遇到不明確資訊以「待補充／未明確」標註。"
  "輸出為 Markdown（繁體中文），不要輸出任何系統／思考標記。"
)

# — Segment Summary Prompt: More concise request, avoid verbosity and system language - Uses 'topic_hint'
MAP_USER_TMPL = textwrap.dedent("""\
主題（可留空）：{topic}

以下是逐字稿片段（非完整全文）：
{chunk}

請就此片段輸出「條列式重點摘要」（500–900 字，繁體中文），注意：
- 只寫最終內容，不要寫解題想法、不要出現任何系統提示或中英括號標記。
- 聚焦可驗證事實（時間、人物、任務、結論、未決事項、行動）。
- 結構：可用小標題＋項目符號，語句務必短、準確、無贅詞。
""")

# — Summary Prompt: Maintain your three-section output structure - Uses 'topic_hint'
REDUCE_USER_TMPL = textwrap.dedent("""\
主題（可留空）：{topic}

以下是所有片段的重點摘要彙整（仍可能有重疊）：
{maps}

請整合為一份會議筆記（Markdown，繁體）：
1) **整體提要**（3–6 句，避免冗言）
2) **章節要點（含時間脈絡）**：條列呈現，每點一行，可附粗略時間
3) **可執行重點**：具體待辦（每條以動詞開頭）
請只輸出最終筆記，不要出現系統或思考標記，不要加入未出現的新資訊。
""")

# Single segment token budget (reserve space for prompt and generation)
prompt_overhead = 700
chunk_target    = max(1024, min(3072, ctx_window - prompt_overhead - map_max_new_tokens))

chunks: List[Tuple[float,float,str]] = []
buf, t0, t1, cur = [], None, None, 0
for (s, e, txt) in segments:
    t = count_tokens_text(txt)
    if not buf:
        buf, t0, t1, cur = [txt], s, e, t
        continue
    if cur + t <= chunk_target:
        buf.append(txt); t1 = e; cur += t
    else:
        chunks.append((t0, t1, "\n".join(buf)))
        buf, t0, t1, cur = [txt], s, e, t
if buf:
    chunks.append((t0, t1, "\n".join(buf)))

if DEBUG_MODE: print(f"→ Generated {len(chunks)} segments (target ~{chunk_target} tokens per segment)")

# ===== Common: Streaming Tools (No regex cleaning; use correct stop sequence) - Uses temperature, top_p, repeat_penalty, map_max_new_tokens, reduce_max_new_tokens


def stream_harmony_final_pieces(text_chunks: Iterable[str]) -> Iterator[str]:
    """Yield Harmony streamed text, preferring the final channel.

    Some community models only emit the assistant channel; fall back to it
    so we do not drop the actual content.
    """
    buffer = ""
    current_channel: Optional[str] = None
    pending_channel = False
    channel_name_buffer = ""
    in_message = False
    assistant_cache: List[str] = []
    final_seen = False

    def canonical_channel(name: Optional[str]) -> str:
        if not name:
            return ""
        lowered = name.strip().lower()
        if not lowered:
            return ""
        if "final" in lowered:
            return "final"
        if "assistant" in lowered:
            return "assistant"
        return lowered

    def emit_text(text: str):
        nonlocal final_seen, assistant_cache
        if not text:
            return
        channel = canonical_channel(current_channel)
        if channel == "final":
            final_seen = True
            if assistant_cache:
                cached = assistant_cache[:]
                assistant_cache.clear()
                for cached_piece in cached:
                    if cached_piece:
                        yield cached_piece
            yield text
        elif channel == "assistant":
            if final_seen:
                yield text
            else:
                assistant_cache.append(text)
        elif final_seen:
            yield text

    for piece in text_chunks:
        if not piece:
            continue
        buffer += piece
        while True:
            if pending_channel:
                idx = buffer.find("<|")
                if idx == -1:
                    channel_name_buffer += buffer
                    buffer = ""
                    break
                channel_name_buffer += buffer[:idx]
                buffer = buffer[idx:]
                channel = channel_name_buffer.strip()
                channel_name_buffer = ""
                pending_channel = False
                current_channel = channel
                channel_canonical = canonical_channel(current_channel)
                in_message = bool(channel_canonical in {"assistant", "final"})
                continue
            tag_start = buffer.find("<|")
            if tag_start == -1:
                if in_message or canonical_channel(current_channel) in {"assistant", "final"}:
                    for out in emit_text(buffer):
                        yield out
                    buffer = ""
                else:
                    if len(buffer) > 128:
                        buffer = buffer[-128:]
                break
            if tag_start > 0:
                text = buffer[:tag_start]
                if in_message or canonical_channel(current_channel) in {"assistant", "final"}:
                    for out in emit_text(text):
                        yield out
                buffer = buffer[tag_start:]
            tag_end = buffer.find("|>")
            if tag_end == -1:
                break
            tag = buffer[2:tag_end].strip().lower()
            buffer = buffer[tag_end + 2:]
            if tag == "start":
                current_channel = None
                in_message = False
            elif tag == "channel":
                pending_channel = True
            elif tag == "message":
                in_message = True
            elif tag == "end":
                in_message = False
                current_channel = None
            elif tag == "return":
                if not final_seen and assistant_cache:
                    for cached_piece in assistant_cache:
                        if cached_piece:
                            yield cached_piece
                    assistant_cache.clear()
                return
            else:
                continue
    if (in_message or canonical_channel(current_channel) in {"assistant", "final"}) and buffer:
        for out in emit_text(buffer):
            yield out
    if not final_seen and assistant_cache:
        for cached_piece in assistant_cache:
            if cached_piece:
                yield cached_piece

def llm_stream(messages, max_tokens):
    """Stream Harmony-formatted completions and yield only the final channel."""
    if 'llm' not in globals() or llm is None:
        raise RuntimeError("LLM model is not loaded. Cannot stream generation.")
    formatter = ensure_harmony_formatter()
    chat_response = formatter(
        messages=messages,
        add_generation_prompt=True,
    )

    completion_kwargs = dict(
        prompt=chat_response.prompt,
        temperature=float(temperature),
        top_p=float(top_p),
        repeat_penalty=float(repeat_penalty),
        max_tokens=int(max_tokens),
        stream=True,
    )
    if chat_response.stop:
        completion_kwargs["stop"] = chat_response.stop
    if chat_response.stopping_criteria is not None:
        completion_kwargs["stopping_criteria"] = chat_response.stopping_criteria

    gen = llm.create_completion(**completion_kwargs)

    def _iter_text_stream(events):
        for ev in events:
            yield ev.get("choices", [{}])[0].get("text", "")

    text_stream = _iter_text_stream(gen)

    for final_piece in stream_harmony_final_pieces(text_stream):
        if final_piece:
            yield final_piece

# ===== Summary 5/6) Segment Summary (map) - Uses map_max_new_tokens, ctx_window, prompt_overhead, topic_hint
if DEBUG_MODE: print("[Summary 5/6] Segment summarization (map) ...")
live = display(Markdown(""), display_id=True)
maps: List[str] = []


def escape_braces(text: str) -> str:
    """Escape braces so str.format does not treat user content as placeholders."""
    return text.replace("{", "{{").replace("}", "}}")


safe_topic_hint = escape_braces(topic_hint or "（無）")

for i, (s, e, body) in enumerate(chunks, 1):
    pct = i / max(len(chunks),1) * 100
    sys.stdout.write(f"  - 處理分段 {i}/{len(chunks)}（~{pct:.1f}%）\n"); sys.stdout.flush()

    # Shrink to safe budget before sending (prevent prompt+segment from exceeding window and causing model to terminate early)
    budget_tokens = max(512, ctx_window - map_max_new_tokens - prompt_overhead)
    def shrink_to_budget(text: str, budget_tokens: int) -> str:
        cur = text
        for _ in range(6):
            if count_tokens_text(cur) <= budget_tokens:
                return cur
            keep = max(800, int(len(cur) * 0.85))
            cur = cur[:keep]
        return cur
    body2 = shrink_to_budget(body, budget_tokens)

    safe_body = escape_braces(body2)
    user_txt = MAP_USER_TMPL.format(topic=safe_topic_hint, chunk=safe_body)
    user_txt = user_txt.replace("{{", "{").replace("}}", "}")
    messages = [
        {"role": "system", "content": SYSTEM_INSTR},
        {"role": "user",   "content": user_txt},
    ]

    part_buf = [] # Reset part_buf for each segment
    for token in llm_stream(messages, map_max_new_tokens):
        part_buf.append(token)
        # Update live display and terminal character count periodically
        if len(part_buf) % 24 == 0:
            cur_txt = "".join(part_buf)
            live.update(Markdown(cur_txt))
            sys.stdout.write(f"    ↳ 分段 {i} 已產生字元：{len(cur_txt)}\n"); sys.stdout.flush()
    cur_txt = "".join(part_buf)
    live.update(Markdown(cur_txt))
    sys.stdout.write(f"    ↳ 分段 {i} 已產生字元：{len(cur_txt)}\n"); sys.stdout.flush()

    # Include the model's final output directly, no regex cleaning
    maps.append(cur_txt.strip())

if DEBUG_MODE: print("→ Segment summarization complete")

if DEBUG_MODE: print("[Summary 6/6] Consolidating summary (reduce) ...")
maps_md = "\n\n---\n\n".join(f"### 片段 {i+1} 要點\n\n{m}" for i, m in enumerate(maps))

# If combined text exceeds window, truncate proportionally first (without changing text within segments to avoid breaking meaning)
def fit_reduce_payload(md_text: str, max_ctx_tokens: int) -> str:
    for _ in range(8):
        need = count_tokens_text(md_text)
        if need + reduce_max_new_tokens + 400 <= max_ctx_tokens:
            return md_text
        md_text = md_text[: int(len(md_text) * 0.9)]
    return md_text

md_cur = fit_reduce_payload(maps_md, ctx_window)

safe_md_cur = escape_braces(md_cur)
user_txt = REDUCE_USER_TMPL.format(topic=safe_topic_hint, maps=safe_md_cur)
user_txt = user_txt.replace("{{", "{").replace("}}", "}")
messages = [{"role":"system","content":SYSTEM_INSTR},
            {"role":"user","content":user_txt}]

live2 = display(Markdown(""), display_id=True)
final_buf = []
for token in llm_stream(messages, reduce_max_new_tokens):
    final_buf.append(token)
    if len(final_buf) % 24 == 0:
        live2.update(Markdown("".join(final_buf)))
        sys.stdout.write(f"    ↳ 彙整 已產生字元：{len(''.join(final_buf))}\n"); sys.stdout.flush()
live2.update(Markdown("".join(final_buf)))
sys.stdout.write(f"    ↳ 彙整 已產生字元：{len(''.join(final_buf))}\n"); sys.stdout.flush()

final_text = "".join(final_buf).strip()

# Determine and create the summary output directory
summary_output_dir_abs = out_base_dir
summary_output_dir_abs.mkdir(parents=True, exist_ok=True)

# Determine the summary output file path using the stem of the input SRT
out_md = summary_output_dir_abs / f"{Path(summary_srt_path_abs).stem}_summary.md"


with open(out_md, "w", encoding="utf-8") as f:
    f.write(final_text)

print(f"→ 完成 ✅  {out_md}")
try:
    del llm
except Exception:
    pass
gc.collect()
if DEBUG_MODE: print("（顯存已釋放，如需重跑可直接再次執行）")