# Overall Semantic Chunking Logic: A top -> down approach

## Why a Top -> Down Chunking Strategy Is Preferable for Semantic Segmentation

During the course of this assignment, multiple approaches to semantic chunking were explored. An initially appealing strategy was a bottom -> up approach, in which small units (words or short segments) are progressively merged to form larger chunks. However, deeper analysis revealed that a top -> down approach, which begins with the entire audio and recursively partitions it into smaller chunks until all of them are <15 seconds in length, might be more suitable for preserving semantic context and producing coherent segments.

### Context as a First-Class Signal

Semantic chunking is not merely about dividing audio into manageable lengths; it is about identifying meaningful conceptual boundaries. Such boundaries are inherently contextual: whether a transition is semantic depends not only on local text, but also on its relationship to surrounding content.
A top -> down approach naturally treats context as a global signal. At each split, both sides of a candidate cut are evaluated in the presence of their broader temporal neighborhood. This enables semantic decisions to be informed by:
what preceded the cut,
what follows it,
and how each side relates to its respective context.
In contrast, a bottom -> up approach constructs chunks incrementally from local units. Once early merges are made, the algorithm loses access to the original global structure. Context is implicitly baked into earlier decisions and cannot be revisited or corrected later.

### Moving-Window Context in Top -> Down Splitting

A major advantage of top -> down chunking is the ability to use a moving context window when evaluating candidate splits.
At any stage, a large chunk can be split,
embedding the left and right sub-chunks, and more importantly:

We can use the extra embeddings from the **extended left and right neighbours** and scoring the split based on semantic contrast and internal coherence.
Because the parent chunk still exists as a coherent whole, these embeddings are computed against meaningful, contiguous context. The algorithm can ask questions such as:


“Is this split a real conceptual transition, or just a change in phrasing inside the same idea?”


This form of contextual reasoning is difficult to replicate in bottom -> up systems, where context must be approximated from previously merged fragments that may not align with true semantic units.

# Installing Dependencies

In [1]:
# 1. First, install the core foundations with strict version caps
!pip install -q \
    "numpy<2.1.0" \
    "pandas==2.2.2" \
    "protobuf<5.0.0" \
    "huggingface-hub<1.0" \
    "pillow<12.0" \
    "fsspec<=2025.3.0"

