<a href="https://colab.research.google.com/github/Raniamea/arabic-video-summarisation/blob/main/notebooks/04_validate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#  Transcript Validation using CAMeLBERT-MSA & Arabic Captions
Compare a cleaned Arabic ASR transcript against scene captions to improve transcript accuracy.
- Uses `diac`, `lemma`, `pos` from transcript segments
- Uses Arabic captions generated previously
- CAMeLBERT-MSA for semantic validation
- Sliding window for misalignment
- Outputs: **replace**, **append**, and **flag** transcript versions

In [None]:
# ✅ Reset environment first
!pip uninstall -y torch torchvision torchaudio transformers tokenizers \
  sentence-transformers huggingface_hub camel_tools opencv-python opencv-contrib-python \
  opencv-python-headless numpy

# ✅ Core installs (compatible versions, no auto-upgrades)
!pip install --no-cache-dir numpy==1.23.5
!pip install --no-cache-dir torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
!pip install --no-cache-dir transformers==4.41.2 tokenizers==0.19.1
!pip install --no-cache-dir sentence-transformers==2.2.2
!pip install --no-cache-dir huggingface_hub==0.23.2 tqdm==4.66.5

# ✅ OpenCV (avoid headless conflicts)
!pip install --no-cache-dir opencv-python==4.7.0.72 opencv-contrib-python==4.7.0.72

# ✅ CAMeL Tools (needs old numpy, already pinned to 1.23.5)
!pip install --no-cache-dir camel-tools==1.5.2


In [None]:
import numpy, torch, transformers, tokenizers, sentence_transformers, huggingface_hub, cv2
print("NumPy:", numpy.__version__)
print("Torch:", torch.__version__)
print("Transformers:", transformers.__version__)
print("Tokenizers:", tokenizers.__version__)
print("Sentence-Transformers:", sentence_transformers.__version__)
print("HF Hub:", huggingface_hub.__version__)
print("OpenCV:", cv2.__version__)

# Smoke tests
from transformers import AutoTokenizer, AutoModel
tok = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
print("Tokenizer OK")


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os, json, re
from bisect import bisect_left, bisect_right
from typing import List, Dict, Any, Tuple, Iterable
from collections import Counter

import torch
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

# ---------------------------
# Project Paths
# ---------------------------
base_path = "/content/drive/MyDrive/ArabicVideoSummariser"
params_path = os.path.join(base_path, "params.json")

with open(params_path, "r", encoding="utf-8") as f:
    params = json.load(f)

video_filename = params.get("video_file")
video_filename = "Almasbagha.mp4"

assert video_filename, "params.json must include 'video_file'."
video_name  = os.path.splitext(video_filename)[0]

videos_path      = os.path.join(base_path, "videos")
captions_path    = os.path.join(base_path, "captions")
preprocessed_path= os.path.join(base_path, "Preprocessed")
validated_path   = os.path.join(base_path, "Validated")
os.makedirs(validated_path, exist_ok=True)

caption_path   = os.path.join(captions_path,   f"{video_name}.json")
transcript_path= os.path.join(preprocessed_path, f"{video_name}_CleanTranscript.json")

assert os.path.exists(caption_path),   f"Missing captions file: {caption_path}"
assert os.path.exists(transcript_path),f"Missing transcript file: {transcript_path}"

# ---------------------------
# Model (recommended)
# ---------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device=device)


In [None]:
# ---------------------------
# Arabic helpers
# ---------------------------

_AR_DIAC = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]")
_AR_PUNCT = re.compile(r"[^\w\u0600-\u06FF]+", re.UNICODE)

AR_STOPWORDS = {
    "و", "في", "على", "من", "إلى", "عن", "أن", "إن", "كان", "كانت", "يكون", "مع", "هذا", "هذه",
    "ذلك", "تلك", "هناك", "هنا", "هو", "هي", "هم", "هن", "كما", "لكن", "بل", "قد", "تم", "ثم",
    "كل", "أي", "أو", "أمام", "خلال", "بعد", "قبل", "حتى", "حيث", "إذا", "إنما", "إما", "لدى",
    "لدي", "لها", "له", "لهم", "لنا", "ما", "ماذا", "لماذا", "كيف", "متى", "أيضا", "بدون", "أمام",
    "داخل", "خارج", "بين", "أكثر", "أقل"
}

