# 03a2 – LLM-Postprocessing v3 (VTT-nativ) auf allen Dev-Sessions

## Kontext

`02i_` testete den VTT-nativen Qwen3-8B-Ansatz auf 5 Sessions.
Dieses Notebook skaliert ihn auf **alle Dev-Sessions** via Glob-Pattern.

**Unterschied zu `02i_`:**
- `SESSIONS_GLOB` = Glob-Pattern → automatisch alle Sessions
- `OVERWRITE=False` → Resume-Modus: bereits verarbeitete VTTs überspringen
- Separate Ergebnis-CSV: `final_results_llm_qwen3_v3_by_session.csv`
- Baseline-Vergleich gegen `output_final_bs12_len20` (nicht E09)

## Ergebnis (Vorschau)

Auch auf allen Sessions keine WER-Verbesserung. LLM-Postprocessing endgültig abgeschlossen.

**Hinweis zum Bugfix:** Dieser finale Lauf basiert auf dem Stand **vor dem Segmentierungs-Bugfix**. Der Bugfix-Lauf folgt in `03b_`.

## 1 – GPU-Check & Auswahl

In [1]:
!nvidia-smi

Fri Jan 30 21:25:04 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:01:00.0 Off |                    0 |
| N/A   35C    P0             81W /  500W |    7093MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  |   00

In [2]:
import os

# Physische GPU-Auswahl: hier GPU 2 (siehe nvidia-smi)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Anpassen je nach Verfügbarkeit

## 2 – CUDA-Verifikation

In [3]:
import torch

In [4]:
print("CUDA available:", torch.cuda.is_available())
print("CUDA devices:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Device 0 name:", torch.cuda.get_device_name(0))
    print("Memory allocated:", torch.cuda.memory_allocated(0) / 1024**3, "GB")

CUDA available: True
CUDA devices: 1
Device 0 name: NVIDIA A100-SXM4-80GB
Memory allocated: 0.0 GB


## 3 – Setup

In [5]:
from pathlib import Path
import re
import shutil
from glob import glob
from tqdm.auto import tqdm

project_baseline_path = "/home/josch080/Projektgruppe/mcorec_baseline"
os.chdir(project_baseline_path)

from script.pg_utils_experiments import append_eval_results_for_experiments

  from .autonotebook import tqdm as notebook_tqdm
  if not hasattr(np, "object"):


In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer

## 4 – Konfiguration

`SESSIONS_GLOB` als Glob-Pattern statt Liste – findet automatisch alle Sessions.

In [7]:
MODEL_NAME = "Qwen/Qwen3-8B"
BASE_DIR = "data-bin/dev"
SESSIONS_GLOB = "data-bin/dev/session_*" # alle Sessions
INPUT_OUTPUT_DIRNAME = "output_final_bs12_len20" # Eingabe: Ergebnis aus 03_
OUTPUT_OUTPUT_DIRNAME = "output_pp_qwen3_8b_v3_bs12_len20"

In [8]:
# Inferenz-Settings
MAX_NEW_TOKENS = 8192 # Groß genug für eine gesamte VTT-Datei
OVERWRITE = False          # wenn False: überspringt schon vorhandene Dateien im Zielordner
SKIP_IF_MISSING_INPUT = True

## 5 – Modell laden

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True,
).eval()

# Für Generation ohne Warnungen
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id


Loading checkpoint shards: 100%|██████████| 5/5 [00:03<00:00,  1.37it/s]


## 6 – Hilfsfunktionen, Prompt, Cue-Guard, Session-Processing

Identisch zu `02i_`.

In [10]:
_TS_LINE = re.compile(r"\d{2}:\d{2}:\d{2}\.\d{3}\s-->\s\d{2}:\d{2}:\d{2}\.\d{3}")

def _timestamps(vtt_text: str):
    return _TS_LINE.findall(vtt_text)

def _extract_webvtt(raw: str) -> str | None:
    # remove reasoning blocks if present
    raw = re.sub(r"<think>.*?</think>\s*", "", raw, flags=re.DOTALL | re.IGNORECASE)

    # remove code fences if model used them
    raw = raw.replace("```webvtt", "```").replace("```vtt", "```")
    raw = re.sub(r"```+\s*", "", raw)

    # take the LAST WEBVTT in case the model mentions 'WEBVTT' in text above
    idx = raw.rfind("WEBVTT")
    if idx < 0:
        return None

    vtt = raw[idx:].strip() + "\n"
    return vtt


