In [116]:
from pathlib import Path
from typing import cast, Any
from datetime import datetime
from dataclasses import dataclass
import re
from collections import Counter

import numpy as np
import numpy.typing as npt
from pisets import load_sound # type: ignore
import gigaam
from gigaam.model import GigaAMASR
import matplotlib.pyplot as plt

from asr_eval.models.gigaam_wrapper import transcribe_with_gigaam_ctc, FREQ, encode
from asr_eval.ctc.forced_alignment import forced_alignment
from asr_eval.ctc.chunking import chunked_ctc_prediction, average_logp_windows
from asr_eval.utils.serializing import save_to_json, load_from_json
from asr_eval.utils.misc import groupby_into_spans

In [119]:
def get_gigaam_log_probs(
    waveform: npt.NDArray[np.floating[Any]], model: GigaAMASR
) -> npt.NDArray[np.floating[Any]]:
    return average_logp_windows(chunked_ctc_prediction(
        waveform=waveform,
        ctc_model=lambda waveform: transcribe_with_gigaam_ctc(model, [waveform])[0].log_probs,
        model_tick_size_sec=1 / FREQ,
        segment_size_sec=20,
        segment_shift_sec=6,
        sampling_rate=16_000,
    ))

def filter_encodable_tokens(text: str, model: GigaAMASR) -> str:
    assert model.decoding.tokenizer.charwise
    text = text.lower().replace('ё', 'е').replace('-', ' ')
    return ''.join(char for char in text if char in model.decoding.tokenizer.vocab)

@dataclass
class DionSegment:
    approx_start_time: float
    speaker: str | list[str]
    text: str
    text_start_pos: int
    text_end_pos: int
    
def get_dion_segments(text_path: Path, model: GigaAMASR) -> tuple[list[DionSegment], str]:
    joined_text = ''
    segments: list[DionSegment] = []
    for record in text_path.read_text(encoding='utf-8-sig').split('\n\n'):
        try:
            start_time_str, speaker, _, text = record.split('\n')
        except ValueError:
            start_time_str, speaker1, speaker2, _, text = record.split('\n')
            speaker = [speaker1, speaker2]
        
        text = text[2:].strip()  # skip "- " prefix in the file
        text = filter_encodable_tokens(text, model)
        
        try:
            start_time = datetime.strptime(start_time_str, '%H:%M:%S') - datetime(1900, 1, 1)
        except ValueError:
            start_time = datetime.strptime(start_time_str, '%M:%S') - datetime(1900, 1, 1)
        
        if len(joined_text):
            joined_text += ' '
        joined_text += text
        
        segments.append(DionSegment(
            approx_start_time=start_time.total_seconds() - 1.5,  # most often Dion timings are shifted forward by 1.5 sec
            speaker=speaker,
            text=text,
            text_start_pos=len(joined_text) - len(text),
            text_end_pos=len(joined_text),
        ))
    
    return segments, joined_text
    
@dataclass
class TimedWord:
    start_time: float
    end_time: float
    text: str
    
@dataclass
class AlignedDionTimedWord:
    start_time: float
    end_time: float
    speaker: str | list[str]
    text: str
    delta: float

def get_gigaam_argmax_timed_words(
    log_probs: npt.NDArray[np.floating[Any]],
    model: GigaAMASR, 
) -> list[TimedWord]:
    idx_per_frame = log_probs.argmax(axis=1)
    spans = [
        (token_id, start, end)
        for token_id, start, end in groupby_into_spans(idx_per_frame.tolist())
        if token_id != model.decoding.blank_id
    ]
    text = ''.join(model.decoding.tokenizer.vocab[token_id] for token_id, _start, _end in spans)

    words: list[TimedWord] = []
    for match in re.finditer(r'\w+', text):
        _, start, _ = spans[match.start()]
        _, _, end = spans[match.end() - 1]
        words.append(TimedWord(
            start_time=start / FREQ,
            end_time=end / FREQ,
            text=match.group(),
        ))
    
    return words

