In [3]:
import time
import numpy as np
import torch
import torchaudio

import matplotlib.pyplot as plt
import seaborn as sns

from itertools import groupby
from operator import itemgetter
from matplotlib.collections import LineCollection

device = 'cuda'

In [4]:
from diarizer import Audio
from diarizer import SileroVAD
from diarizer import OnlineSpeakerClustering
from diarizer import MSDD

In [5]:
from utils import load_audio

In [6]:
msdd = MSDD(
    threshold=0.8
)
titanet_l = msdd.speech_embedding_model
vad = SileroVAD(
    threshold=0.5
)
osc = OnlineSpeakerClustering()

[NeMo I 2025-02-20 21:29:10 cloud:58] Found existing object /home/raid/.cache/torch/NeMo/NeMo_2.1.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.
[NeMo I 2025-02-20 21:29:10 cloud:64] Re-using file from: /home/raid/.cache/torch/NeMo/NeMo_2.1.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo
[NeMo I 2025-02-20 21:29:10 common:826] Instantiating model from pre-trained checkpoint


[NeMo W 2025-02-20 21:29:11 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: true
    
[NeMo W 2025-02-20 21:29:11 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: null
    emb_dir: null
    sample_rate: 16000
    num_spks: 2
    soft_label_thres: 0.5
    labels: null
    batch_size: 15
    emb_batch_size: 0
    shuffle: false
    
[NeMo W 2025-02-20 21:29:11 modelPT:189] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple

[NeMo I 2025-02-20 21:29:11 features:305] PADDING: 16
[NeMo I 2025-02-20 21:29:12 features:305] PADDING: 16
[NeMo I 2025-02-20 21:29:12 save_restore_connector:275] Model EncDecDiarLabelModel was successfully restored from /home/raid/.cache/torch/NeMo/NeMo_2.1.0/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.


In [7]:
scales = [1.5, 1.25, 1.0, 0.75, 0.5]
hops = [0.75, 0.625, 0.5, 0.375, 0.25]
a = Audio(
    scales, 
    hops, 
    speech_embedding_model=titanet_l,
    voice_activity_detection_model=vad,
    multi_scale_diarization_model=msdd,
    speaker_clustering=osc
)

In [23]:
# waveform, sr = load_audio('test.wav')
waveform, sr = load_audio('toefl_eg.mp3')
waveform = waveform[0]

In [24]:
proba, labels = a(waveform[:500_000])

torch.Size([1, 5, 192, 1])


In [10]:
proba, labels = a(waveform[500_000:1_000_000])

torch.Size([1, 5, 192, 1])


In [13]:
proba, labels = a(waveform[1_000_000:4_000_000])

torch.Size([1, 5, 192, 3])


In [17]:
a(waveform[4_000_000:10_000_000])

torch.Size([1, 5, 192, 3])


(tensor([[0.2468, 0.3410, 0.4089,  ..., 0.4451, 0.4427, 0.4259],
         [0.9999, 1.0000, 1.0000,  ..., 0.4617, 0.3913, 0.2931],
         [0.3836, 0.3176, 0.3007,  ..., 1.0000, 1.0000, 1.0000]],
        device='cuda:0'),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 1., 1.]], device='cuda:0'))

In [18]:
timeline = [segment.speakers for segment in a.base_scale_segments]

In [19]:
def process_timeline(data):
    # Convert None to empty list for consistency
    timeline = [[] if x is None else sorted(x) for x in data]
    return timeline

def merge_segments(timeline):
    merged_segments = []
    
    # Find continuous segments for each speaker
    for speaker in set([spk for t in timeline for spk in t]):
        start_idx = None
        
        for t, speakers in enumerate(timeline):
            if speaker in speakers:
                if start_idx is None:
                    start_idx = t
            elif start_idx is not None:
                # Add segment
                merged_segments.append({
                    'speaker': speaker,
                    'start': start_idx * 0.25,  # Convert to seconds
                    'end': t * 0.25,            # Convert to seconds
                    'duration': (t - start_idx) * 0.25
                })
                start_idx = None
        
        # Handle segment that goes until the end
        if start_idx is not None:
            merged_segments.append({
                'speaker': speaker,
                'start': start_idx * 0.25,
                'end': len(timeline) * 0.25,
                'duration': (len(timeline) - start_idx) * 0.25
            })
    
    return sorted(merged_segments, key=lambda x: (x['start'], x['speaker']))