def _build_prompt(vtt_text: str) -> str:
    return (
        "You are working on a transcription of a video. You are only transcribing one speaker "
        "and not a complete dialog. Attached you find a first draft of the transcription. "
        "The draft probably includes mistakes due to rare, unclear, or domain-specific words.\n\n"

        "You should do two things:\n"
        "1. Review the attached transcription and summarise what the transcription is about.\n"
        "2. Based on your own summary, rewrite the transcription and exchange words which might "
        "be wrong due to the context of the overall transcription.\n\n"

        "IMPORTANT: Use the summary ONLY internally. Do NOT output the summary. Output ONLY the corrected WEBVTT starting with 'WEBVTT'.\n"
        "Also do NOT output <think> tags or any reasoning.\n\n"

        "Only exchange certain words with better fitting words which may sound similar from a "
        "pronunciation perspective.\n"
        "You MAY also fix missing or wrong small function words (e.g., a/the/to/of/your/I'm/it's) "
        "if it clearly improves grammatical correctness.\n"
        "Do NOT paraphrase or change meaning. Keep the same tone and level of formality.\n"
        "Keep the structure of the attached file exactly as it is. "
        "Keep all timestamps exactly as they are. "
        "The transcription must be complete. Do not leave any parts out. "
        "Do not add new content. Do not hallucinate.\n\n"
        "IMPORTANT: Use the summary ONLY internally. Do NOT output the summary. "
        "Output ONLY the corrected WEBVTT starting with 'WEBVTT'.\n"
        "Also do NOT output <think> tags or any reasoning.\n\n"

        "ADDITIONAL STRICT FORMAT RULES:\n"
        "- Keep the exact casing style of each cue line (if it is ALL CAPS, keep ALL CAPS).\n"
        "- Do NOT add punctuation that wasn't there (no commas, question marks, brackets, etc.).\n"
        "- NEVER insert placeholders like [DEVICE], [INAUDIBLE], [MUSIC], or anything in brackets.\n"
        "- Only replace individual words when highly confident; otherwise keep the original words.\n"
        "- Output ONLY the corrected WEBVTT starting with 'WEBVTT'. No summary, no reasoning.\n\n"
        
        "=== TRANSCRIPTION (WEBVTT) START ===\n\n"
        f"{vtt_text.strip()}\n\n"
        "=== TRANSCRIPTION END ==="
    )


@torch.inference_mode()
def smooth_vtt_with_qwen(vtt_text: str) -> str:
    prompt = _build_prompt(vtt_text)

    messages = [{"role": "user", "content": prompt}]
    chat_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer([chat_text], return_tensors="pt").to(model.device)

    generated = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=0.2,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    gen_ids = generated[0][inputs["input_ids"].shape[1]:]
    raw = tokenizer.decode(gen_ids, skip_special_tokens=True)

    cleaned = _extract_webvtt(raw)
    if cleaned is None:
        return vtt_text
    
    # Cue-Guards: verhindert ?, Kommas, [DEVICE], Case-Drift etc.
    cleaned = apply_line_level_guard(vtt_text, cleaned)
    
    # (optional) zusätzlich: wenn timestamps sich geändert haben → Original
    if _timestamps(cleaned) != _timestamps(vtt_text):
        return vtt_text
    return cleaned


def _resolve_sessions(sessions_spec, base_dir: str | Path | None = None):
    """
    sessions_spec:
      - str  : glob pattern (z.B. "data-bin/.../session_*")
      - list : ["session_40", "..."] oder ["/abs/path/session_40", ...]
    base_dir:
      - wenn list nur Session-Namen enthält, wird base_dir davor gesetzt
    """
    if isinstance(sessions_spec, (str, Path)):
        return [Path(p) for p in sorted(glob(str(sessions_spec)))]
    if isinstance(sessions_spec, (list, tuple)):
        out = []
        for s in sessions_spec:
            p = Path(s)
            if not p.is_absolute() and base_dir is not None:
                p = Path(base_dir) / p
            out.append(p)
        return out
    raise TypeError(f"sessions_spec must be str/Path or list/tuple, got {type(sessions_spec)}")

