<a href="https://colab.research.google.com/github/Ehsan77e/ASR-with-Diarization/blob/main/ASR_Implementation_With_LM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1>Dependencies </h1>
<p>run the cell below, then restart the sesseion through runtime menu for them to work</p>
if you want to use diarization 3.0, install either onnxruntime or gpu veriation of it

In [2]:
!pip install transformers
# import locale
# locale.getpreferredencoding = lambda: "UTF-8"
!pip install pyctcdecode https://github.com/kpu/kenlm/archive/master.zip # kenlm
!pip install -qq https://github.com/pyannote/pyannote-audio/archive/refs/heads/develop.zip   # install pyannote

# some requirements for diarization 3.0, install one
#!pip install onnxruntime       # if using CPU
!pip install onnxruntime-gpu    # if using GPU

!pip uninstall speechbrain
!pip install speechbrain==0.5.16

Collecting https://github.com/kpu/kenlm/archive/master.zip
  Downloading https://github.com/kpu/kenlm/archive/master.zip (553 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m553.6/553.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pyctcdecode
  Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode)
  Downloading pygtrie-2.5.0-py3-none-any.whl (25 kB)
Collecting hypothesis<7,>=6.14 (from pyctcdecode)
  Downloading hypothesis-6.103.1-py3-none-any.whl (461 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m461.2/461.2 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: kenlm
  Building wheel for kenlm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for kenlm: filename=kenlm-0

In [1]:
from google.colab import drive
drive.mount('/content/drive')

import transformers
import torch
import librosa
import IPython.display as display
import librosa

from pyannote.audio import Model, Pipeline


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


<h1> ASR MODEL </h1>
<p> the basic class to transcribe audio. keep in mind that we are also using a lm with it that you need to have it before you proceed. check out the kenlm to learn how to make a simple language model.</p>
<p>the ASR model used in this notebook is the one I have trained myself and cannot share in public,however, any wav2vec2 model that has a lm will work</p>

In [2]:
class ASR:
    def __init__(self, model_path, processor_path, computation_method = 'cpu'):
        self.model = transformers.Wav2Vec2ForCTC.from_pretrained(model_path).to(torch.device(computation_method))
        self.processor = transformers.Wav2Vec2ProcessorWithLM.from_pretrained(processor_path)
        self.computation_method = computation_method



    def transcribe(self, audio_path, start=0, end=None):
        if end == None:         # if start and end are not set by request
            speech, rate = librosa.load(audio_path, sr=16000)

        else:                   # if start and end are set by request
            speech, rate = librosa.load(audio_path, sr=16000, offset=start, duration=end - start)

        input_values = self.processor(speech, sampling_rate=16_000, return_tensors='pt').to(self.computation_method)
        with torch.no_grad():   # the actual computation
            logits = self.model(input_values.input_values, attention_mask=input_values.attention_mask).logits

        if self.computation_method == "cuda":
            logits = logits.to('cpu')
        logits = logits.numpy()  # Convert the tensor to a NumPy array
        transcription = self.processor.batch_decode(logits).text
        return transcription[0]

    def display_audio(self, audio_path):
        return display.Audio(audio_path)

# creating an instance of it
<p> it's recommended to use GPU otherwise computations will take a long time</p>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'model_name_or_path'
processor_path = 'lm_decore_name_or_path'
asr_model = ASR(model_path = model_path, processor_path=processor_path, computation_method = device)

Some weights of the model checkpoint at /content/drive/MyDrive/asr_related/model_training were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at /content/drive/MyDrive/asr_related/model_training and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.origi

# testing it out on a sample audio

In [None]:
audio_path = 'sample_audio_path'
asr_model.transcribe(audio_path)

'ای در پایینو کامل در بیاره نه بابا شیششو عزیزم فقط شیشش هشتصد هزار تومن هشتصد هفتاد پنج شم با دستگیره چه خبره اینجوریه دیگه الآن زنگ زد به یه جای دیگه گفت پونص ششصد تا شیشت باید شماره سریال اجاق گازتو بفرستین شماره سریال اجاق گاز رو در بیار کجا در بیاره'

# displaying the sample audio if you need

In [None]:
asr_model.display_audio(audio_path)

# the basic class for a diarization model
<p> speaker diarization is a task which recognizes which speaker is speaking at different timestamps</p>

In [None]:
class Diarization_Model:
    def __init__(self, model_path,computation_method = 'cpu', auth_token=None):
        if auth_token is None:
            self.model = Pipeline.from_pretrained(model_path)


        else:
            self.model = Pipeline.from_pretrained(
            model_path,
            use_auth_token=auth_token)
        self.model.to(torch.device(computation_method))


# ********************************************************************************************************************************************************

    def create_spoken_time_periods(self, diarization, speaker_00_name, speaker_01_name):

        # the spoken_time_periods is a list, each object of it is also a list, where the first item is start, second is end, and
        # the third is the speaker
        self.spoken_time_periods = []

        for turn, _, speaker in diarization.itertracks(yield_label=True):
            if speaker.lower() == 'speaker_00':
                self.spoken_time_periods.append({'start': turn.start, 'end': turn.end, 'speaker': speaker_00_name})
            else:
                self.spoken_time_periods.append({'start': turn.start, 'end': turn.end, 'speaker': speaker_01_name})

        self.spoken_time_periods = sorted(self.spoken_time_periods, key=lambda x: x['start'])


# ********************************************************************************************************************************************************

    def concat_same_speaker(self):
        self.spoken_time_periods_concatenated = []

        i = 0 # for i'th item
        while i < len(self.spoken_time_periods):
            j = 1  # for j'th item in front of the i'th one
            still_same_person = True
            start = self.spoken_time_periods[i]['start']
            end = self.spoken_time_periods[i]['end']
            speaker = self.spoken_time_periods[i]['speaker']

            while still_same_person:
                if i + j < len(self.spoken_time_periods):
                    if speaker == self.spoken_time_periods[i+j]['speaker']:
                        end = self.spoken_time_periods[i+j]['end']
                        j += 1

                    else:
                        still_same_person = False
                        self.spoken_time_periods_concatenated.append({'start': start, 'end': end, 'speaker': speaker})
                        i += j
                else:
                    still_same_person = False
                    self.spoken_time_periods_concatenated.append({'start': start, 'end': end, 'speaker': speaker})
                    i += j


# *********************************************** concurrent speaking removed version or "enhanced version" **********************************************
# it detects how many times each speaker interupted and the overal time they spoke concurrently

    def detect_concurrent_speaking(self, speaker_00_name, speaker_01_name,\
                                    word_for_concurrent_speeches, margin_for_interruption):


        self.concurrent_speech_time = 0
        self.speaker_00_interupts = 0
        self.speaker_01_interupts = 0
        self.word_for_concurrent_speeches = word_for_concurrent_speeches
        self.concurrent_time_periods = []

        i = 0 # for i'th speech
        while True:

            for j in range(i+1, len(self.spoken_time_periods)): # loop over the remaining speeches to detect any concorrent speeches

                if self.spoken_time_periods[i]['end'] > self.spoken_time_periods[j]['start'] + margin_for_interruption and \
                self.spoken_time_periods[i]['speaker'] != self.spoken_time_periods[j]['speaker']:   # detecting interruptions

                    interrupter = self.spoken_time_periods[j]['speaker'] # detecting the interrupter
                    if interrupter == speaker_00_name:
                        self.speaker_00_interupts += 1
                    else:
                        self.speaker_01_interupts += 1

                    start_point = self.spoken_time_periods[j]['start'] # the point when they start speaking at the same time
                    end_point = min(self.spoken_time_periods[i]['end'], self.spoken_time_periods[j]['end']) # the point when it ends
                    self.concurrent_speech_time += end_point - start_point

                    self.concurrent_time_periods.append([start_point, end_point])

            i += 1
            if i == len(self.spoken_time_periods):
                break


# ***************************************************************************************************************

    def diarize(self, audio_path, num_speakers = 2,speaker_00_name = 'speaker_00', speaker_01_name = 'speaker_01',\
                word_for_concurrent_speeches = 'concurrent speaking', margin_for_interruption = 0.10):


        diarization = self.model(audio_path, num_speakers=num_speakers)
        self.create_spoken_time_periods(diarization, speaker_00_name, speaker_01_name)
        self.detect_concurrent_speaking(speaker_00_name, speaker_01_name, word_for_concurrent_speeches, margin_for_interruption)
        self.concat_same_speaker()
        # self.detect_concurrent_speaking(diarization, speaker_00_name, speaker_01_name,\
        #                                 word_for_concurrent_speeches, margin_for_interruption)

        return self.spoken_time_periods,self.spoken_time_periods_concatenated

# initializing the diarization model
<p>you need an auth token from hugging face to access pyannote. it's very simple to do.</p>

In [None]:
model_path = 'pyannote/speaker-diarization-3.0'
auth = 'your_hugging_face_auth_token'
diarization_model = Diarization_Model(model_path = model_path, auth_token = auth)


INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.1.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint drive/MyDrive/packages/segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.1.0+cu121. Bad things might happen unless you revert torch to 1.x.


INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.2.7 to v2.1.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint drive/MyDrive/packages/embedding.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.8.1+cu102, yours is 2.1.0+cu121. Bad things might happen unless you revert torch to 1.x.


# cheking out the diarization model
<p> lets show case all different sorts of outputs made by diarization model</p>

In [None]:
audio_path = 'sample_audio_path'
spoken_time_periods,concatenated_spoken_time_periods =  diarization_model.diarize(audio_path = audio_path)


In [1]:
#asr_model.display_audio(audio_path)

In [None]:
spoken_time_periods

[{'start': 0.8788395904436861,
  'end': 1.5102389078498295,
  'speaker': 'speaker_01'},
 {'start': 1.9368600682593857,
  'end': 3.063139931740614,
  'speaker': 'speaker_00'},
 {'start': 3.643344709897611,
  'end': 4.991467576791809,
  'speaker': 'speaker_01'},
 {'start': 4.325938566552901,
  'end': 29.99146757679181,
  'speaker': 'speaker_00'}]

In [None]:
concatenated_spoken_time_periods

[{'start': 0.008532423208191127,
  'end': 0.8959044368600683,
  'speaker': 'speaker_01'},
 {'start': 0.5034129692832765,
  'end': 1.7832764505119454,
  'speaker': 'speaker_00'},
 {'start': 2.2098976109215016,
  'end': 3.063139931740614,
  'speaker': 'speaker_01'},
 {'start': 3.063139931740614,
  'end': 8.011945392491468,
  'speaker': 'speaker_00'},
 {'start': 7.1416382252559725,
  'end': 9.513651877133105,
  'speaker': 'speaker_01'},
 {'start': 9.513651877133105,
  'end': 12.107508532423209,
  'speaker': 'speaker_00'},
 {'start': 11.66382252559727,
  'end': 13.148464163822526,
  'speaker': 'speaker_01'},
 {'start': 13.063139931740615,
  'end': 24.75255972696246,
  'speaker': 'speaker_00'},
 {'start': 24.75255972696246,
  'end': 28.09726962457338,
  'speaker': 'speaker_01'}]

In [None]:
# interrupts0

4

In [None]:
# interrupts1

0

In [None]:
# concurrent_time

1.672354948805461

In [None]:
# concurrent_time_periods

[[0.5034129692832765, 0.8617747440273038],
 [7.1416382252559725, 8.011945392491468],
 [11.680887372013652, 12.107508532423209],
 [13.063139931740615, 13.080204778156997]]

# Combining ASR & diarizatoin into speechmapper:

In [None]:
class SpeechMapper:
    def __init__(self, asr_model: ASR, diarization_model: Diarization_Model):
        self.asr_model = asr_model
        self.diarization_model = diarization_model

    def transcribe_with_diarization_list(self, audio_path, padding_for_end = 0.1, padding_for_start = 0.1):

        # the purpose of this function is to create mapped_transcription, a list where first item is start, second is
        # end, third for speaker, forth for transcription
        spoken_time_periods, concatenated_spoken_time_periods = self.diarization_model.diarize(audio_path = audio_path)
        audio, sr = librosa.load(audio_path)  # reading the audio file initially to compute its durition

        # Get the duration of the audio in seconds
        audio_lentgh = librosa.get_duration(y=audio, sr=sr)

        self.mapped_transcription = []

        for speech in concatenated_spoken_time_periods:
            start = max(speech['start'] - padding_for_start, 0)
            end = min(speech['end'] + padding_for_end, audio_lentgh)
            speaker = speech['speaker']

            if speaker == self.diarization_model.word_for_concurrent_speeches:      # commenting out the conjunctions
                transcription = self.asr_model.transcribe(audio_path, start, end)
                speech['transcription'] = transcription
                #speech.append("")

            else:
                transcription = self.asr_model.transcribe(audio_path, start, end )
                speech['transcription'] = transcription
            self.mapped_transcription.append(speech)


    def beutified_transcription(self, audio_path):
        self.transcribe_with_diarization_list(audio_path)

        for speech in self.mapped_transcription:
            start = round(speech['start'], 2)
            end = round(speech['end'], 2)
            speaker = speech['speaker']
            transcription = speech['transcription']
            if speaker != self.diarization_model.word_for_concurrent_speeches:
                print(f'speaker {speaker} from {start} to {end} said:\n {transcription}' )
            else:
                #print(f'\n //  from {start} to {end} the speakers are concurrently speaking // \n' )

                print(f'from {start} to {end} they are concurrently speaking:\n {transcription}' )



In [None]:
speech_mapper = SpeechMapper(asr_model, diarization_model)

In [None]:
audio_path = 'sample_audio_path'

In [None]:
speech_mapper.beutified_transcription(audio_path)

speaker speaker_01 from 0.01 to 0.9 said:
 کوچک وی
speaker speaker_00 from 0.5 to 1.78 said:
 وی مثل دریور درست
speaker speaker_01 from 2.21 to 3.06 said:
 بره آره آره
speaker speaker_00 from 3.06 to 8.01 said:
 بله شونه فقط می زدی دیگه ای گوشم که نمیزنین میزنیشون نه حالی
speaker speaker_01 from 7.14 to 9.51 said:
 نه نه می شونزده خالی بله
speaker speaker_00 from 9.51 to 12.11 said:
 الآن پشت خط هستی من یه امتحان بکنم
speaker speaker_01 from 11.66 to 13.15 said:
 مرسی لطف بکن
speaker speaker_00 from 13.06 to 24.75 said:
 می زنم و می کاتونای یه خر چی مشکرمه
speaker speaker_01 from 24.75 to 28.1 said:
 مسی خیلی لطف کردین فداتون به سم مربنه درامای فاطمه