# very light clitic/affix list (heuristic, safe)
CLITIC_PREFIXES = ("و", "ف", "ب", "ك", "ل", "س")     # single-letter clitics
DEF_ART = "ال"                                       # definite article
CLITIC_SUFFIXES = ("ه", "ها", "هم", "هن", "كما", "كم", "نا", "ي")  # pronoun suffixes (approx)

def ar_normalize(text: str) -> str:
    if not isinstance(text, str):
        text = "" if text is None else str(text)
    t = _AR_DIAC.sub("", text)
    t = t.replace("أ","ا").replace("إ","ا").replace("آ","ا")
    t = t.replace("ة","ه").replace("ى","ي")
    t = _AR_PUNCT.sub(" ", t)
    t = re.sub(r"\s+"," ", t).strip()
    return t

def split_clitics(token: str) -> List[str]:
    """Split very common prefixes and remove the definite article if present."""
    tok = token
    pieces = []

    # peel single-letter prefixes repeatedly (e.g., و + ب + ل + ...)
    while tok and tok[0] in CLITIC_PREFIXES:
        pieces.append(tok[0])
        tok = tok[1:]

    # remove 'ال' if present (definite article)
    if tok.startswith(DEF_ART) and len(tok) > 2:
        pieces.append(DEF_ART)
        tok = tok[2:]

    # heuristic suffix stripping (only one pass; we keep core meaning)
    for suf in CLITIC_SUFFIXES:
        if tok.endswith(suf) and len(tok) > len(suf) + 1:
            tok = tok[: -len(suf)]
            break

    # Return content part last (so tokens = [prefixes..., core])
    if tok:
        pieces.append(tok)
    return pieces

def ar_tokens(text: str) -> List[str]:
    """Normalize then split into tokens; keep the core (post-clitic) token for content scoring."""
    norm = ar_normalize(text)
    raw = [w for w in norm.split() if w]
    # keep both raw tokens (for exact) and stripped cores for content/lemma-ish
    cores = []
    for w in raw:
        parts = split_clitics(w)
        if parts:
            cores.append(parts[-1])  # last piece = core
    return cores

def ngrams(tokens: List[str], n: int) -> Iterable[Tuple[str, ...]]:
    return zip(*[tokens[i:] for i in range(n)])


In [None]:
# ---------------------------
# Captions prep & time windows
# ---------------------------

def prep_captions(captions_data: Dict[str, Any]):
    """Return sorted lists by scene_time: times[], raw_texts[], norm_tokens(core)[], joined_texts_for_embed."""
    rows = []
    for v in captions_data.values():
        rows.append({
            "t": float(v.get("scene_time", 0.0)),
            "text": v.get("arabic", "") or ""
        })
    rows.sort(key=lambda x: x["t"])

    times = [r["t"] for r in rows]
    texts = [r["text"] for r in rows]
    tokens_core = [ar_tokens(r["text"]) for r in rows]  # core tokens (pseudo-lemmas)
    return times, texts, tokens_core

def time_window_indices(times: List[float], center_time: float, half_window_sec: float) -> Tuple[int, int]:
    """Return slice [lo:hi) covering captions within ±half_window_sec around center_time."""
    if not times:
        return 0, 0
    lo = bisect_left(times, center_time - half_window_sec)
    hi = bisect_right(times, center_time + half_window_sec)
    lo = max(0, lo); hi = min(len(times), hi)
    if lo >= hi:
        # fallback: nearest single index
        j = min(max(bisect_left(times, center_time), 0), len(times)-1)
        return j, j+1
    return lo, hi


In [None]:
# ---------------------------
# Lexical overlap & fusion scoring
# ---------------------------

def jaccard(set_a: set, set_b: set) -> float:
    if not set_a and not set_b: return 0.0
    inter = len(set_a & set_b)
    union = len(set_a | set_b)
    return inter / union if union else 0.0