def do_forced_alignment(
    log_probs: npt.NDArray[np.floating[Any]],
    model: GigaAMASR,
    joined_text: str,
    segments: list[DionSegment],
) -> list[AlignedDionTimedWord]:
    _idx_per_frame, _scores_per_frame, true_tokens_pos = forced_alignment(
        log_probs=log_probs,
        true_tokens=encode(model, joined_text),
        blank_id=model.decoding.blank_id,
    )

    timed_words: list[AlignedDionTimedWord] = []
    for segment in segments:
        
        start_frame, _ = true_tokens_pos[segment.text_start_pos]
        delta = start_frame / FREQ - segment.approx_start_time
        
        for match in re.finditer(r'\w+', segment.text):
            start_frame, _ = true_tokens_pos[segment.text_start_pos + match.start()]
            _, end_frame = true_tokens_pos[segment.text_start_pos + match.end() - 1]
            timed_words.append(AlignedDionTimedWord(
                start_time=start_frame / FREQ,
                end_time=end_frame / FREQ,
                speaker=segment.speaker,
                text=match.group(),
                delta=round(delta, 3), # usually values like 1.5000000001 or 1.4999999999
            ))
    
    return timed_words

def draw_time_deltas(timed_words: list[AlignedDionTimedWord]):
    plt.figure(figsize=(10, 2)) # type: ignore
    deltas = [x.delta for x in timed_words]
    plt.title( # type: ignore
        f'Match between GigaAM forced alignment and Dion timings:\nmedian delta'
        f' {np.median(deltas):.2f}, min delta {min(deltas):.2f}, max delta {max(deltas):.2f}'
    )
    plt.plot(deltas) # type: ignore
    plt.axhline(0, color='lightgray', zorder=0) # type: ignore
    plt.show() # type: ignore

In [58]:
model = cast(GigaAMASR, gigaam.load_model('ctc', device='cuda'))

In [None]:
for video_path in Path('/asr_datasets/dion/').glob('*.mp4'):
    print(video_path)
    
    text_path = video_path.with_suffix('.txt')
    output_path = video_path.with_suffix('.json')
    
    if output_path.exists():
        continue

    if video_path.name in (  # multispeaker dion transcriptions
        'kondrashuk_2025-06-25T12_12_07+03_00.mp4',  # too long for CTC
    ):
        continue

    # get conference transcription from Dion
    dion_segments, joined_text = get_dion_segments(text_path, model)

    # get log probs from GigaAM
    waveform = load_sound(video_path)
    log_probs = get_gigaam_log_probs(waveform, model)
    
    gigaam_timed_words = get_gigaam_argmax_timed_words(log_probs, model)

    # align both
    aligned_timed_words = do_forced_alignment(log_probs, model, joined_text, dion_segments)

    # display(dion_segments[:5])
    # print(joined_text[:500])
    # display(timed_words[:5])
    draw_time_deltas(aligned_timed_words)
    
    save_to_json({
        'gigaam_timed_words': gigaam_timed_words,
        'aligned_timed_words': aligned_timed_words,
    }, output_path)

In [123]:
all_counts: Counter[str | tuple[str, ...]] = Counter()

for json_path in Path('/asr_datasets/dion/').glob('*.json'):
    timed_words: list[AlignedDionTimedWord] = load_from_json(json_path)['aligned_timed_words']
    counts = Counter([
        (x.speaker if isinstance(x.speaker, str) else tuple(x.speaker))
        for x in timed_words
    ])
    all_counts.update(counts)

all_counts.most_common()

[('Иван Бондаренко', 13088),
 ('Roman', 11990),
 ('Elena Bruches', 11354),
 ('Виктория Кондрашук', 10895),
 ('Георгий', 9941),
 ('Vadim', 9411),
 ('Анастасия Рыбенко', 5058),
 ('Irina', 4296),
 ('Арсений', 4073),
 ('Анастасия', 3143),
 ('Георгий М', 3070),
 ('Даниил Гребенкин', 3036),
 ('Nikolay Bushkov', 2989),
 ('Дмитрий', 2866),
 ('v.alperovich | T-Bank', 2764),
 ('Ирина', 2576),
 ('Стас', 2248),
 ('ilyas', 1173),
 ('Михаил', 872),
 ('Den', 649),
 ('Ильдар (Онтико)', 498),
 ('Askar Timirgazin', 349),
 ('Игорь', 155),
 ('Иван Бондаренко ', 138),
 ('Денис Бондаренко ', 107),
 ('ivan_chernov', 104),
 ('Дари Батурова', 93),
 ('Konsatntin RSHB', 74),
 ('aleksandr medvedev', 64),
 ('Федьков Дмитрий', 18),
 (('Elena Bruches', 'Виктория Кондрашук'), 17),
 ('Дари', 17),
 ('Андрей Башкиров', 16),
 (('Георгий', 'Виктория Кондрашук'), 16),
 (('Иван Бондаренко', 'Виктория Кондрашук'), 11),
 (('Игорь', 'ilyas'), 6),
 (('Анастасия Рыбенко', 'Виктория Кондрашук'), 5),
 (('Den', 'Виктория Кондрашук'