def generate_rttm(segments):
    rttm_lines = []
    
    for seg in segments:
        rttm_line = f"SPEAKER unknown 1 {seg['start']:.3f} {seg['duration']:.3f} <NA> <NA> SPEAKER_{seg['speaker']} <NA> <NA>"
        rttm_lines.append(rttm_line)
    
    return "\n".join(rttm_lines)

def create_visualization(segments, total_duration):
    plt.figure(figsize=(20, 5))
    
    # Get unique speakers and assign them y-coordinates
    unique_speakers = sorted(set(seg['speaker'] for seg in segments))
    speaker_to_y = {speaker: i for i, speaker in enumerate(unique_speakers)}
    
    # Create line segments for each speaker
    for speaker in unique_speakers:
        speaker_segments = [seg for seg in segments if seg['speaker'] == speaker]
        
        for seg in speaker_segments:
            plt.hlines(
                y=speaker_to_y[seg['speaker']],
                xmin=seg['start'],
                xmax=seg['end'],
                linewidth=4,
                label=f"Speaker {speaker}"
            )
    
    # Set axis limits to show the full timeline
    plt.xlim(0, total_duration)
    plt.ylim(-0.5, len(unique_speakers) - 0.5)
    
    # Set y-axis labels
    plt.yticks(
        range(len(unique_speakers)),
        [f"Speaker {speaker}" for speaker in unique_speakers]
    )
    
    # Remove duplicate labels in legend
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(), loc='center left', bbox_to_anchor=(1, 0.5))
    
    # Customize the plot
    plt.xlabel("Time (seconds)")
    plt.ylabel("Speakers")
    plt.title("Speaker Diarization Timeline")
    plt.grid(True, alpha=0.3)
    
    # Add some padding to the plot
    plt.margins(x=0.02)
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save the plot
    plt.savefig('speaker_diarization.png', bbox_inches='tight')
    plt.close()

# Use your existing code to process the data and generate segments
timeline = process_timeline(timeline)
merged_segments = merge_segments(timeline)

# Calculate total duration in seconds
total_duration = len(timeline) * 0.25

# Generate RTTM
rttm_content = generate_rttm(merged_segments)
with open('output.rttm', 'w') as f:
    f.write(rttm_content)

# Create visualization with total duration
create_visualization(merged_segments, total_duration)

# Print segments for verification
for seg in merged_segments:
    print(f"Speaker {seg['speaker']}: {seg['start']:.2f}s - {seg['end']:.2f}s (duration: {seg['duration']:.2f}s)")

Speaker 0: 62.00s - 120.00s (duration: 58.00s)
Speaker 0: 120.25s - 134.00s (duration: 13.75s)
Speaker 0: 135.75s - 137.75s (duration: 2.00s)
Speaker 1: 135.75s - 140.25s (duration: 4.50s)
Speaker 1: 140.75s - 145.00s (duration: 4.25s)
Speaker 2: 145.50s - 155.25s (duration: 9.75s)
Speaker 2: 155.50s - 156.50s (duration: 1.00s)
Speaker 1: 156.75s - 160.00s (duration: 3.25s)
Speaker 2: 159.00s - 168.75s (duration: 9.75s)
Speaker 1: 168.50s - 178.25s (duration: 9.75s)
Speaker 2: 171.75s - 172.75s (duration: 1.00s)
Speaker 2: 175.25s - 175.75s (duration: 0.50s)
Speaker 2: 178.00s - 181.75s (duration: 3.75s)
Speaker 2: 182.00s - 187.25s (duration: 5.25s)
Speaker 1: 187.50s - 191.25s (duration: 3.75s)
Speaker 2: 191.00s - 191.25s (duration: 0.25s)
Speaker 1: 191.50s - 191.75s (duration: 0.25s)
Speaker 2: 191.50s - 197.00s (duration: 5.50s)
Speaker 1: 196.50s - 206.75s (duration: 10.25s)
Speaker 1: 207.00s - 215.75s (duration: 8.75s)
Speaker 2: 215.25s - 220.25s (duration: 5.00s)
Speaker 1: 