def lexical_overlap(
    seg_lemmas: List[str],
    cap_core_tokens_window: List[List[str]],  # list per caption in window
    seg_tokens_core: List[str],
    use_ngrams: bool = True
) -> float:
    # Caption bag of core tokens
    cap_bag = set()
    for toks in cap_core_tokens_window:
        cap_bag.update(toks)

    # Lemma overlap (transcript side has true lemmas)
    lem_bag = set([l for l in seg_lemmas if l and l not in AR_STOPWORDS])
    lemma_j = jaccard(lem_bag, cap_bag)

    if not use_ngrams:
        return lemma_j

    # n-gram overlap from core content tokens (2-gram and 3-gram)
    seg_bi = set(ngrams([t for t in seg_tokens_core if t not in AR_STOPWORDS], 2))
    seg_tri = set(ngrams([t for t in seg_tokens_core if t not in AR_STOPWORDS], 3))

    # Build caption n-grams
    cap_tokens_linear = []
    for toks in cap_core_tokens_window:
        cap_tokens_linear.extend(toks)

    cap_bi  = set(ngrams([t for t in cap_tokens_linear if t not in AR_STOPWORDS], 2))
    cap_tri = set(ngrams([t for t in cap_tokens_linear if t not in AR_STOPWORDS], 3))

    # Weight trigrams a bit more than bigrams
    bi_overlap  = jaccard(set(map(tuple, seg_bi)),  set(map(tuple, cap_bi)))
    tri_overlap = jaccard(set(map(tuple, seg_tri)), set(map(tuple, cap_tri)))

    # Weighted lexical score
    return 0.6 * lemma_j + 0.25 * bi_overlap + 0.15 * tri_overlap

def fusion_score(lex: float, cos: float, alpha: float = 0.5) -> float:
    # alpha: weight for lexical; (1-alpha): for cosine
    return alpha * lex + (1.0 - alpha) * cos


In [None]:
# ---------------------------
# Safeguards
# ---------------------------

def is_propn(pos_tag: str) -> bool:
    """Heuristic: consider these as proper nouns."""
    if not pos_tag: return False
    tag = pos_tag.upper()
    return any(k in tag for k in ("PROPN", "NNP", "NOUN_PROP", "PROPER"))

def should_backoff_too_much_removed(kept: List[str], dropped: List[str], max_removed_ratio: float = 0.3) -> bool:
    # Only consider content words for the ratio
    kept_content    = [w for w in kept    if len(w) >= 3 and w not in AR_STOPWORDS]
    dropped_content = [w for w in dropped if len(w) >= 3 and w not in AR_STOPWORDS]
    total = len(kept_content) + len(dropped_content)
    if total == 0: return False
    return (len(dropped_content) / total) > max_removed_ratio


In [None]:
# ---------------------------
# Main validator
# ---------------------------