# 2. Then install your tools (PyTorch and Pyannote)
!pip install -q \
    "torch==2.5.1" \
    "torchvision==0.20.1" \
    "torchaudio==2.5.1" \
    pyannote.audio

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/294.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
grpcio-status 1.71.2 requires protobuf<6.0dev,>=5.26.1, but you have protobuf 4.25.8 which is incompatible.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 4.25.8 which is incompatible.
ydf 0.13.0 requires protobuf<7.0.0,>=5.29.1, but you have protobuf 4.25.8 which is incompatible.[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.6/59.6 kB[0

In [2]:
import numpy as np
import pyannote.audio
print(f"NumPy version: {np.__version__}")
print("Pyannote imported successfully!")

NumPy version: 2.0.2
Pyannote imported successfully!


In [3]:
!apt-get install -y nodejs npm
!pip install -U yt-dlp
!pip install -U ffmpeg-python
!pip install silero-vad

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  gyp javascript-common libc-ares2 libjs-events libjs-highlight.js
  libjs-inherits libjs-is-typedarray libjs-psl libjs-source-map
  libjs-sprintf-js libjs-typedarray-to-buffer libnode-dev libnode72
  libnotify-bin libnotify4 libuv1-dev node-abab node-abbrev node-agent-base
  node-ansi-regex node-ansi-styles node-ansistyles node-aproba node-archy
  node-are-we-there-yet node-argparse node-arrify node-asap node-asynckit
  node-balanced-match node-brace-expansion node-builtins node-cacache
  node-chalk node-chownr node-clean-yaml-object node-cli-table node-clone
  node-color-convert node-color-name node-colors node-columnify
  node-combined-stream node-commander node-console-control-strings
  node-copy-concurrently node-core-util-is node-coveralls node-cssom
  node-cssstyle node-debug node-decompress-response node-defaults
  node-delayed-st

# Procuring and Preparing Data

We download the highest-quality audio-only stream using yt-dlp to minimize file size and avoid unnecessary video processing. The audio is converted to 16 kHz mono PCM WAV using ffmpeg, as this matches the training conditions of Whisper and Silero VAD, leading to more stable transcription and timestamp alignment.

In [55]:
!yt-dlp \
  -f bestaudio \
  -o "input_audio.%(ext)s" \
  "https://www.youtube.com/watch?v=Sby1uJ_NFIY"

[youtube] Extracting URL: https://www.youtube.com/watch?v=Sby1uJ_NFIY
[youtube] Sby1uJ_NFIY: Downloading webpage
[youtube] Sby1uJ_NFIY: Downloading android sdkless player API JSON
[youtube] Sby1uJ_NFIY: Downloading web safari player API JSON
[youtube] Sby1uJ_NFIY: Downloading m3u8 information
[info] Sby1uJ_NFIY: Downloading 1 format(s): 251-9
[download] input_audio.webm has already been downloaded
[K[download] 100% of   29.57MiB


The audio stream downloaded from YouTube is encoded in Opus format at 48 kHz stereo. We convert it to a 16 kHz mono PCM WAV file using ffmpeg, as this matches the expected input format of Whisper-based ASR and voice activity detection models. This conversion ensures stable transcription quality and accurate timestamp alignment.

In [56]:
!ffmpeg -y -i input_audio.webm -ac 1 -ar 16000 input_audio.wav

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

In [57]:
import torchaudio

AUDIO_PATH = "input_audio.wav"  # your extracted audio

waveform, sample_rate = torchaudio.load(AUDIO_PATH)

# Convert to mono if needed
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

# Resample to 16 kHz if needed
if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(
        orig_freq=sample_rate,
        new_freq=16000
    )
    waveform = resampler(waveform)
    sample_rate = 16000

# Silero expects 1D tensor
audio = waveform.squeeze()

# Identifying Candidates for Chunk Split

Speaker diarization allows us to map chunks of audio to speakers. When we run this cell, we will need to use an HF token, as well as accept terms and provide some details to HF using the 2 links in the code below

In [58]:
import torch
from getpass import getpass
from pyannote.audio import Pipeline

# 1. Configuration
AUDIO_PATH = "input_audio.wav"

# 2. Authentication
# Note: Ensure you accepted terms at:
# https://hf.co/pyannote/speaker-diarization-3.1
# https://hf.co/pyannote/segmentation-3.0
HF_TOKEN = getpass("Enter Hugging Face token: ").strip()

# 3. Load Pipeline
print("Loading pipeline...")
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=HF_TOKEN
)

# 4. Enable CUDA (GPU)
device = torch.device("cuda")
if pipeline is not None:
    pipeline.to(device)
    print(f"Pipeline loaded and running on: {device}")

    # 5. Run Diarization
    print("Processing audio (this may take a few minutes)...")
    diarization = pipeline(AUDIO_PATH)

    # 6. Format and Print Results
    diar_segments = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        diar_segments.append({
            "start": round(float(turn.start)),
            "end": round(float(turn.end)),
            "speaker": speaker
        })

    # Sort and display
    diar_segments.sort(key=lambda x: x["start"])
    print(f"\n✅ Done! Found {len(diar_segments)} segments.")

    # Print first 10 segments as a preview
    for seg in diar_segments[:10]:
        print(f"[{seg['start']}s - {seg['end']}s] {seg['speaker']}")
else:
    print("Error: Pipeline could not be loaded. Check your token/permissions.")


Enter Hugging Face token: ··········
Loading pipeline...


/usr/local/lib/python3.12/dist-packages/lightning_fabric/utilities/cloud_io.py:73: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Pipeline loaded and running on: cuda
Processing audio (this may take a few minutes)...


  std = sequences.std(dim=-1, correction=1)



✅ Done! Found 115 segments.
[0s - 4s] SPEAKER_03
[4s - 4s] SPEAKER_04
[4s - 4s] SPEAKER_03
[9s - 10s] SPEAKER_04
[12s - 17s] SPEAKER_04
[18s - 19s] SPEAKER_04
[20s - 21s] SPEAKER_04
[22s - 22s] SPEAKER_04
[23s - 25s] SPEAKER_04
[25s - 27s] SPEAKER_04


Running silero vad to identify silences. Therefore, we won't have to run vad again during whisper transcription. The silero vad arguments have NOT been tweaked from their default values. As I explored some other minimum silence duration values, I realised tweaking it was not necessary.

In [59]:
import torch
from silero_vad import get_speech_timestamps, load_silero_vad

# Load model
vad_model = load_silero_vad()

# Run VAD
speech_timestamps = get_speech_timestamps(
    audio,
    vad_model,
    sampling_rate=16000,
    min_speech_duration_ms=300,
    min_silence_duration_ms=100,
)

In [60]:
speech_segments = [
    {
        "start": ts["start"] / 16000,
        "end": ts["end"] / 16000,
        "duration": (ts["end"] - ts["start"]) / 16000,
    }
    for ts in speech_timestamps
]

In [61]:
speech_segments = sorted(speech_segments, key=lambda x: x["start"])

This is a list of all silences detected by silero vad. We make sure that the minimum silence that is enforced is 100 milliseconds which coinicides with that given for vad.

In [62]:
vad_silences = []

MIN_SILENCE = 0.1  # seconds (100 ms)

for i in range(len(speech_segments) - 1):
    silence_start = speech_segments[i]["end"]
    silence_end = speech_segments[i + 1]["start"]

    silence_duration = silence_end - silence_start

    if silence_duration >= MIN_SILENCE:
        vad_silences.append({
            "start": silence_start,
            "end": silence_end,
            "duration": silence_duration
        })


faster-whisper allows us to run whisper but using a much more efficient CTranslate2 engine. It drastically speeds up inference without actually losing out on accuracy.

In [63]:
!pip install faster-whisper soundfile



In [64]:
from faster_whisper import WhisperModel

model = WhisperModel(
    "large-v3",
    device="cuda",        # "cuda" if available or "cpu"
    compute_type="float16"  # int8 for CPU, float16 for GPU
)


Embedder is required to find the embeddings of any given text. This is the best free one I could find.

In [65]:
from sentence_transformers import SentenceTransformer
import numpy as np

_embedder = SentenceTransformer("all-MiniLM-L6-v2")

def embed(text: str) -> np.ndarray:
    if not text.strip():
        # handle empty text safely
        return np.zeros(_embedder.get_sentence_embedding_dimension())
    return _embedder.encode(text, normalize_embeddings=True)


In [66]:
segments, info = model.transcribe(
    "input_audio.wav",               # path to your audio
    beam_size=5,
    word_timestamps=True,      # <-- THIS IS THE KEY
    vad_filter=False           # you already ran your own VAD
)

We start with words as our atomic blocks, retaining ALL information about every word including start, end and the speaker. The reason I worked with words was to get maximum control and finer granularity over where I can create the next chunk (chunking at arbitrary time points instead of whisper segment boundaries).

In [67]:
words = []

for seg in segments:
    if seg.words is None:
        continue

    for w in seg.words:
        # w.word already includes leading space sometimes → strip it
        words.append({
            "start": float(w.start),
            "end": float(w.end),
            "text": w.word.strip()
        })


These 2 functions enable word level diarization and speaker based chunk scoring. We only believe the diarization when there is at least 0.2 seconds of overlap between words and speaker.

In [68]:
from collections import Counter

def assign_speakers_to_words(
    words,
    diar_segments,
    default_speaker="UNK",
    min_overlap=0.2,      # seconds of overlap required to trust a label
    max_snap=0.25          # seconds: if no overlap, snap to nearest segment if close
):
    """
    Assign speaker label to each word by maximizing overlap between:
      word interval [w.start, w.end] and diarization segment [seg.start, seg.end].
    Falls back to nearest segment if within max_snap.
    """
    diar_segments = sorted(diar_segments, key=lambda x: x["start"])

    for w in words:
        ws = float(w["start"])
        we = float(w["end"])
        if we <= ws:
            w["speaker"] = default_speaker
            continue

        best_spk = default_speaker
        best_ov = 0.0

        # 1) Max-overlap match
        for seg in diar_segments:
            ss = float(seg["start"])
            se = float(seg["end"])
            ov = max(0.0, min(we, se) - max(ws, ss))
            if ov > best_ov:
                best_ov = ov
                best_spk = seg["speaker"]

        if best_ov >= min_overlap:
            w["speaker"] = best_spk
            continue

        # 2) No overlap: snap to nearest diar segment if close
        mid = 0.5 * (ws + we)
        nearest_spk = default_speaker
        nearest_dist = float("inf")
        for seg in diar_segments:
            ss = float(seg["start"])
            se = float(seg["end"])
            # distance from mid to segment (0 if inside)
            dist = 0.0 if (ss <= mid <= se) else min(abs(mid - ss), abs(mid - se))
            if dist < nearest_dist:
                nearest_dist = dist
                nearest_spk = seg["speaker"]

        w["speaker"] = nearest_spk if nearest_dist <= max_snap else default_speaker

    return words

def speaker_stats_between(words, start, end, default_speaker="UNK"):
    spks = []
    for w in words:
        mid = 0.5 * (w["start"] + w["end"])
        if start <= mid < end:
            spks.append(w.get("speaker", default_speaker))

    if not spks:
        return {"dominant": default_speaker, "purity": 0.0, "counts": Counter()}

    counts = Counter(spks)
    dominant, dom_n = counts.most_common(1)[0]
    return {"dominant": dominant, "purity": dom_n / len(spks), "counts": counts}

words = assign_speakers_to_words(words, diar_segments)
print("First labeled words:", words[:20])

First labeled words: [{'start': 0.0, 'end': 0.44, 'text': 'Congratulations', 'speaker': 'SPEAKER_03'}, {'start': 0.44, 'end': 0.84, 'text': 'to', 'speaker': 'SPEAKER_03'}, {'start': 0.84, 'end': 0.96, 'text': 'you', 'speaker': 'SPEAKER_03'}, {'start': 0.96, 'end': 1.2, 'text': 'Mr.', 'speaker': 'SPEAKER_03'}, {'start': 1.22, 'end': 1.56, 'text': 'Raghavan', 'speaker': 'SPEAKER_03'}, {'start': 1.56, 'end': 1.74, 'text': 'for', 'speaker': 'SPEAKER_03'}, {'start': 1.74, 'end': 1.94, 'text': 'that.', 'speaker': 'SPEAKER_03'}, {'start': 2.08, 'end': 2.38, 'text': 'Thank', 'speaker': 'SPEAKER_03'}, {'start': 2.38, 'end': 2.48, 'text': 'you', 'speaker': 'SPEAKER_03'}, {'start': 2.48, 'end': 2.56, 'text': 'so', 'speaker': 'SPEAKER_03'}, {'start': 2.56, 'end': 2.76, 'text': 'much', 'speaker': 'SPEAKER_03'}, {'start': 2.76, 'end': 2.86, 'text': 'for', 'speaker': 'SPEAKER_03'}, {'start': 2.86, 'end': 3.1, 'text': 'joining', 'speaker': 'SPEAKER_03'}, {'start': 3.1, 'end': 3.36, 'text': 'us.', 'spe

This checks whether or not speaker has changed based on a minimum timing gap and minimum amount of words on either side of speaker. Since we anyways have a minimum of 3 seconds of audio per chunk (I enforce this later on), we dont need to tamper with this at all. Any speaker will naturally speak at least 3-4 words in a duration of 3 seconds.

In [69]:
def speaker_change_candidates(
    words,
    segment,
    ignore_speakers={"UNK"},
    min_gap=0.2,          # seconds; optionally require a small timing gap
    min_run_words=1       # require this many words on each side (reduces jitter)
):
    """
    Candidate cut times whenever the speaker label changes between adjacent words.
    Returns times in seconds, inside (segment.start, segment.end).
    """
    seg_start, seg_end = segment["start"], segment["end"]

    # Select words that belong to this segment (midpoint assignment = stable at boundaries)
    seg_words = []
    for w in words:
        mid = 0.5 * (float(w["start"]) + float(w["end"]))
        if seg_start <= mid < seg_end:
            seg_words.append(w)

    if len(seg_words) < (2 * min_run_words + 1):
        return []

    cands = []
    for i in range(min_run_words, len(seg_words) - min_run_words):
        w1 = seg_words[i - 1]
        w2 = seg_words[i]

        spk1 = w1.get("speaker", "UNK")
        spk2 = w2.get("speaker", "UNK")

        if spk1 in ignore_speakers or spk2 in ignore_speakers:
            continue
        if spk1 == spk2:
            continue

        gap = float(w2["start"]) - float(w1["end"])
        if gap < min_gap:
            continue

        # cut between words (midpoint between w1 end and w2 start is robust)
        t = 0.5 * (float(w1["end"]) + float(w2["start"]))
        if seg_start < t < seg_end:
            cands.append(t)

    # de-dup + sort
    return sorted(set(round(t, 3) for t in cands))

In [70]:
assert all(words[i]["start"] <= words[i+1]["start"] for i in range(len(words)-1))


In [71]:
full_text = " ".join(w["text"] for w in words)

This is a list of mid points of any given silence detected by silero vad. It makes it eaiser to check whether a silence falls in one specific time interval or not.

In [72]:
vad_silence_times = [
    0.5 * (s["start"] + s["end"])
    for s in vad_silences
]

We now check for *potential* points of split using VAD outputs, speaker changes and punctuation. This does not include semantics. We use semantics only to decide whether or not we actually want to use this potential point to split.

In [73]:
FILLER_WORDS = {
    "um", "uh", "umm", "uhh"
}

PUNCTUATION = {".", ",", "?", "!", "..."}

In [74]:
def word_boundary_candidates(words, segment, min_gap=0.05):
    """
    Returns candidate cut times based on linguistic + speaker cues.
    Used ONLY when no VAD silence exists.
    """
    candidates = []

    seg_words = [
        w for w in words
        if segment["start"] < w["start"] < segment["end"]
    ]

    for i in range(len(seg_words) - 1):
        w1 = seg_words[i]
        w2 = seg_words[i + 1]

        gap = w2["start"] - w1["end"]

        # Small acoustic gap (even < VAD threshold)
        if gap >= min_gap:
            candidates.append(0.5 * (w1["end"] + w2["start"]))

        # Filler words
        if w1["text"].lower() in FILLER_WORDS:
            candidates.append(w1["end"])

        # Punctuation cues
        if any(p in w1["text"] for p in PUNCTUATION):
            candidates.append(w1["end"])

    # NEW: speaker-change cues
    candidates += speaker_change_candidates(
        words,
        segment,
        ignore_speakers={"UNK"},
        min_gap=0.0, #create more potential avenues for chunking
        min_run_words=1
    )

    return sorted(set(round(t, 3) for t in candidates))

These helper functions reconstruct transcript text from a word-level, time-aligned ASR output. They are designed for use in top-down semantic chunking pipelines, where audio is recursively split into non-overlapping time intervals and each word must belong to exactly one chunk.

In [75]:
def text_between(words, start, end):
    # Assign each word to exactly one interval using midpoint-in-span.
    span_words = []
    for w in words:
        mid = 0.5 * (w["start"] + w["end"])
        if start <= mid < end:   # half-open to avoid boundary duplication
            span_words.append(w["text"])
    return " ".join(" ".join(span_words).split())

def text_for_chunks(words, chunk_list):
    return " ".join(
        text_between(words, ch["start"], ch["end"])
        for ch in chunk_list
    ).strip()

In [76]:
import math

def cosine_similarity(a, b):
    """
    Compute cosine similarity between two vectors a and b.
    Both a and b must be the same length.
    """
    # Dot product
    dot = sum(x * y for x, y in zip(a, b))
    # Norms
    norm_a = math.sqrt(sum(x * x for x in a))
    norm_b = math.sqrt(sum(y * y for y in b))
    if norm_a == 0 or norm_b == 0:
        return 0.0
    return dot / (norm_a * norm_b)

# **Semantic Scoring + Chunking**

*   Candidate split scoring with semantic, balance, and speaker aware constraints.
This function scores all viable candidate split points inside a given chunk and returns them ranked by quality. It is designed for top-down semantic chunking, where oversized chunks are recursively split until duration constraints are satisfied. As mentioned before, a candidate for a split is selected either via VAD, speaker change or text-based ruling (punctuation and filler words)
*   Aside from just using semantic scores from neighbouring chunks, we also maintain a context window that helps retain context over larger durations.

*   Apart from that, we also have a minimum of 3 seconds per chunk being enforced in any case.


A good semantic boundary should satisfy these properties:


1.   Local discontinuity: The text immediately to the left and right of the split should be semantically different.
2.   Contextual discontinuity
The surrounding context on the left and right should also differ, not just the immediate split regions.
3. Internal coherence
Each side of the split should be semantically consistent with its own broader context.
This function explicitly scores all three.



Additionally, I also add speaker diarization based rewards and penalties:

  

1.   Reward for splitting when speaker changes.
2.   Penalty for interrupting or chunking when a speaker is explaining something uninterrupted.





In [77]:
# 2) Add guards so the splitter doesn’t create tiny slivers / very imbalanced halves.
MIN_SIDE = 3.0          # seconds: disallow splits creating < 3s chunks
BALANCE_LAMBDA = 0.25   # soft penalty weight; increase if you still see tiny chunks

# Speaker-aware scoring knobs (tune these)
SPEAKER_PURITY_MIN = 0.90   # require each side to be mostly one speaker
SPEAKER_BONUS      = 1.00   # added when the split separates different speakers cleanly
SPEAKER_PENALTY    = 0.35   # optional: subtract when both sides are confidently same speaker

def imbalance_penalty(left_dur, right_dur):
    total = left_dur + right_dur
    if total <= 0:
        return 0.0
    return abs(left_dur - right_dur) / total  # 0 (balanced) .. ~1 (very imbalanced)


# 3) Patch score_all_splits to apply the min-size constraint + penalty.
def score_all_splits(
    current_chunk,
    all_chunks,
    words,
    silence_times,
    embed,
    left_context_size=1,
    right_context_size=1,
    include_speaker_candidates=True,
    ignore_speakers={"UNK"},
):
    """
    Scores candidate cut times inside current_chunk.

    Enhancements:
      1) Optionally adds speaker-change times as additional candidates.
      2) Adds a speaker-aware term to reward cuts that separate different dominant speakers.
    """
    scored = []

    # --- build candidate times ---
    candidate_times = list(silence_times) if silence_times is not None else []

    if include_speaker_candidates:
        candidate_times += speaker_change_candidates(
            words,
            current_chunk,
            ignore_speakers=ignore_speakers,
            min_gap=0.0,
            min_run_words=1,
        )

    # de-dup + sort
    candidate_times = sorted(set(round(float(t), 3) for t in candidate_times))

    # find index for context windows
    try:
        idx = all_chunks.index(current_chunk)
    except ValueError:
        idx = [
            i for i, ch in enumerate(all_chunks)
            if ch["start"] == current_chunk["start"] and ch["end"] == current_chunk["end"]
        ][0]

    for t in candidate_times:
        if not (current_chunk["start"] < t < current_chunk["end"]):
            continue

        left_dur = t - current_chunk["start"]
        right_dur = current_chunk["end"] - t

        # hard constraint to prevent tiny chunks
        if left_dur < MIN_SIDE or right_dur < MIN_SIDE:
            continue

        left_chunk  = {"start": current_chunk["start"], "end": t}
        right_chunk = {"start": t, "end": current_chunk["end"]}

        left_text  = text_between(words, left_chunk["start"], left_chunk["end"])
        right_text = text_between(words, right_chunk["start"], right_chunk["end"])
        if not left_text.strip() or not right_text.strip():
            continue

        # --- semantic score (your existing logic) ---
        E_left  = embed(left_text)
        E_right = embed(right_text)

        left_neighbors = all_chunks[max(0, idx - left_context_size): idx]
        extended_left = left_neighbors + [left_chunk]
        ext_left_text = text_for_chunks(words, extended_left)
        if not ext_left_text.strip():
            continue
        E_ext_left = embed(ext_left_text)

        right_neighbors = all_chunks[idx + 1: idx + 1 + right_context_size]
        extended_right = [right_chunk] + right_neighbors
        ext_right_text = text_for_chunks(words, extended_right)
        if not ext_right_text.strip():
            continue
        E_ext_right = embed(ext_right_text)

        S_local = 1.0 - cosine_similarity(E_left, E_right)
        S_ext   = 1.0 - cosine_similarity(E_ext_left, E_ext_right)
        C_left  = cosine_similarity(E_left, E_ext_left)
        C_right = cosine_similarity(E_right, E_ext_right)
        C_int   = C_left + C_right

        score = 1.0 * S_local + 0.5 * S_ext + 0.2 * C_int

        # soft penalty for very uneven splits (your existing knob)
        score -= BALANCE_LAMBDA * imbalance_penalty(left_dur, right_dur)

        # --- NEW: speaker-aware term ---
        sL = speaker_stats_between(words, left_chunk["start"], left_chunk["end"])
        sR = speaker_stats_between(words, right_chunk["start"], right_chunk["end"])

        domL, purL = sL["dominant"], float(sL["purity"])
        domR, purR = sR["dominant"], float(sR["purity"])

        if (domL not in ignore_speakers) and (domR not in ignore_speakers):
            min_pur = min(purL, purR)

            # reward clean separation into different dominant speakers
            if (domL != domR) and (purL >= SPEAKER_PURITY_MIN) and (purR >= SPEAKER_PURITY_MIN):
                score += SPEAKER_BONUS * min_pur

            # optional: penalize cuts that *don’t* align with a speaker boundary
            #if (domL == domR) and (purL >= SPEAKER_PURITY_MIN) and (purR >= SPEAKER_PURITY_MIN):
                #score -= SPEAKER_PENALTY * min_pur

        scored.append((t, score))

    scored.sort(key=lambda x: x[1], reverse=True)
    return scored


This is an auxilliary function that is used when there are no candidates for splitting from the VAD and speaker diarization alone.

In [78]:
def best_fallback_split(segment, all_chunks, words, embed,
                        left_context_size=1, right_context_size=1,
                        ignore_speakers={"UNK"}):
    """
    Fallback split: generate word/punctuation-based candidate times, then score them
    using the same scoring function as the main path (score_all_splits).
    """
    fallback_times = word_boundary_candidates(words, segment)
    if not fallback_times:
        return None

    scored = score_all_splits(
        current_chunk=segment,
        all_chunks=all_chunks,
        words=words,
        silence_times=fallback_times,          # reuse the same parameter for candidate cut times
        embed=embed,
        left_context_size=left_context_size,
        right_context_size=right_context_size,
        include_speaker_candidates=False,      # IMPORTANT: word_boundary_candidates already adds these
        ignore_speakers=ignore_speakers,
    )

    return float(scored[0][0]) if scored else None

This function actually calls the score_all_splits function and segments each chunk into smaller chunks. We then run this iteratively until we are sure that there are no more chunks >15 seconds left. We run this only on VAD and speaker based potential candidates first. If it doesnt work out (not enough silences or speaker changes in the right area), only then do we look for punctuation, etc. We call the auxilliary function - best_fallback_split for this.

In [79]:
MAX_LEN = 15

def run_segmentation(
    words,
    silence_times,
    embed,
    local_context=1,
    right_context=1,
    include_speaker_candidates=True,
    ignore_speakers={"UNK"},
):
    """
    Enforces that all chunks in global `current_chunks` have duration <= MAX_LEN.

    Change vs your current version:
      - Builds per-segment candidate times (VAD silences within the segment)
      - Optionally adds speaker-change times within the segment
      - Passes the combined candidate list into score_all_splits
        (and disables internal speaker-candidate addition to avoid double-counting)
    """
    global current_chunks

    i = 0
    while i < len(current_chunks):
        segment = current_chunks[i]
        start, end = float(segment["start"]), float(segment["end"])
        duration = end - start

        # already valid
        if duration <= MAX_LEN:
            i += 1
            continue

        # if too short to split safely, skip
        if duration < 2 * MIN_SIDE:
            i += 1
            continue

        # --- build per-segment candidate times ---
        seg_candidate_times = []

        # VAD silence midpoints (restricted to this segment)
        if silence_times is not None:
            seg_candidate_times.extend(
                float(t) for t in silence_times
                if start < float(t) < end
            )

        # Speaker-change times (restricted by speaker_change_candidates itself)
        if include_speaker_candidates:
            seg_candidate_times.extend(
                speaker_change_candidates(
                    words,
                    segment,
                    ignore_speakers=ignore_speakers,
                    min_gap=0.0,
                    min_run_words=1,
                )
            )

        # de-dup + sort
        seg_candidate_times = sorted(set(round(float(t), 3) for t in seg_candidate_times))

        # --- score splits using the combined candidates ---
        scored = score_all_splits(
            current_chunk=segment,
            all_chunks=current_chunks,
            words=words,
            silence_times=seg_candidate_times,          # combined candidates
            embed=embed,
            left_context_size=local_context,
            right_context_size=right_context,
            include_speaker_candidates=False,           # IMPORTANT: avoid double-adding
            ignore_speakers=ignore_speakers,
        )

        if scored:
            cut_time = float(scored[0][0])
        else:
            # --- Fallback: word-based semantic split ---
            cut_time = best_fallback_split(
                segment,
                all_chunks=current_chunks,
                words=words,
                embed=embed,
                left_context_size=local_context,
                right_context_size=right_context,
            )

        # --- absolute last resort ---
        if cut_time is None or not (start < float(cut_time) < end):
            cut_time = 0.5 * (start + end)

        # clamp to respect MIN_SIDE
        cut_time = float(cut_time)
        cut_time = max(start + MIN_SIDE, min(cut_time, end - MIN_SIDE))

        # if still cannot split (numerical edge), skip
        if not (start + MIN_SIDE < cut_time < end - MIN_SIDE):
            i += 1
            continue

        # --- split ---
        left_chunk = {"start": start, "end": cut_time}
        right_chunk = {"start": cut_time, "end": end}

        # replace current chunk with two
        current_chunks[i:i + 1] = [left_chunk, right_chunk]

        # Do NOT increment i: re-check the new left chunk first

In [80]:
# ...existing code...

def preview_chunks(words, chunks):
    for i, c in enumerate(sorted(chunks, key=lambda x: x["start"])):
        # speaker summary for this chunk
        s = speaker_stats_between(words, c["start"], c["end"])
        dom = s["dominant"]
        pur = float(s["purity"])

        text = text_between(words, c["start"], c["end"])
        text = text.replace("\n", " ").strip()

        dur = c["end"] - c["start"]
        print(f"[{i:02d}] {c['start']:.2f}–{c['end']:.2f} ({dur:.2f}s) | speaker={dom} (purity={pur:.2f})")
        print(f"     {text}\n")

# ...existing code...

In [81]:
# ...existing code...

# 0) Ensure these already exist from earlier cells:
# - words (list of dicts with start/end/text[/speaker])
# - vad_silence_times (list[float])  OR pass None
# - embed (function: str -> np.ndarray)
# - run_segmentation, MIN_SIDE, MAX_LEN, score_all_splits, etc.

# 1) Initialize the global chunk list (one big chunk, or any starting chunks you want)
current_chunks = [{"start": 0.0, "end": float(words[-1]["end"])}]  # or your known end time (e.g., 1575.0)

# 2) Run segmentation (this mutates current_chunks in-place)
run_segmentation(
    words=words,
    silence_times=vad_silence_times,  # can be None if you want purely fallback/word-based
    embed=embed,
    local_context=5,
    right_context=5,
    include_speaker_candidates=True,
    ignore_speakers={"UNK"},
)

# 3) Inspect results
durations = [c["end"] - c["start"] for c in current_chunks]
print("chunks:", len(current_chunks), "min:", min(durations), "max:", max(durations))

preview_chunks(words, current_chunks)
# ...existing code...

chunks: 175 min: 3.088000000000079 max: 14.82400000000007
[00] 0.00–6.54 (6.54s) | speaker=SPEAKER_03 (purity=1.00)
     Congratulations to you Mr. Raghavan for that. Thank you so much for joining us. Over to you.

[01] 6.54–11.15 (4.61s) | speaker=SPEAKER_04 (purity=0.80)
     Hi everybody. How are you?

[02] 11.15–24.93 (13.78s) | speaker=SPEAKER_04 (purity=0.97)
     I am not hearing this at all. It's like a post lunch energy downer or something. Let's hear it. Are you guys awake? Alright. You better be because we

[03] 24.93–33.07 (8.14s) | speaker=SPEAKER_04 (purity=1.00)
     have a superstar guest here. You heard the $41 million and I didn't hear honestly anything she said after that.

[04] 33.07–40.93 (7.86s) | speaker=SPEAKER_04 (purity=1.00)
     So we are going to ask for about $40 million from him by the end of this conversation. But let's get started.

[05] 40.93–52.85 (11.92s) | speaker=SPEAKER_04 (purity=1.00)
     I want to introduce Vivek and Pratyush, his co -founder 

# Final Output

In [82]:
# Build the assignment-required output format
chunks_sorted = sorted(current_chunks, key=lambda c: float(c["start"]))

output_chunks = []
for i, c in enumerate(chunks_sorted, start=1):
    start = float(c["start"])
    end = float(c["end"])
    text = text_between(words, start, end).strip()

    output_chunks.append({
        "chunk_id": i,
        "chunk_length": float(end - start),
        "text": text,
        "start_time": start,
        "end_time": end,
    })

In [83]:
print(output_chunks)

[{'chunk_id': 1, 'chunk_length': 6.544, 'text': 'Congratulations to you Mr. Raghavan for that. Thank you so much for joining us. Over to you.', 'start_time': 0.0, 'end_time': 6.544}, {'chunk_id': 2, 'chunk_length': 4.608, 'text': 'Hi everybody. How are you?', 'start_time': 6.544, 'end_time': 11.152}, {'chunk_id': 3, 'chunk_length': 13.776000000000002, 'text': "I am not hearing this at all. It's like a post lunch energy downer or something. Let's hear it. Are you guys awake? Alright. You better be because we", 'start_time': 11.152, 'end_time': 24.928}, {'chunk_id': 4, 'chunk_length': 8.144000000000002, 'text': "have a superstar guest here. You heard the $41 million and I didn't hear honestly anything she said after that.", 'start_time': 24.928, 'end_time': 33.072}, {'chunk_id': 5, 'chunk_length': 7.8559999999999945, 'text': "So we are going to ask for about $40 million from him by the end of this conversation. But let's get started.", 'start_time': 33.072, 'end_time': 40.928}, {'chunk_i