<a href="https://colab.research.google.com/github/Ehsan77e/ASR-with-Diarization/blob/main/Wav2vec2_With_LM_combined_with_Diarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1>Dependencies </h1>
<p>run the cell below, then restart the runtime for them to work</p>
if you want to use diarization 3.0, install either onnxruntime or gpu veriation of it

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install transformers
import transformers
import torch
import librosa
!pip install pyctcdecode https://github.com/kpu/kenlm/archive/master.zip # kenlm
import IPython.display as display
import librosa


!pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip   # install pyannote
from pyannote.audio import Model, Pipeline

# some requirements for diarization 3.0, install one
#pip install onnxruntime
#pip install onnxruntime-gpu

<h1> ASR MODEL <h1>
<p> the basic class to transcribe audio. keep in mind that we are also using a lm with it that you need to have it before you proceed. check out the kenlm to learn how to make a simple language model.</p>

In [None]:
class ASR:
    def __init__(self, model_path, processor_path, computation_method = 'cpu'):
        self.model = transformers.Wav2Vec2ForCTC.from_pretrained(model_path).to(torch.device(computation_method))
        self.processor = transformers.Wav2Vec2ProcessorWithLM.from_pretrained(processor_path)
        self.computation_method = computation_method



    def transcribe(self, audio_path, start=0, end=None):
        if end == None:         # if start and end are not set by request
            speech, rate = librosa.load(audio_path, sr=16000)

        else:                   # if start and end are set by request
            speech, rate = librosa.load(audio_path, sr=16000, offset=start, duration=end - start)

        input_values = self.processor(speech, sampling_rate=16_000, return_tensors='pt').to(torch.device(self.computation_method))
        with torch.no_grad():   # the actual computation
            logits = self.model(input_values.input_values, attention_mask=input_values.attention_mask).logits

        if self.computation_method == "cpu":
            logits = logits.numpy()  # Convert the tensor to a NumPy array
        transcription = self.processor.batch_decode(logits).text
        return transcription[0]

    def display_audio(self, audio_path):
        return display.Audio(audio_path)

# creating an instance of it

In [None]:
model_path = 'model_name_or_path'
processor_path = 'processor_name_or_path'
asr_model = ASR(model_path = model_path, processor_path=processor_path)

# testing it out on a sample audio

In [None]:
audio_path = 'sample_audio_path'
asr_model.transcribe(audio_path)

# displaying the sample audio if you need

In [None]:
asr_model.display_audio(audio_path)

# the basic class for a diarization model
<p> speaker diarization is a task which recognizes which speaker is speaking at different timestamps</p>

In [None]:
class Diarization_Model:
    def __init__(self, model_path,computation_method = 'cpu', auth_token=None):
        if auth_token is None:
            self.model = Pipeline.from_pretrained(model_path)
            self.model.to(torch.device(computation_method))

        else:
            self.model = Pipeline.from_pretrained(
            model_path,
            use_auth_token=auth_token)


# ********************************************************************************************************************************************************

    def create_spoken_time_periods(self, diarization, speaker_00_name, speaker_01_name):

        # the spoken_time_periods is a list, each object of it is also a list, where the first item is start, second is end, and
        # the third is the speaker
        self.spoken_time_periods = []

        for turn, _, speaker in diarization.itertracks(yield_label=True):
            if speaker.lower() == 'speaker_00':
                self.spoken_time_periods.append({'start': turn.start, 'end': turn.end, 'speaker': speaker_00_name})
            else:
                self.spoken_time_periods.append({'start': turn.start, 'end': turn.end, 'speaker': speaker_01_name})

        self.spoken_time_periods = sorted(self.spoken_time_periods, key=lambda x: x['start'])


# ********************************************************************************************************************************************************

    def concat_same_speaker(self):
        self.spoken_time_periods_concated = []

        i = 0 # for i'th item
        while i < len(self.spoken_time_periods):
            j = 1  # for j'th item in front of the i'th one
            still_same_person = True
            start = self.spoken_time_periods[i]['start']
            end = self.spoken_time_periods[i]['end']
            speaker = self.spoken_time_periods[i]['speaker']

            while still_same_person:
                if i + j < len(self.spoken_time_periods):
                    if speaker == self.spoken_time_periods[i+j]['speaker']:
                        end = self.spoken_time_periods[i+j]['end']
                        j += 1

                    else:
                        still_same_person = False
                        self.spoken_time_periods_concated.append({'start': start, 'end': end, 'speaker': speaker})
                        i += j
                else:
                    still_same_person = False
                    self.spoken_time_periods_concated.append({'start': start, 'end': end, 'speaker': speaker})
                    i += j