def validate_words_by_visual_support(
    transcript_data: List[Dict[str, Any]],
    captions_data: Dict[str, Any],
    model: SentenceTransformer,
    half_window_sec: float = 10.0,     # time-based window (± seconds)
    alpha_fusion: float = 0.5,         # weight of lexical vs embedding
    sim_threshold: float = 0.55,       # cosine threshold for window vs segment
    min_word_len: int = 3,             # ignore very short tokens
    propn_keep_margin: float = 0.15,   # keep PROPN unless score < (sim_threshold - margin)
    backoff_removed_ratio: float = 0.30 # backoff if >30% content words would be removed
):
    # Prep captions once
    cap_times, cap_texts, cap_core_tokens = prep_captions(captions_data)
    if not cap_texts:
        raise ValueError("No captions available for comparison.")

    enriched = []
    kept_total, dropped_total = 0, 0

    for idx, seg in tqdm(enumerate(transcript_data), total=len(transcript_data)):
        seg_dict = seg if isinstance(seg, dict) else {"text": str(seg)}

        # Prefer text_norm > text > original
        seg_text   = seg_dict.get("text_norm") or seg_dict.get("text") or seg_dict.get("original") or ""
        seg_lemmas = seg_dict.get("lemmas", []) or seg_dict.get("lemma", [])
        seg_pos    = seg_dict.get("pos", [])  # can be list or string
        seg_start  = float(seg_dict.get("start", 0.0)) if "start" in seg_dict else None
        seg_end    = float(seg_dict.get("end", 0.0)) if "end" in seg_dict else None
        seg_mid    = ((seg_start + seg_end)/2.0) if (seg_start is not None and seg_end is not None) else None

        # Time window around segment midpoint (or index fallback)
        if seg_mid is not None:
            lo, hi = time_window_indices(cap_times, seg_mid, half_window_sec)
        else:
            # fallback to a narrow window by index (kept for robustness)
            center = min(idx, len(cap_times) - 1)
            lo, hi = max(0, center-3), min(len(cap_times), center+4)

        # Build window text for embeddings
        window_text = " ".join(cap_texts[lo:hi])

        # Segment embedding vs window
        with torch.no_grad():
            seg_emb = model.encode(seg_text, convert_to_tensor=True, show_progress_bar=False)
            win_emb = model.encode(window_text, convert_to_tensor=True, show_progress_bar=False)
            cos = float(util.cos_sim(seg_emb, win_emb).item())

        # Lexical overlap (lemmas + ngrams) vs caption cores in the window
        lex = lexical_overlap(
            seg_lemmas,
            cap_core_tokens[lo:hi],
            ar_tokens(seg_text),
            use_ngrams=True
        )

        fused = fusion_score(lex, cos, alpha=alpha_fusion)

        # Tokenize segment words for keep/drop decision
        # Keep original surface words to reconstruct validated_text
        raw_words = [w for w in re.sub(r"\s+", " ", seg_text).split(" ") if w]

        kept_words, dropped_words = [], []
        for i, w in enumerate(raw_words):
            w_norm = ar_normalize(w)
            if len(w_norm) < min_word_len or w_norm in AR_STOPWORDS:
                kept_words.append(w)
                continue

            # If POS exists and says PROPN, safeguard (don't drop unless very low support)
            is_name = False
            if isinstance(seg_pos, list) and i < len(seg_pos):
                is_name = is_propn(str(seg_pos[i]))
            elif isinstance(seg_pos, str):
                # crude: if whole segment POS tag string contains PROPN
                is_name = is_propn(seg_pos)

            # Decision by fused score; for PROPN, allow a margin
            thr = sim_threshold - (propn_keep_margin if is_name else 0.0)
            if fused >= thr:
                kept_words.append(w)
            else:
                dropped_words.append(w)

        # Backoff if we removed too much
        if should_backoff_too_much_removed(kept_words, dropped_words, max_removed_ratio=backoff_removed_ratio):
            kept_words = raw_words
            dropped_words = []

        validated_text = " ".join(kept_words)

        seg_out = dict(seg_dict)
        seg_out["visual_window_idx"] = [lo, hi-1]
        seg_out["visual_context"] = window_text
        seg_out["scores"] = {"cosine": cos, "lexical": lex, "fused": fused}
        seg_out["validated_text"] = validated_text
        seg_out["dropped_words"] = dropped_words
        seg_out["kept_words"] = kept_words

        enriched.append(seg_out)
        kept_total += len(kept_words)
        dropped_total += len(dropped_words)

    summary = {
        "segments": len(enriched),
        "kept_words": kept_total,
        "dropped_words": dropped_total,
        "params": {
            "half_window_sec": half_window_sec,
            "alpha_fusion": alpha_fusion,
            "sim_threshold": sim_threshold,
            "min_word_len": min_word_len,
            "propn_keep_margin": propn_keep_margin,
            "backoff_removed_ratio": backoff_removed_ratio
        }
    }
    return enriched, summary


In [None]:
# ---------------------------
# Load data
# ---------------------------
with open(caption_path, "r", encoding="utf-8") as f:
    captions_data = json.load(f)

with open(transcript_path, "r", encoding="utf-8") as f:
    transcript_data = json.load(f)

# ---------------------------
# Run validation
# ---------------------------
enriched, stats = validate_words_by_visual_support(
    transcript_data=transcript_data,
    captions_data=captions_data,
    model=model,
    half_window_sec=10.0,     # tune 8–12s
    alpha_fusion=0.5,         # 0.4–0.6 usually stable
    sim_threshold=0.55,       # raise for stricter, lower for more permissive
    min_word_len=3,
    propn_keep_margin=0.15,
    backoff_removed_ratio=0.30
)

print("Summary:", stats)

# ---------------------------
# Save outputs into Validated/
# ---------------------------
base_name = os.path.splitext(os.path.basename(transcript_path))[0]
out_json = os.path.join(validated_path, f"{base_name}_ValidatedWords.json")

with open(out_json, "w", encoding="utf-8") as f:
    json.dump({"summary": stats, "segments": enriched}, f, ensure_ascii=False, indent=2)

print("✅ Saved:", out_json)
