In [None]:
import torch
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from pydantic import BaseModel
from typing import Dict, List, Tuple
import numpy as np

In [16]:
class SpeakerFeatures(BaseModel):
    total_speaking_duration: float
    total_turns: int
    speech_ratio: float  # total_speaking_duration / conv_length
    
    # A bunch of statistics on the turn lengths distribution
    # --------------------------
    mean_turn_duration: float
    median_turn_duration: float
    std_turn_duration: float
    min_turn_duration: float
    max_turn_duration: float
    percentiles: Dict[str, float]
    # --------------------------
    
    interruptions_made: int
    interruptions_received: int
    interrupted_by: Dict[str, int]

class ConversationMetrics(BaseModel):
    num_speakers: int
    total_speaking_time: float
    overlap_duration: float
    silence_duration: float
    overlap_ratio: float
    silence_ratio: float
    total_interruptions: int
    interruption_rate: float  # interruptions per minute
    
class BasicMetricsResponse(BaseModel):
    speakers: Dict[str, SpeakerFeatures]
    conversation: ConversationMetrics
    
class SpeakerSegment(BaseModel):
    start: float
    end: float
    speaker: str

class DiarizationResponse(BaseModel):
    segments: list[SpeakerSegment]
    num_speakers: int

In [None]:
# Community-1 open-source speaker diarization pipeline
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-community-1",
)

mps = torch.device("mps")
pipeline.to(mps)

# apply pretrained pipeline (with optional progress hook)
with ProgressHook() as hook:
    output = pipeline("../data/processed/audios/test_audio.wav", hook=hook)  # runs locally

Output()

In [18]:
def create_diarization_response(output) -> DiarizationResponse:
    segments = []
    speakers = set()
    
    # Extract speaker segments from the diarization output
    for turn, speaker in output.speaker_diarization:
        segment = SpeakerSegment(
            start=turn.start,
            end=turn.end,
            speaker=f"speaker_{speaker}"
        )
        segments.append(segment)
        speakers.add(speaker)
    
    # Create the response object
    response = DiarizationResponse(
        segments=segments,
        num_speakers=len(speakers)
    )
    
    return response

In [19]:
# Create DiarizationResponse from pipeline output
diarization_result = create_diarization_response(output)

In [23]:
# Display the result
print(f"Number of speakers: {diarization_result.num_speakers}")
print(f"Total segments: {len(diarization_result.segments)}")
print("\nSample segments:")
for segment in diarization_result.segments[:100]:  # Show first 5 segments
    print(f"From {segment.start:.1f}s to {segment.end:.1f}s: {segment.speaker}")

Number of speakers: 5
Total segments: 279

Sample segments:
From 10.9s to 13.5s: speaker_SPEAKER_04
From 28.0s to 28.7s: speaker_SPEAKER_04
From 60.5s to 65.0s: speaker_SPEAKER_04
From 70.3s to 92.1s: speaker_SPEAKER_04
From 93.7s to 94.2s: speaker_SPEAKER_04
From 94.3s to 137.6s: speaker_SPEAKER_01
From 139.0s to 139.2s: speaker_SPEAKER_04
From 139.2s to 139.2s: speaker_SPEAKER_01
From 139.6s to 170.0s: speaker_SPEAKER_03
From 171.7s to 172.3s: speaker_SPEAKER_03
From 172.5s to 172.5s: speaker_SPEAKER_03
From 172.5s to 173.3s: speaker_SPEAKER_00
From 173.3s to 220.4s: speaker_SPEAKER_00
From 202.2s to 202.5s: speaker_SPEAKER_01
From 220.8s to 220.9s: speaker_SPEAKER_00
From 221.2s to 221.6s: speaker_SPEAKER_00
From 224.0s to 224.0s: speaker_SPEAKER_00
From 224.0s to 244.4s: speaker_SPEAKER_04
From 245.0s to 267.6s: speaker_SPEAKER_04
From 269.6s to 270.3s: speaker_SPEAKER_04
From 273.2s to 284.7s: speaker_SPEAKER_04
From 287.7s to 288.4s: speaker_SPEAKER_04
From 291.9s to 295.4s: spea