def process_sessions(
    sessions_glob,
    input_output_dirname: str,
    output_output_dirname: str,
    overwrite: bool = True,
    base_dir_for_list: str | Path | None = None,
):
    session_dirs = _resolve_sessions(sessions_glob, base_dir=base_dir_for_list)

    print(f"Found {len(session_dirs)} sessions.")
    print("CWD:", Path().resolve())

    missing_session_dirs = [p for p in session_dirs if not p.exists()]
    if missing_session_dirs:
        print("WARNING: these session dirs do not exist (first 5):", missing_session_dirs[:5])

    processed_sessions = 0
    skipped_missing_input = 0
    processed_vtts = 0

    for sdir in tqdm(session_dirs, desc="Sessions"):
        sdir = Path(sdir)
        in_dir = sdir / input_output_dirname
        out_dir = sdir / output_output_dirname

        if not in_dir.exists():
            skipped_missing_input += 1
            continue

        out_dir.mkdir(parents=True, exist_ok=True)
        processed_sessions += 1

        # copy non-vtt side files
        for p in in_dir.iterdir():
            if p.is_file() and p.suffix.lower() != ".vtt":
                dst = out_dir / p.name
                if overwrite or (not dst.exists()):
                    shutil.copy2(p, dst)

        vtt_files = sorted([p for p in in_dir.iterdir() if p.is_file() and p.suffix.lower() == ".vtt"])
        for vtt_path in tqdm(vtt_files, desc=f"{sdir.name}: VTTs", leave=False):
            dst_path = out_dir / vtt_path.name
            if dst_path.exists() and (not overwrite):
                continue
            vtt_text = vtt_path.read_text(encoding="utf-8", errors="replace")
            fixed = smooth_vtt_with_qwen(vtt_text)
            dst_path.write_text(fixed, encoding="utf-8")
            processed_vtts += 1

    print("Done.")
    print(f"Processed sessions: {processed_sessions}")
    print(f"Skipped (missing input dir '{input_output_dirname}'): {skipped_missing_input}")
    print(f"Processed VTT files: {processed_vtts}")

import re

def _cue_text_lines(vtt_text: str):
    # very simple: take non-empty lines that are not WEBVTT and not timestamps
    lines = []
    for line in vtt_text.splitlines():
        s = line.strip("\n")
        if not s.strip():
            continue
        if s.strip() == "WEBVTT":
            continue
        if re.match(r"^\d{2}:\d{2}:\d{2}\.\d{3}\s-->\s\d{2}:\d{2}:\d{2}\.\d{3}$", s.strip()):
            continue
        lines.append(s)
    return lines

def _is_all_caps(s: str) -> bool:
    letters = [c for c in s if c.isalpha()]
    return bool(letters) and all(c.isupper() for c in letters)

def _valid_line_change(orig: str, new: str) -> bool:
    # 1) No bracket placeholders
    if "[" in new or "]" in new:
        return False

    # 2) Preserve ALL CAPS style if original was ALL CAPS
    if _is_all_caps(orig) and not _is_all_caps(new):
        return False

    # 3) Disallow introducing NEW punctuation characters
    # (allow letters/digits/spaces/apostrophes; allow punctuation only if it already existed in orig)
    allowed_base = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 '")
    orig_extra = set([c for c in orig if c not in allowed_base])
    for c in new:
        if c in allowed_base:
            continue
        if c not in orig_extra:   # new punctuation char that didn't exist before
            return False

    return True

def apply_line_level_guard(original_vtt: str, corrected_vtt: str) -> str:
    orig_text = _cue_text_lines(original_vtt)
    new_text  = _cue_text_lines(corrected_vtt)

    if len(orig_text) != len(new_text):
        return original_vtt if original_vtt.endswith("\n") else (original_vtt + "\n")

    text_i = 0
    out_lines = []
    for line in corrected_vtt.splitlines():
        s = line.strip()
        if not s:
            out_lines.append(line)
            continue
        if s == "WEBVTT" or re.match(r"^\d{2}:\d{2}:\d{2}\.\d{3}\s-->\s\d{2}:\d{2}:\d{2}\.\d{3}$", s):
            out_lines.append(line)
            continue

        # this is a text cue line
        orig_line = orig_text[text_i]
        new_line  = new_text[text_i]
        if _valid_line_change(orig_line, new_line):
            out_lines.append(new_line)
        else:
            out_lines.append(orig_line)
        text_i += 1

    return "\n".join(out_lines).rstrip() + "\n"




## 8 – Postprocessing starten

