# Visualizing Speaker Diarization Results

This notebook demonstrates how to visualize and analyze the speaker diarization results.

In [None]:
import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import matplotlib.patches as mpatches
import IPython.display as ipd
from typing import Dict, List, Tuple

# Add parent directory to path to import project modules
sys.path.append(os.path.abspath('..'))
from src.data_utils.audio_processor import AudioProcessor

## Load Audio and Diarization Results

In [None]:
# Change these paths to your files
AUDIO_FILE = "../data/input/example.wav"  # Path to input audio file
DIARIZATION_JSON = "../data/output/example_diarization.json"  # Path to diarization JSON result

In [None]:
# Load the audio file
audio_processor = AudioProcessor()
audio, sr = audio_processor.load_audio(AUDIO_FILE)

# Display audio player
print(f"Audio duration: {len(audio)/sr:.2f} seconds")
ipd.display(ipd.Audio(audio, rate=sr))

In [None]:
# Load diarization results
with open(DIARIZATION_JSON, 'r') as f:
    diarization_results = json.load(f)

# Extract speaker segments
speaker_segments = diarization_results["speaker_segments"]
transcribed_segments = diarization_results["transcribed_segments"]

# Print number of speakers
print(f"Number of speakers detected: {len(speaker_segments)}")

## Visualize Speaker Segments

In [None]:
def visualize_diarization(audio, sr, speaker_segments, figsize=(15, 8)):
    """Visualize diarization results with waveform and speaker segments."""
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1, 3]})
    
    # Plot waveform in the first subplot
    librosa.display.waveshow(audio, sr=sr, ax=ax1)
    ax1.set_title('Waveform')
    ax1.set_xlabel('')
    
    # Plot speaker segments in the second subplot
    ax2.set_title('Speaker Diarization')
    ax2.set_xlabel('Time (seconds)')
    ax2.set_ylabel('Speaker')
    
    # Set y-axis limits and hide ticks
    ax2.set_ylim(0, len(speaker_segments) + 1)
    ax2.set_yticks(range(1, len(speaker_segments) + 1))
    ax2.set_yticklabels([f'Speaker {i}' for i in speaker_segments.keys()])
    
    # Set x-axis limits to match the audio duration
    duration = len(audio) / sr
    ax1.set_xlim(0, duration)
    ax2.set_xlim(0, duration)
    
    # Define colors for each speaker
    cmap = plt.get_cmap('tab10')
    colors = [cmap(i) for i in range(len(speaker_segments))]
    
    # Plot each speaker's segments
    legend_patches = []
    for i, (speaker_id, segments) in enumerate(speaker_segments.items()):
        for segment in segments:
            start = segment['start']
            end = segment['end']
            rect = mpatches.Rectangle((start, i+0.5), end-start, 0.8, 
                                       edgecolor='none', facecolor=colors[i], alpha=0.8)
            ax2.add_patch(rect)
        
        legend_patches.append(mpatches.Patch(color=colors[i], label=f'Speaker {speaker_id}'))
    
    # Add legend
    ax2.legend(handles=legend_patches, loc='upper right')
    
    plt.tight_layout()
    return fig, (ax1, ax2)

In [None]:
# Visualize the diarization results
fig, _ = visualize_diarization(audio, sr, speaker_segments)
plt.show()

## Display Transcriptions by Speaker

In [None]:
def display_transcript(transcribed_segments):
    """Display transcript with timestamps and speaker IDs."""
    # Sort segments by start time
    sorted_segments = sorted(transcribed_segments, key=lambda x: x["start"])
    
    # Display each segment
    for segment in sorted_segments:
        start = segment["start"]
        end = segment["end"]
        speaker_id = segment["speaker_id"]
        text = segment["text"]
        
        # Format: [MM:SS.ms -> MM:SS.ms] Speaker X: Text
        start_str = f"{int(start // 60):02d}:{start % 60:06.3f}"
        end_str = f"{int(end // 60):02d}:{end % 60:06.3f}"
        print(f"[{start_str} -> {end_str}] Speaker {speaker_id}: {text}")

In [None]:
# Display the transcript
display_transcript(transcribed_segments)

## Listen to Individual Speaker Segments

In [None]:
def extract_speaker_audio(audio, sr, speaker_segments, speaker_id):
    """Extract audio segments for a specific speaker."""
    if speaker_id not in speaker_segments:
        print(f"Speaker {speaker_id} not found in the diarization results.")
        return None
    
    # Create a silent audio of the same length as the original
    speaker_audio = np.zeros_like(audio)
    
    # Fill in the segments for the specified speaker
    for segment in speaker_segments[speaker_id]:
        start_sample = int(segment['start'] * sr)
        end_sample = int(segment['end'] * sr)
        
        # Ensure indices are within bounds
        start_sample = max(0, start_sample)
        end_sample = min(len(audio), end_sample)
        
        if end_sample > start_sample:
            speaker_audio[start_sample:end_sample] = audio[start_sample:end_sample]
    
    return speaker_audio

In [None]:
# Extract and play audio for each speaker
for speaker_id in speaker_segments.keys():
    print(f"Speaker {speaker_id}:")
    speaker_audio = extract_speaker_audio(audio, sr, speaker_segments, speaker_id)
    if speaker_audio is not None:
        ipd.display(ipd.Audio(speaker_audio, rate=sr))

## Analyze Speaker Statistics

In [None]:
def speaker_statistics(speaker_segments):
    """Calculate statistics for each speaker."""
    stats = {}
    
    for speaker_id, segments in speaker_segments.items():
        # Calculate total speaking time
        total_duration = sum(seg['end'] - seg['start'] for seg in segments)
        
        # Count number of turns
        num_turns = len(segments)
        
        # Average turn duration
        avg_turn_duration = total_duration / num_turns if num_turns > 0 else 0
        
        stats[speaker_id] = {
            'total_duration': total_duration,
            'num_turns': num_turns,
            'avg_turn_duration': avg_turn_duration
        }
    
    return stats

In [None]:
# Analyze speaker statistics
stats = speaker_statistics(speaker_segments)

# Display statistics
print("Speaker Statistics:")
print(f"{'Speaker ID':<15} {'Total Time (s)':<20} {'Number of Turns':<20} {'Avg Turn Duration (s)':<20}")
print("-" * 75)

for speaker_id, speaker_stats in stats.items():
    print(f"{speaker_id:<15} {speaker_stats['total_duration']:<20.2f} {speaker_stats['num_turns']:<20} {speaker_stats['avg_turn_duration']:<20.2f}")

# Plot speaking time distribution
plt.figure(figsize=(10, 6))
plt.bar(stats.keys(), [s['total_duration'] for s in stats.values()])
plt.title('Speaking Time Distribution')
plt.xlabel('Speaker ID')
plt.ylabel('Speaking Time (seconds)')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()