# *********************************************** concurrent speaking removed version or "enhanced version" **********************************************
# it detects how many times each speaker interupted and the overal time they spoke concurrently

    def detect_concurrent_speaking(self, speaker_00_name, speaker_01_name,\
                                    word_for_concurrent_speeches, margin_for_interruption):


        self.concurrent_speech_time = 0
        self.speaker_00_interupts = 0
        self.speaker_01_interupts = 0
        self.word_for_concurrent_speeches = word_for_concurrent_speeches
        self.concurrent_time_periods = []

        i = 0 # for i'th speech
        while True:

            for j in range(i+1, len(self.spoken_time_periods)): # loop over the remaining speeches to detect any concorrent speeches

                if self.spoken_time_periods[i]['end'] > self.spoken_time_periods[j]['start'] + margin_for_interruption and \
                self.spoken_time_periods[i]['speaker'] != self.spoken_time_periods[j]['speaker']:   # detecting interruptions

                    interrupter = self.spoken_time_periods[j]['speaker'] # detecting the interrupter
                    if interrupter == speaker_00_name:
                        self.speaker_00_interupts += 1
                    else:
                        self.speaker_01_interupts += 1

                    start_point = self.spoken_time_periods[j]['start'] # the point when they start speaking at the same time
                    end_point = min(self.spoken_time_periods[i]['end'], self.spoken_time_periods[j]['end']) # the point when it ends
                    self.concurrent_speech_time += end_point - start_point

                    self.concurrent_time_periods.append([start_point, end_point])

            i += 1
            if i == len(self.spoken_time_periods):
                break


# ***************************************************************************************************************

    def diarize(self, audio_path, num_speakers = 2,speaker_00_name = 'speaker_00', speaker_01_name = 'speaker_01',\
                word_for_concurrent_speeches = 'concurrent speaking', margin_for_interruption = 0.10):


        diarization = self.model(audio_path, num_speakers=num_speakers)
        self.create_spoken_time_periods(diarization, speaker_00_name, speaker_01_name)
        self.detect_concurrent_speaking(speaker_00_name, speaker_01_name, word_for_concurrent_speeches, margin_for_interruption)
        self.concat_same_speaker()
        # self.detect_concurrent_speaking(diarization, speaker_00_name, speaker_01_name,\
        #                                 word_for_concurrent_speeches, margin_for_interruption)

        return self.spoken_time_periods,self.spoken_time_periods_concated

# initializing the diarization model

In [None]:
model_path = 'pyannote/speaker-diarization-3.0'
auth = 'your_auth_token'
diarization_model = Diarization_Model(model_path = model_path, auth_token = auth)

# you can use code below if you have diarization model mounted on your drive:

# model_path = 'model_name_or_path'
# diarization_model = Diarization_Model(model_path = model_path)



# combining ASR and diarization model

In [None]:
class SpeechMapper:
    def __init__(self, asr_model: ASR, diarization_model: Diarization_Model):
        self.asr_model = asr_model
        self.diarization_model = diarization_model

    def transcribe_with_diarization_list(self, audio_path, padding_for_end = 0.1, padding_for_start = 0.1):

        # the purpose of this function is to create mapped_transcription, a list where first item is start, second is
        # end, third for speaker, forth for transcription
        spoken_time_periods, concated_spoken_time_periods = self.diarization_model.diarize(audio_path = audio_path)
        audio, sr = librosa.load(audio_path)  # reading the audio file initially to compute its durition

        # Get the duration of the audio in seconds
        audio_lentgh = librosa.get_duration(y=audio, sr=sr)

        self.mapped_transcription = []

        for speech in concated_spoken_time_periods:
            start = max(speech['start'] - padding_for_start, 0)
            end = min(speech['end'] + padding_for_end, audio_lentgh)
            speaker = speech['speaker']

            if speaker == self.diarization_model.word_for_concurrent_speeches:      # commenting out the conjunctions
                transcription = self.asr_model.transcribe(audio_path, start, end)
                speech['transcription'] = transcription
                #speech.append("")

            else:
                transcription = self.asr_model.transcribe(audio_path, start, end )
                speech['transcription'] = transcription
            self.mapped_transcription.append(speech)


    def beautified_transcription(self, audio_path):
        self.transcribe_with_diarization_list(audio_path)

        for speech in self.mapped_transcription:
            start = round(speech['start'], 2)
            end = round(speech['end'], 2)
            speaker = speech['speaker']
            transcription = speech['transcription']
            if speaker != self.diarization_model.word_for_concurrent_speeches:
                print(f'speaker {speaker} from {start} to {end} said:\n {transcription}' )
            else:
                #print(f'\n //  from {start} to {end} the speakers are concurrently speaking // \n' )

                print(f'from {start} to {end} they are concurrently speaking:\n {transcription}' )



# creating an instance of if

In [None]:
speech_mapper = SpeechMapper(asr_model, diarization_model)

# testing on a sample Audio

In [None]:
audio_path = 'sample_audio_path'
speech_mapper.beautified_transcription(audio_path)