In [21]:
def basic_metrics(diarization_result: DiarizationResponse,
                  audio_length: float,
                  percentiles: List[int] = [10, 25, 75, 90]) -> BasicMetricsResponse:
    """
    Comprehensive analysis of conversation dynamics from diarization results.

    Returns:
        BasicMetricsResponse (pydantic) with 'speakers' and 'conversation' populated.
    """
    # Group segments by speaker and sort all segments chronologically
    speaker_segments: Dict[str, List[Tuple[float, float]]] = {}
    all_segments: List[Tuple[float, float, str]] = []

    for segment in diarization_result.segments:
        speaker_segments.setdefault(segment.speaker, []).append((segment.start, segment.end))
        all_segments.append((segment.start, segment.end, segment.speaker))

    # Sort segments chronologically by start time
    all_segments.sort(key=lambda x: x[0])

    # Initialize interruption tracking
    interruptions = {speaker: 0 for speaker in speaker_segments.keys()}
    interrupted_by = {
        speaker: {other: 0 for other in speaker_segments.keys() if other != speaker}
        for speaker in speaker_segments.keys()
    }

    # Detect interruptions
    active_speakers: Dict[str, float] = {}  # speaker -> end_time

    for start, end, speaker in all_segments:
        # Check if this speaker is interrupting anyone
        for active_speaker, active_end in list(active_speakers.items()):
            if active_speaker != speaker and start < active_end:
                interruptions[speaker] += 1
                interrupted_by[active_speaker][speaker] += 1

        # Update active speakers
        active_speakers[speaker] = end

        # Remove speakers who have finished before current start
        active_speakers = {s: e for s, e in active_speakers.items() if e > start}

    # Calculate features for each speaker
    speaker_features: Dict[str, SpeakerFeatures] = {}
    total_speaking_time = 0.0

    for speaker, segments in speaker_segments.items():
        turn_durations = [end - start for start, end in segments]
        total_speaking_duration = sum(turn_durations)
        total_speaking_time += total_speaking_duration
        total_turns = len(turn_durations)
        speech_ratio = total_speaking_duration / audio_length if audio_length > 0 else 0.0

        arr = np.array(turn_durations, dtype=float)
        mean_turn_duration = float(np.mean(arr)) if arr.size else 0.0
        median_turn_duration = float(np.median(arr)) if arr.size else 0.0
        std_turn_duration = float(np.std(arr)) if arr.size else 0.0
        min_turn_duration = float(np.min(arr)) if arr.size else 0.0
        max_turn_duration = float(np.max(arr)) if arr.size else 0.0

        percentile_values = np.percentile(arr, percentiles).tolist() if arr.size else [0.0] * len(percentiles)
        percentile_dict = {f"percentile_{p}": float(v) for p, v in zip(percentiles, percentile_values)}

        speaker_features[speaker] = SpeakerFeatures(
            total_speaking_duration=float(total_speaking_duration),
            total_turns=int(total_turns),
            speech_ratio=float(speech_ratio),
            mean_turn_duration=mean_turn_duration,
            median_turn_duration=median_turn_duration,
            std_turn_duration=std_turn_duration,
            min_turn_duration=min_turn_duration,
            max_turn_duration=max_turn_duration,
            interruptions_made=int(interruptions[speaker]),
            interruptions_received=int(sum(interrupted_by[speaker].values())),
            interrupted_by={k: int(v) for k, v in interrupted_by[speaker].items()},
            percentiles=percentile_dict
        )

    # Conversation-level metrics
    # Build timeline and compute coverage
    timeline = [(seg.start, seg.end, seg.speaker) for seg in diarization_result.segments]
    timeline.sort(key=lambda x: x[0])

    total_coverage = 0.0
    last_end = 0.0
    for start, end, _ in timeline:
        if start > last_end:
            # Disjoint segment
            total_coverage += (end - start)
            last_end = end
        elif end > last_end:
            # Partial overlap
            total_coverage += (end - last_end)
            last_end = end

    overlap_duration = total_speaking_time - total_coverage
    silence_duration = max(0.0, audio_length - total_coverage)
    total_interruptions = sum(interruptions.values())
    interruption_rate = total_interruptions / (audio_length / 60) if audio_length > 0 else 0.0

    conversation_metrics = ConversationMetrics(
        num_speakers=int(diarization_result.num_speakers),
        total_speaking_time=float(total_speaking_time),
        overlap_duration=float(overlap_duration),
        silence_duration=float(silence_duration),
        overlap_ratio=float(overlap_duration / audio_length) if audio_length > 0 else 0.0,
        silence_ratio=float(silence_duration / audio_length) if audio_length > 0 else 0.0,
        total_interruptions=int(total_interruptions),
        interruption_rate=float(interruption_rate)
    )

    return BasicMetricsResponse(speakers=speaker_features, conversation=conversation_metrics)

In [22]:
# prints
audio_length = 1384

# Get comprehensive analysis results
analysis = basic_metrics(diarization_result, audio_length)

# Print speaker-specific metrics
for speaker, features in analysis.speakers.items():
    print(f"\n{speaker} statistics:")
    print(f"  Speaking time: {features.total_speaking_duration:.2f}s ({features.speech_ratio*100:.1f}% of conversation)")
    print(f"  Number of turns: {features.total_turns}")
    print(f"  Average turn duration: {features.mean_turn_duration:.2f}s (median: {features.median_turn_duration:.2f}s)")
    print(f"  Interruptions made: {features.interruptions_made}")
    print(f"  Interruptions received: {features.interruptions_received}")

# Print conversation-level metrics
conv = analysis.conversation
print("\nConversation metrics:")
print(f"  Total speaking time: {conv.total_speaking_time:.2f}s")
print(f"  Silence: {conv.silence_duration:.2f}s ({conv.silence_ratio*100:.1f}%)")
print(f"  Overlap: {conv.overlap_duration:.2f}s ({conv.overlap_ratio*100:.1f}%)")
print(f"  Total interruptions: {conv.total_interruptions} ({conv.interruption_rate:.2f}/minute)")


speaker_SPEAKER_04 statistics:
  Speaking time: 453.70s (32.8% of conversation)
  Number of turns: 119
  Average turn duration: 3.81s (median: 1.69s)
  Interruptions made: 30
  Interruptions received: 33

speaker_SPEAKER_01 statistics:
  Speaking time: 102.85s (7.4% of conversation)
  Number of turns: 31
  Average turn duration: 3.32s (median: 0.47s)
  Interruptions made: 9
  Interruptions received: 14

speaker_SPEAKER_03 statistics:
  Speaking time: 229.92s (16.6% of conversation)
  Number of turns: 83
  Average turn duration: 2.77s (median: 0.62s)
  Interruptions made: 35
  Interruptions received: 15

speaker_SPEAKER_00 statistics:
  Speaking time: 243.52s (17.6% of conversation)
  Number of turns: 41
  Average turn duration: 5.94s (median: 1.62s)
  Interruptions made: 8
  Interruptions received: 23

speaker_SPEAKER_02 statistics:
  Speaking time: 1.97s (0.1% of conversation)
  Number of turns: 5
  Average turn duration: 0.39s (median: 0.17s)
  Interruptions made: 3
  Interruptions 