### Hugging Face Pyannote Diarization Implementation

pip install torch==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html ## I dont think this is particularly useful
pip install pyannote.audio

In [2]:
import pandas as pd
import numpy as np
from pyannote.audio import Pipeline
import torch
import torchaudio
import sounddevice as sd
from datetime import datetime, timedelta

  torchaudio.set_audio_backend("soundfile")
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(torch.__version__)
print(torch.cuda.is_available())

2.3.0+cpu
False


In [4]:
## Get the Hugging Face Model for Voice Diarization
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token="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")

torchvision is not available - cannot save figures


If GPU enabled, and we do
```
import torch
pipeline.to(torch.device("cuda"))
```
They claim it could process 1 hour conversation in 1.5 mins. 
Currently in CPU - it takes over 3 mins to process a 4 mins audio file.

In [5]:
# run the pipeline on an audio file - took 3 mins and 5 seconds to run for 4 mins audio
#diarization = pipeline("multiple_voice.wav")

In [6]:
## Another way - took same 3 mins and 5 seconds to run a 4 mins audio file
waveform, sample_rate = torchaudio.load("Debate example.wav")
diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate})

In [7]:
# print the diarization output
print(diarization)

[ 00:00:00.008 -->  00:01:14.388] A SPEAKER_00
[ 00:01:15.254 -->  00:01:17.088] B SPEAKER_01
[ 00:01:17.292 -->  00:01:29.889] C SPEAKER_01
[ 00:01:30.127 -->  00:01:40.144] D SPEAKER_01
[ 00:01:40.636 -->  00:01:41.315] E SPEAKER_01
[ 00:01:41.740 -->  00:01:55.271] F SPEAKER_01
[ 00:01:55.475 -->  00:02:46.273] G SPEAKER_00
[ 00:02:46.460 -->  00:02:52.928] H SPEAKER_00
[ 00:02:53.285 -->  00:02:54.864] I SPEAKER_01
[ 00:02:56.290 -->  00:02:58.904] J SPEAKER_01
[ 00:02:59.193 -->  00:03:36.511] K SPEAKER_01
[ 00:03:37.665 -->  00:03:38.938] L SPEAKER_01
[ 00:03:39.006 -->  00:03:46.494] M SPEAKER_01
[ 00:03:46.680 -->  00:03:48.361] N SPEAKER_01
[ 00:03:48.887 -->  00:04:32.266] O SPEAKER_00
[ 00:04:33.200 -->  00:04:35.628] P SPEAKER_01
[ 00:04:36.188 -->  00:04:40.874] Q SPEAKER_01
[ 00:04:41.315 -->  00:04:45.356] R SPEAKER_01
[ 00:04:45.899 -->  00:04:50.025] S SPEAKER_01
[ 00:04:50.365 -->  00:04:50.823] T SPEAKER_01
[ 00:04:51.179 -->  00:04:57.071] U SPEAKER_01
[ 00:04:57.47

In [8]:
# print speaker labels
print(diarization.labels())

['SPEAKER_00', 'SPEAKER_01']


In [9]:
### If we have to remove overlapped content
print(diarization.get_overlap())

[]


## Now lets see if we can separate audio's based on the given timelines from pyannote

In [10]:
def play_wav(waveform, sample_rate, playback_duration=10):
    # Define the duration to play (in seconds)
    num_samples_to_play = min(len(waveform), int(playback_duration * sample_rate))
    waveform_to_play = waveform[:num_samples_to_play]
    
    # Check if the waveform has only one channel, if yes, expand it to two channels
    if waveform_to_play.ndim == 1:
        waveform_to_play = np.expand_dims(waveform_to_play, axis=0)

    # Play audio
    sd.play(waveform_to_play.T, sample_rate)  # Transpose to match the expected shape (channels, samples)
    sd.wait()
    pass

In [11]:
def diarization_overlap_to_dataframe(diarization):
    overlap_str = str(diarization.get_overlap())
    
    # Split the text into lines
    lines = overlap_str.split('\n')

    # Initialize lists to store data
    start_times = []
    end_times = []

    # Iterate over each line and extract data
    for line in lines:
        parts = line.split('] ')
        time_parts = parts[0][2:].strip().split(' --> ')
        start_times.append(time_parts[0].strip())
        end_times.append(time_parts[1].replace(']','').strip())

    # Create a DataFrame
    df = pd.DataFrame({
        'Start Time': start_times,
        'End Time': end_times
    })

    return df

In [12]:
def diarization_to_dataframe(diarization):
    # Convert diarization object to string
    diarization_str = str(diarization)
    
    # Split the text into lines
    lines = diarization_str.split('\n')

    # Initialize lists to store data
    start_times = []
    end_times = []
    speaker_labels = []

    # Iterate over each line and extract data
    for line in lines:
        parts = line.split('] ')
        time_parts = parts[0][2:].strip().split(' --> ')
        speaker_label = parts[1].strip()[2:].strip()
        start_times.append(time_parts[0].strip())
        end_times.append(time_parts[1].strip())
        speaker_labels.append(speaker_label)

    # Create a DataFrame
    df = pd.DataFrame({
        'Start Time': start_times,
        'End Time': end_times,
        'Speaker Label': speaker_labels
    })

    # Convert time columns to datetime
    #df['Start Time'] = pd.to_datetime(df['Start Time'], infer_datetime_format=True)
    #df['End Time'] = pd.to_datetime(df['End Time'], infer_datetime_format=True)

    return df

In [13]:
df_da = diarization_to_dataframe(diarization)
df_da.head(10)

Unnamed: 0,Start Time,End Time,Speaker Label
0,00:00:00.008,00:01:14.388,SPEAKER_00
1,00:01:15.254,00:01:17.088,SPEAKER_01
2,00:01:17.292,00:01:29.889,SPEAKER_01
3,00:01:30.127,00:01:40.144,SPEAKER_01
4,00:01:40.636,00:01:41.315,SPEAKER_01
5,00:01:41.740,00:01:55.271,SPEAKER_01
6,00:01:55.475,00:02:46.273,SPEAKER_00
7,00:02:46.460,00:02:52.928,SPEAKER_00
8,00:02:53.285,00:02:54.864,SPEAKER_01
9,00:02:56.290,00:02:58.904,SPEAKER_01


In [14]:
######## Check if there is an overlap between the voices, if yes then extract that and show to the user
try:
    df_da1 = diarization_overlap_to_dataframe(diarization)
    df_da1['flag'] = 1
    df_da1.head(10)
except:
    df_da1 = pd.DataFrame({
        'Start Time': [],
        'End Time': [],
        'flag': []
    })
    

In [15]:
def calculate_total_seconds(time_str):
    time_str = time_str.strip()
    start_time = datetime.strptime(time_str, '%H:%M:%S.%f').time()
    time_delta = timedelta(hours=start_time.hour, minutes=start_time.minute, seconds=start_time.second, microseconds=start_time.microsecond)
    total_secs = time_delta.total_seconds()
    return total_secs



In [16]:
def extract_speaker_audio(waveform, sample_rate, df):
    speaker_waveforms = {}
    
    for index, row in df.iterrows():
        start_time = row['Start Time']
        end_time = row['End Time']
        speaker_label = row['Speaker Label']
        
        # Convert start and end times to indices
        start_index = int(calculate_total_seconds(start_time) * sample_rate)
        end_index = int(calculate_total_seconds(end_time) * sample_rate)

        # Extract segment from waveform
        segment_waveform = waveform[:, start_index:end_index]
        
        # If the speaker label is not already in the dictionary, create a new entry
        if speaker_label not in speaker_waveforms:
            speaker_waveforms[speaker_label] = [segment_waveform]
        else:
            # If the speaker label already exists, append the segment waveform to the list
            speaker_waveforms[speaker_label].append(segment_waveform)
    
    # Concatenate waveforms for each speaker
    concatenated_waveforms = {}
    for speaker_label, waveforms_list in speaker_waveforms.items():
        concatenated_waveforms[speaker_label] = torch.cat(waveforms_list, dim=1)
    
    return concatenated_waveforms

In [17]:
### IF we decide to delete the overlapped content from the voices
df_merge = df_da.merge(df_da1, on = ['Start Time', 'End Time'], how = 'left')
df_non_overlap = df_merge[df_merge['flag'].isnull()]
df_non_overlap = df_non_overlap.drop(['flag'], axis = 1)
df_non_overlap.head(5)

Unnamed: 0,Start Time,End Time,Speaker Label
0,00:00:00.008,00:01:14.388,SPEAKER_00
1,00:01:15.254,00:01:17.088,SPEAKER_01
2,00:01:17.292,00:01:29.889,SPEAKER_01
3,00:01:30.127,00:01:40.144,SPEAKER_01
4,00:01:40.636,00:01:41.315,SPEAKER_01


In [18]:
## If we need to use full audio including overlaps
#speaker_waveforms = extract_speaker_audio(waveform, sample_rate, df_da)

## If we don't want to use overlapped audio'
speaker_waveforms = extract_speaker_audio(waveform, sample_rate, df_non_overlap)

In [19]:
speaker_waveforms

{'SPEAKER_00': tensor([[ 0.0000,  0.0000,  0.0000,  ..., -0.0026, -0.0026, -0.0024],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0026, -0.0026, -0.0024]]),
 'SPEAKER_01': tensor([[-0.0013, -0.0010, -0.0016,  ..., -0.0044, -0.0048, -0.0051],
         [-0.0013, -0.0010, -0.0016,  ..., -0.0044, -0.0048, -0.0051]])}

In [20]:
def play_wav(waveform, sample_rate, playback_duration=30):
    # Define the duration to play (in seconds)
    num_samples_to_play =  int(playback_duration * sample_rate)
    waveform_to_play = waveform[:, :num_samples_to_play]
    
    # Convert tensor to numpy array
    waveform_np = waveform_to_play.numpy()
    
    # Check if the waveform has only one channel, if yes, expand it to two channels
    if waveform_np.ndim == 1:
        waveform_np = waveform_np.reshape(1, -1)

    # Play audio
    sd.play(waveform_np.T, sample_rate)  # Transpose to match the expected shape (channels, samples)
    sd.wait()

In [24]:
# Play the audio of each speaker
for speaker_label, waveform in speaker_waveforms.items():
    if speaker_label == 'SPEAKER_00':
        print(f"Playing audio for speaker {speaker_label}")
        play_wav(waveform, sample_rate, playback_duration=30)

    ## Save the files
    torchaudio.save("debate_{}.wav".format(speaker_label), waveform, sample_rate)

Playing audio for speaker SPEAKER_00
