# Automatic Speech Recognition combined with Speaker Diarization

In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.

Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies.
"""
# If you're using Google Colab and not running locally, run this cell.

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install unidecode

# ## Install NeMo
BRANCH = 'v1.0.0'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]

## Install TorchAudio
!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html

# Introduction

In the early years, speaker diarization algorithms were developed for speech recognition on multispeaker audio recordings to enable speaker adaptive processing, but also gained its own value as a stand-alone application over
time to provide speaker-specific meta information for downstream tasks such as audio retrieval.
Automatic Speech Recognition output when combined with Speaker labels has shown immense use in many tasks, ranging from analyzing telephonic conversation to decoding meeting transcriptions. 

In this tutorial we demonstrate how one can get ASR transcriptions combined with Speaker labels along with voice activity time stamps using NeMo asr collections. 

For detailed understanding of transcribing words with ASR refer to this [ASR tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/01_ASR_with_NeMo.ipynb), and for detailed understanding of speaker diarizing an audio refer to this [Diarization inference](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_recognition/Speaker_Diarization_Inference.ipynb) tutorial

Let's first import nemo asr and other libraries for visualization purposes

In [None]:
import nemo.collections.asr as nemo_asr
import numpy as np
from IPython.display import Audio, display
import librosa
import os
import wget
import matplotlib.pyplot as plt

We demonstrate this tutorial using merged an4 audio, that has two speakers(male and female) speaking dates in different formats. If not exists already download the data and listen to it

In [None]:
ROOT = os.getcwd()
data_dir = os.path.join(ROOT,'data')
os.makedirs(data_dir, exist_ok=True)
an4_audio_url = "https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav"
if not os.path.exists(os.path.join(data_dir,'an4_diarize_test.wav')):
    AUDIO_FILENAME = wget.download(an4_audio_url, data_dir)
else:
    AUDIO_FILENAME = os.path.join(data_dir,'an4_diarize_test.wav')

signal, sample_rate = librosa.load(AUDIO_FILENAME, sr=None)
display(Audio(signal,rate=sample_rate))

In [None]:
def show_figure(signal,text='Audio',overlay_color=[]):
    fig,ax = plt.subplots(1,1)
    fig.set_figwidth(20)
    fig.set_figheight(2)
    plt.scatter(np.arange(len(signal)),signal,s=1,marker='o',c='k')
    if len(overlay_color):
        plt.scatter(np.arange(len(signal)),signal,s=1,marker='o',c=overlay_color)
    fig.suptitle(text, fontsize=16)
    plt.xlabel('time (secs)', fontsize=18)
    plt.ylabel('signal strength', fontsize=14);
    plt.axis([0,len(signal),-0.5,+0.5])
    time_axis,_ = plt.xticks();
    plt.xticks(time_axis[:-1],time_axis[:-1]/sample_rate);


plot the audio 

In [None]:
show_figure(signal)

We start our demonstration by first transcribing the audio using our pretrained model `QuartzNet15x5Base-En` and use the CTC output probabilities to get timestamps for words spoken. We then later use these timestamps to get speaker label information using speaker diarizer model. 

Download and load pretrained quartznet asr model

In [None]:
#Load model
asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='QuartzNet15x5Base-En', strict=False)

Transcribe the audio 

In [None]:
files = [AUDIO_FILENAME]
transcript = asr_model.transcribe(paths2audio_files=files)[0]
print(f'Transcript: "{transcript}"')

Get CTC log probabilities with output labels

In [None]:
# softmax implementation in NumPy
def softmax(logits):
    e = np.exp(logits - np.max(logits))
    return e / e.sum(axis=-1).reshape([logits.shape[0], 1])

# let's do inference once again but without decoder
logits = asr_model.transcribe(files, logprobs=True)[0]
probs = softmax(logits)

# 20ms is duration of a timestep at output of the model
time_stride = 0.02

# get model's alphabet
labels = list(asr_model.decoder.vocabulary) + ['blank']
labels[0] = 'space'

We use CTC labels for voice activity detection. To detect speech and non-speech segments in the audio, we use blank and space labels in the CTC outputs. Consecutive labels with spaces or blanks longer than a threshold are considered non-speech segments

In [None]:

blanks = []

state = ''
idx_state = 0

if np.argmax(probs[0]) == 28:
    state = 'blank'

for idx in range(1, probs.shape[0]):
    current_char_idx = np.argmax(probs[idx])
    if state == 'blank' and current_char_idx != 0 and current_char_idx != 28:
        blanks.append([idx_state, idx-1])
        state = ''
    if state == '':
        if current_char_idx == 28:
            state = 'blank'
            idx_state = idx

if state == 'blank':
    blanks.append([idx_state, len(probs)-1])

threshold=20 #minimun width to consider non-speech activity 
non_speech=list(filter(lambda x:x[1]-x[0]>threshold,blanks)) 

# get timestamps for space symbols
spaces = []

state = ''
idx_state = 0

if np.argmax(probs[0]) == 0:
    state = 'space'

for idx in range(1, probs.shape[0]):
    current_char_idx = np.argmax(probs[idx])
    if state == 'space' and current_char_idx != 0 and current_char_idx != 28:
        spaces.append([idx_state, idx-1])
        state = ''
    if state == '':
        if current_char_idx == 0:
            state = 'space'
            idx_state = idx

if state == 'space':
    spaces.append([idx_state, len(pred)-1])
# calibration offset for timestamps: 180 ms
offset = -0.18

# split the transcript into words
words = transcript.split()

Frame level stamps for non speech frames 

In [None]:
print(non_speech)

write to rttm type file for later use in extracting speaker labels

In [None]:
frame_offset=offset/time_stride
speech_labels=[]
uniq_id = os.path.basename(AUDIO_FILENAME).split('.')[0]
with open(uniq_id+'.rttm','w') as f:
    for idx in range(len(non_speech)-1):
        start = (non_speech[idx][1]+frame_offset)*time_stride
        end = (non_speech[idx+1][0]+frame_offset)*time_stride
        f.write("SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> speech <NA>\n".format(uniq_id,start,end-start))
        speech_labels.append("{:.3f} {:.3f} speech".format(start,end))
    if non_speech[-1][1] < len(probs):
        start = (non_speech[-1][1]+frame_offset)*time_stride
        end = (len(probs)+frame_offset)*time_stride
        f.write("SPEAKER {} 1 {:.3f} {:.3f} <NA> <NA> speech <NA>\n".format(uniq_id,start,end-start))
        speech_labels.append("{:.3f} {:.3f} speech".format(start,end))

Time stamps for speech frames

In [None]:
print(speech_labels)

In [None]:
COLORS="b g c m y".split()
def get_color(signal,speech_labels,sample_rate=16000):
    c=np.array(['k']*len(signal))
    for time_stamp in speech_labels:
        start,end,label=time_stamp.split()
        start,end = int(float(start)*16000),int(float(end)*16000),
        if label == "speech":
            code = 'red'
        else:
            code = COLORS[int(label.split('_')[-1])]
        c[start:end]=code
    
    return c 

With voice activity time stamps extracted from CTC outputs, here we show the Voice Activity signal in <span style="color:red">**red**</span> color and background speech in **black** color

In [None]:
color=get_color(signal,speech_labels)
show_figure(signal,'an4 audio signal with vad',color)

We use helper function from speaker utils to convert voice activity rttm file to manifest to diarize using 
speaker diarizer clustering inference model

In [None]:
from nemo.collections.asr.parts.speaker_utils import write_rttm2manifest
output_dir = os.path.join(ROOT, 'oracle_vad')
os.makedirs(output_dir,exist_ok=True)
oracle_manifest = os.path.join(output_dir,'oracle_manifest.json')
write_rttm2manifest(paths2audio_files=files,
                    paths2rttm_files=[uniq_id+'.rttm'],
                    manifest_file=oracle_manifest)

In [None]:
!cat {output_dir}/oracle_manifest.json

Set up diarizer model 

In [None]:
from omegaconf import OmegaConf
MODEL_CONFIG = os.path.join(data_dir,'speaker_diarization.yaml')
if not os.path.exists(MODEL_CONFIG):
    config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_recognition/conf/speaker_diarization.yaml"
    MODEL_CONFIG = wget.download(config_url,data_dir)   
    
config = OmegaConf.load(MODEL_CONFIG)

pretrained_speaker_model='speakerdiarization_speakernet'
config.diarizer.paths2audio_files = files
config.diarizer.out_dir = output_dir #Directory to store intermediate files and prediction outputs

config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
# Ignoring vad we just need to pass the manifest file we created
config.diarizer.speaker_embeddings.oracle_vad_manifest = oracle_manifest
config.diarizer.oracle_num_speakers = 2

Diarize the audio at provided time stamps

In [None]:
from nemo.collections.asr.models import ClusteringDiarizer
oracle_model = ClusteringDiarizer(cfg=config);
oracle_model.diarize();

In [None]:
from nemo.collections.asr.parts.speaker_utils import rttm_to_labels
pred_rttm=os.path.join(output_dir,'pred_rttms',uniq_id+'.rttm')
labels=rttm_to_labels(pred_rttm)
print("speaker labels with time stamps\n",labels)

Now let us see the audio plot color coded per speaker

In [None]:
color=get_color(signal,labels)
show_figure(signal,'audio with speaker labels',color)
display(Audio(signal,rate=16000))

Finally transcribe audio with time stamps and speaker label information

In [None]:
pos_prev = 0
idx=0
start_point,end_point,speaker=labels[idx].split()
print("{} [{:.2f} - {:.2f} sec]".format(speaker,float(start_point),float(end_point)),end=" ")
for j, spot in enumerate(spaces):
    pos_end = offset + (spot[0]+spot[1])/2*time_stride
    if pos_prev < float(end_point):
        print(words[j],end=" ")
    else:
        print()
        idx+=1
        start_point,end_point,speaker=labels[idx].split()
        print("{} [{:.2f} - {:.2f} sec]".format(speaker,float(start_point),float(end_point)),end=" ")
        print(words[j],end=" ")
    pos_prev = pos_end


print(words[j+1],end=" ")