In [11]:
process_sessions(
    sessions_glob=SESSIONS_GLOB,
    input_output_dirname=INPUT_OUTPUT_DIRNAME,
    output_output_dirname=OUTPUT_OUTPUT_DIRNAME,
    overwrite=OVERWRITE,
)


Found 25 sessions.
CWD: /home/josch080/Projektgruppe/mcorec_baseline


Sessions:   0%|          | 0/25 [00:00<?, ?it/s]
[Asion_132: VTTs:   0%|          | 0/6 [00:00<?, ?it/s]
[Asion_132: VTTs:  17%|█▋        | 1/6 [03:06<15:30, 186.13s/it]
[Asion_132: VTTs:  33%|███▎      | 2/6 [05:06<09:49, 147.35s/it]
[Asion_132: VTTs:  50%|█████     | 3/6 [07:07<06:45, 135.28s/it]
[Asion_132: VTTs:  67%|██████▋   | 4/6 [08:42<03:59, 119.64s/it]
[Asion_132: VTTs:  83%|████████▎ | 5/6 [10:17<01:50, 110.75s/it]
[Asion_132: VTTs: 100%|██████████| 6/6 [11:57<00:00, 107.11s/it]
Sessions:   4%|▍         | 1/25 [11:57<4:47:11, 717.97s/it]      
[Asion_133: VTTs:   0%|          | 0/6 [00:00<?, ?it/s]
[Asion_133: VTTs:  17%|█▋        | 1/6 [01:53<09:28, 113.61s/it]
[Asion_133: VTTs:  33%|███▎      | 2/6 [03:53<07:49, 117.50s/it]
[Asion_133: VTTs:  50%|█████     | 3/6 [06:09<06:17, 125.69s/it]
[Asion_133: VTTs:  67%|██████▋   | 4/6 [07:15<03:24, 102.02s/it]
[Asion_133: VTTs:  83%|████████▎ | 5/6 [09:16<01:48, 108.99s/it]
[Asion_133: VTTs: 100%|██████████| 6/6 [11:2

Done.
Processed sessions: 25
Skipped (missing input dir 'output_final_bs12_len20'): 0
Processed VTT files: 139





## 9 – Evaluation & Aggregation

In [14]:
EXPERIMENTS = {
    "pp_qwen3_8b_v3_bs12_len20": { 
        "llm_model": "qwen3_8b",
        "description": "Qwen3 8B v3"
    },
}


from glob import glob
from pathlib import Path

SESSIONS_GLOB = "data-bin/dev/session_*"
SESSION_IDS = [Path(p).name for p in sorted(glob(SESSIONS_GLOB))]   # -> ["session_0", "session_1", ...]

df_dev = append_eval_results_for_experiments(
    experiments=EXPERIMENTS,
    session_ids=SESSION_IDS,
    target_csv="final_results_llm_qwen3_v3_by_session.csv",
)



########## Evaluate für session_132 ##########
Starte Evaluate: /home/josch080/Projektgruppe/mcorec_train/bin/python script/evaluate.py --session_dir data-bin/dev_without_central_videos/dev/session_132 --output_dir_name output_ --label_dir_name labels
Evaluating 1 sessions

=== Evaluating session session_132 ===

--- Evaluating output dir: output_auto_avsr ---
Conversation clustering F1 score: 1.0
Speaker to WER: {'spk_0': 0.8698, 'spk_1': 0.8679, 'spk_2': 0.9183, 'spk_3': 0.8807, 'spk_4': 0.8512, 'spk_5': 0.8877}
Speaker clustering F1 score: {'spk_0': 1.0, 'spk_1': 1.0, 'spk_2': 1.0, 'spk_3': 1.0, 'spk_4': 1.0, 'spk_5': 1.0}
Joint ASR-Clustering Error Rate: {'spk_0': 0.4349, 'spk_1': 0.43395, 'spk_2': 0.45915, 'spk_3': 0.44035, 'spk_4': 0.4256, 'spk_5': 0.44385}

--- Evaluating output dir: output_avsr_cocktail ---
Conversation clustering F1 score: 1.0
Speaker to WER: {'spk_0': 0.5022, 'spk_1': 0.6208, 'spk_2': 0.4942, 'spk_3': 0.4947, 'spk_4': 0.657, 'spk_5': 0.7181}
Speaker clusteri

  results_df = pd.concat([results_df, new_df], ignore_index=True)


## 10 – Ergebnisanalyse: LLM-v3 vs. Final-Baseline

Baseline hier `output_final_bs12_len20` (aus `03_`), nicht E09 wie in `02f_`–`02i_`.

In [20]:
import pandas as pd
from pathlib import Path
from glob import glob

CSV_BASE = Path("final_results.csv")
CSV_QWEN = Path("final_results_llm_qwen3_v3_by_session.csv")

BASELINE_EXP = "output_final_bs12_len20"
NEW_EXP      = "output_pp_qwen3_8b_v3_bs12_len20"

SESSIONS_GLOB = "data-bin/dev/session_*"
SESSIONS = [Path(p).name for p in sorted(glob(SESSIONS_GLOB))]
N_TOTAL = len(SESSIONS)

# --- baseline WER (aggregiert) ---
df_base = pd.read_csv(CSV_BASE)
baseline_wer = float(df_base.loc[df_base["exp"] == BASELINE_EXP, "avg_speaker_wer"].iloc[0])

# --- qwen WER (mean over sessions) ---
df_q = pd.read_csv(CSV_QWEN)

if "timestamp" in df_q.columns:
    df_q["timestamp"] = pd.to_datetime(df_q["timestamp"], errors="coerce")
    df_q = (df_q.sort_values(["session", "exp", "timestamp"])
              .drop_duplicates(["session", "exp"], keep="last"))

df_q_new = df_q[(df_q["session"].isin(SESSIONS)) & (df_q["exp"] == NEW_EXP)].copy()
n_used = df_q_new["session"].nunique()

new_wer = float(df_q_new["avg_speaker_wer"].mean())

# --- deltas: positiv = Verbesserung ---
delta_abs = baseline_wer - new_wer
delta_rel = (delta_abs / baseline_wer) * 100

# --- one-line output ---
row = pd.DataFrame([{
    "baseline_exp": BASELINE_EXP,
    "new_exp": NEW_EXP,
    "baseline_WER": baseline_wer,
    "new_WER": new_wer,
    "sessions_used": f"{n_used}/{N_TOTAL}",
    "ΔWER_abs": delta_abs,
    "ΔWER_rel_%": delta_rel,
}])

display(row.style.format({
    "baseline_WER": "{:.6f}",
    "new_WER": "{:.6f}",
    "ΔWER_abs": "{:+.6f}",
    "ΔWER_rel_%": "{:+.3f}%",
}))

# Optional: exakt als "eine Zeile" printen (tab-separiert)
r = row.iloc[0]
print(
    f"{r['baseline_exp']}\t{r['new_exp']}\t"
    f"{r['baseline_WER']:.6f}\t{r['new_WER']:.6f}\t"
    f"{r['sessions_used']}\t{r['ΔWER_abs']:+.6f}\t{r['ΔWER_rel_%']:+.3f}%"
)


Unnamed: 0,baseline_exp,new_exp,baseline_WER,new_WER,sessions_used,ΔWER_abs,ΔWER_rel_%
0,output_final_bs12_len20,output_pp_qwen3_8b_v3_bs12_len20,0.497967,0.499619,25/25,-0.001652,-0.332%


output_final_bs12_len20	output_pp_qwen3_8b_v3_bs12_len20	0.497967	0.499619	25/25	-0.001652	-0.332%


## 11 – Interpretation

| Metrik | Baseline (`output_final_bs12_len20`) | LLM v3 (`output_pp_qwen3_8b_v3_bs12_len20`) | Δ |
|--------|--------------------------------------|---------------------------------------------|---|
| Speaker WER (↓) | 0.4980 | 0.4996 | +0.0016 |
| Conv F1 (↑) | 0.8153 | 0.8153 | 0.0000 |
| Joint Error (↓) | 0.3453 | 0.3566 | +0.0113 |

Das LLM-Postprocessing **verschlechtert** WER (+0.0016).

**Abschluss LLM-Postprocessing:** Alle getesteten Varianten (E38–E55, `03a_`, `03a2_`)
bringen keine signifikante oder robuste WER-Verbesserung – E55/`03a2_` verschlechtert sogar leicht.
Die optimale Konfiguration für die Einreichung folgt aus dem min_duration-Grid (`02k_`):
min_on=1.0, min_off=1.2, beam=12, len=20 → WER 0.4943 auf allen 25 Dev-Sessions (`04_dev_final_results`).
Hinweis: E72 erzielte 0.4902 auf den 5 Grid-Search-Sessions (`02k_`); auf allen 25 Sessions ist der Wert 0.4943.
