-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
218 lines (170 loc) · 9.88 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
from pyannote.audio import Pipeline
import torchaudio, gc, os
# from dotenv import load_dotenv
from datetime import datetime
from Levenshtein import distance as levenshtein_distance
import numpy, whisperx
from docx import Document
from typing import Type, Union
from utils import get_audio, get_segments, post_process_bn, numpytobytes
# load_dotenv()
class CONFIG:
device='cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
chunk_length_s=30
batch_size=12
torch_dtype=torch.float16
# token = os.getenv('HUGGINGFACE_TOKEN')
transcription_model='kabir5297/whisper_bn_medium'
diarization_model="pyannote/speaker-diarization-3.1"
en_transcription_model = 'distil-whisper/distil-large-v2'
punctuation_model = 'kabir5297/bn_punctuation_model'
default_language = 'bn'
keywords = ['ডিউলাক্স','নিরোলাক']
threshold = 0.75
id2label = {0: 'LABEL_0', 1: 'LABEL_1', 2: 'LABEL_2', 3: 'LABEL_3'}
label2id = {'LABEL_0': 0, 'LABEL_1': 1, 'LABEL_2': 2, 'LABEL_3': 3}
id2punc = {0: '', 1: '।', 2: ',', 3: '?'}
class TranscriberAgent():
def __init__(self,CONFIG: CONFIG = CONFIG) -> None:
'''
This is the main agent. You can use the mehtods for raw transcription, diarization and conversation generation, adding punctuation and getting highlighted keywords.
Arguements:
-----------
CONFIG: Class of configs. It contains the following values in a class.
device: Device to be used for loading models and inference. By default, it looks for a GPU and if not available, CPU is used.
chunk_length: Audio chunk length for Whisper transcription in seconds. Max value is 30s. Audio longer than 30s will be created into chunks and then transcripted. Defualt value is 30s.
batch_size: Batch size used for model inference. Default value is 24.
torch_dtype: The insanely fast Whisper inference requires the inference data type to be float16.
token: Huggingface token.
transcription_model: Repository for Bengali transcription model.
diarization_model: Repository for Diarization model. By default, we use pyannote version 3.1
en_transcription_model: Repository for English transcription model.
punctuation_model: Repository for punctuation model.
default_language: The methods can take two languages for input, 'bn' and 'en'. If not defined, the models will use 'bn' by default.
keywords: List of words to be searched in the transcriptions. The default list of words for keywords are: ['ডিউলাক্স','এসিআই','নিরোলাক']. You can use your own words to be searched simply with comma (,) separated list.
threshold: The keyword search method uses Levenshtein distance for calculating similarity. The similarity threshold is then used for identifying similar words. The default value is 0.75.
id2label: The id2label dictionary for punctuation model.
label2id: The label2id dictionary for punctuation model.
id2punc: Dictionary used for converting ids to punctuations.
'''
self.CONFIG = CONFIG
self.transcription_pipeline = pipeline(
task="automatic-speech-recognition",
model=self.CONFIG.transcription_model,
torch_dtype=self.CONFIG.torch_dtype,
device=self.CONFIG.device,
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
# token=self.CONFIG.token,
)
self.en_transcription_pipeline = pipeline(
task="automatic-speech-recognition",
model=self.CONFIG.en_transcription_model,
torch_dtype=self.CONFIG.torch_dtype,
device=self.CONFIG.device,
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)
self.diarization_pipeline = whisperx.DiarizationPipeline(model_name=self.CONFIG.diarization_model,
use_auth_token=self.CONFIG.token,
device=self.CONFIG.device)
self.punctuation_pipeline = pipeline(
task = 'ner',
model=self.CONFIG.punctuation_model,
device=self.CONFIG.device,
# token=self.CONFIG.token,
)
def get_raw_transcription(self, audio_path: Union[str, numpy.ndarray], language: str ='bn') -> str:
'''
Get raw audio transcription of an audio path or audio file.
Arguements:
-----------
audio_path (str or numpy.ndarray): Path to an audio file or a numpy array of audio.
language (str, Optional): Language to use for transcription. 'bn' for Bengali and 'en' for English.
Returns:
--------
Transcripted string with punctuation.
'''
if language == 'bn':
transcriptions = self.transcription_pipeline(audio_path,
batch_size=self.CONFIG.batch_size,
chunk_length_s=self.CONFIG.chunk_length_s,
return_timestamps=False,
)
if type(transcriptions) == dict:
transcriptions = post_process_bn(transcriptions['text'])
elif type(transcriptions) == list:
for transcription in transcriptions:
transcription['text'] = post_process_bn(transcription['text'])
elif language != 'bn':
transcriptions = self.en_transcription_pipeline(audio_path,
batch_size=self.CONFIG.batch_size,
chunk_length_s=self.CONFIG.chunk_length_s,
return_timestamps=False,
)
if type(transcriptions) == dict:
transcriptions = transcriptions['text']
elif type(transcriptions) == list:
for transcription in transcriptions:
transcription['text'] = transcription['text']
torch.cuda.empty_cache()
gc.collect()
return transcriptions
def create_conversation(self,audio_path: str,language: str ='bn') -> list:
'''
Diarize the audio file, transcribe with punctuations and generate conversation.
Arguements:
-----------
audio_path (str): Path to audio file. Only path is allowed, numpy array of audio file won't work.
language (str, Optional): Language to use for transcription. 'bn' for Bengali and 'en' for English.
Returns:
--------
List of list of strings. Each list in the entire list contains 2 strings, first one is the speaker tag and the second one is the transcribed string.
'''
segments, _, speakers = get_segments(audio_path=audio_path, diarization_pipeline=self.diarization_pipeline)
diarize = []
diarized = self.get_raw_transcription(segments, language=language)
del segments
for speaker, transcription in zip(speakers, diarized):
if transcription['text'] != '':
diarize.append([speaker, transcription['text']])
torch.cuda.empty_cache()
gc.collect()
return diarize
def get_keywords(self, audio_path: Union[str, numpy.ndarray], keywords: list = CONFIG.keywords, language: str='bn') -> dict:
'''
Count specified keywords from the transcription and return frequency of each words.
Arguements:
-----------
audio_path (str or numpy.ndarray): Path to an audio file or a numpy array of audio.
keywords (list, Optional): List of words to search for in the transcription.
language (str, Optional): Language to use for transcription. 'bn' for Bengali and 'en' for English.
Returns:
--------
Dictionary of keys and frequency of each keys in the transcribed text.
'''
sentence = self.get_raw_transcription(audio_path,language=language)
tokens = list(set(sentence.split()))
distance = []
keys = []
key_dict = {}
for keyword in keywords:
distance = []
count = 0
distance.append([1 - levenshtein_distance(token,keyword)/(max(len(token),len(keyword))) for token in tokens])
for key in distance:
for index, value in enumerate(key):
if value >= self.CONFIG.threshold:
count += 1
keys.append(tokens[index])
key_dict[keyword] = count
torch.cuda.empty_cache()
gc.collect()
return {'keys':keys, 'count':key_dict}
if __name__=='__main__':
agent = TranscriberAgent()
print(agent.get_raw_transcription('test_audio_file.wav'))
print(agent.create_conversation('test_audio_file.wav'))
print(agent.get_keywords('test_audio_file.wav'))