# Introduction

Continuation of the pseudolabeling pipeline described in https://www.kaggle.com/code/reasat/pseudolabeling-step-1-download-speech-audio

Model weight and inference notebook copied from: https://www.kaggle.com/competitions/bengaliai-speech/discussion/447970

## STT Model:

* OpenAI whisper-medium
* Huggingface trainer
* Trained on 8x 48GB RTX A6000
* bs=8 and lr=1e-5
* Train steps 50k
* Spectrogram dithering
* Spectrogram time and frequency masking
* Resampling 16khz->8khz->16khz as augmentation
* Inference with max_length=260, num_beams=4 and chunk_length_s=20.1s
* Libsonic based speed/pitch augmentation
* Datasets: OpenSLR 37, OpenSLR 53, MadASR, Shrutilipi, Macro, Kathbath, GoogleTTS generated audios and pseudo labeled YouTube videos

## Punctuation Model:

* AutoModelForTokenClassification google/muril-base-cased
* Huggingface trainer
* Labels: period, comma and question mark
* bs=64, lr=2e-4 and max_seq_length=512
* Ensemble of 4 models (using 6, 8, 11 and 12 layers of google/muril-base-cased)
* Normalized IndicCorp v2 Bangla dataset



In [25]:
import jiwer  # you may need to install this library
domain_weights = {
        'Barishal': 0.125,
        'Chittagong': 0.083,
        'Habiganj': 0.125,
        'Kishoreganj': 0.083,
        'Narail': 0.083, 
        'Narsingdi': 0.083,
        'Rangpur': 0.083,
        'Sylhet': 0.125,
        'Sandwip': 0.125,
        'Tangail': 0.083,
    }
domain_weights = { key.lower(): value for key, value in domain_weights.items()}
# unseen: Habiganj, Barishal, Sylhet, Sandwip
def mean_wer(solution, submission):
    joined = solution.merge(submission.rename(columns={'sentence': 'predicted'}))
#     print(joined)
    domain_scores = joined.groupby('domain').apply(
        # note that jiwer.wer computes a weighted average wer by default when given lists of strings
        lambda df: jiwer.wer(df['sentence'].to_list(), df['predicted'].to_list()),
    )
    print(domain_scores)
    for key, value in domain_weights.items():
        domain_scores.loc[key] = domain_scores.loc[key].item()*value
    print(domain_scores)
    return domain_scores.sum()

In [1]:
import os
import csv
import time
import glob

dir_root = '/home/phison/LargeFiles/'

# MODEL = '/kaggle/input/bengali-ai-asr-submission/bengali-whisper-medium/'
PUNCT_MODELS = [
    '/home/phison/LargeFiles/bengali-ai-asr-submission/punct-model-6layers/',
    '/home/phison/LargeFiles/bengali-ai-asr-submission/punct-model-8layers/',
    '/home/phison/LargeFiles/bengali-ai-asr-submission/punct-model-11layers/',
    '/home/phison/LargeFiles/bengali-ai-asr-submission/punct-model-12layers/'
]
PUNCT_WEIGHTS = [[1.0, 1.4, 1.0, 0.8]]

CHUNK_LENGTH_S = 20.1
ENABLE_BEAM = True



if ENABLE_BEAM:
    BATCH_SIZE = 32*8
else:
    BATCH_SIZE = 32*8

DATASET_PATH = '/home/phison/LargeFiles/iut-comp-dataset/16_kHz_test_audio'
# MODEL = 'BengaliAI/tugstugi_bengaliai-asr_whisper-medium'
MODEL = '/home/phison/LargeFiles/regional-asr/whisper-base-bn'

In [2]:
import csv
import pandas as pd
import glob
import shutil
import librosa
import argparse
import warnings
from pathlib import Path
import transformers
print(transformers.__version__)
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer

import warnings

warnings.filterwarnings("ignore")

# files = []
df_sub = pd.read_csv(dir_root+'/iut-comp-dataset/test.csv')
print(df_sub.head())
df_sub['paths'] = df_sub['file_name'].apply(lambda x: os.path.join(DATASET_PATH, x))
files = df_sub['paths'].to_list()
# files += list(glob.glob(DATASET_PATH + '/' + '*.mp3'))
print('files', len(files))
print(files[:3])
# NOTE: running on a few samples for demonstration
# files = files[:10]

# files.sort()

4.41.2
              file_name                                        transcripts  \
0  test_sandwip (1).wav  হুম, আই দাওয়াত খাইবো। খাইবো অনকা তো বুজি <> হব...   
1  test_sandwip (2).wav  এন্নে বিয়া-সাদি চুয়াইছে, ঘরে সংসারও চইলছে, হুম...   
2  test_sandwip (3).wav  হুম, হাক্কন হিডা ন নি হারা যাইবো নে? এক্কবারে ...   
3  test_sandwip (4).wav  আম-কাডল <> ধর ইয়া আছে, হল আছে তারফরে দুধ আছে আ...   
4  test_sandwip (5).wav  এই তুই গোসল কইত্তি ন? হিয়া হরে কইরবো এরি। গোসল...   

  district  
0  sandwip  
1  sandwip  
2  sandwip  
3  sandwip  
4  sandwip  
files 1700
['/home/phison/LargeFiles/iut-comp-dataset/16_kHz_test_audio/test_sandwip (1).wav', '/home/phison/LargeFiles/iut-comp-dataset/16_kHz_test_audio/test_sandwip (2).wav', '/home/phison/LargeFiles/iut-comp-dataset/16_kHz_test_audio/test_sandwip (3).wav']


In [3]:
pipe = pipeline(task="automatic-speech-recognition",
                model=MODEL,
                tokenizer=MODEL,
                chunk_length_s=CHUNK_LENGTH_S, device=0, 
#                 batch_size=BATCH_SIZE
               )
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="bn", task="transcribe")

print("model loaded!")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


model loaded!


In [4]:
def fix_repetition(text, max_count):
    uniq_word_counter = {}
    words = text.split()
    for word in text.split():
        if word not in uniq_word_counter:
            uniq_word_counter[word] = 1
        else:
            uniq_word_counter[word] += 1

    for word, count in uniq_word_counter.items():
        if count > max_count:
            words = [w for w in words if w != word]
    text = " ".join(words)
    return text

In [5]:
from tqdm.auto import tqdm
import torch
import time
def batchify(inputs, batch_size):
    for i in range(0, len(inputs), batch_size):
        yield inputs[i:i + batch_size]
    
# ENABLE_BEAM = 0
generate_kwargs = {"max_length": 260, "num_beams": 4} if ENABLE_BEAM else None
start = time.time()
texts = []
for batch in batchify(files, BATCH_SIZE):
    texts+=pipe(batch, generate_kwargs = generate_kwargs)
    elapsed = time.time()-start
    print('completed: {}, avg. sec/sample: {:.2f}'.format(len(texts), elapsed/len(texts)))
print('total time: {:.2f}'.format(time.time()-start))

completed: 256, avg. sec/sample: 2.79
completed: 512, avg. sec/sample: 2.70
completed: 768, avg. sec/sample: 2.66
completed: 1024, avg. sec/sample: 2.63
completed: 1280, avg. sec/sample: 2.58
completed: 1536, avg. sec/sample: 2.57
completed: 1700, avg. sec/sample: 2.60
total time: 4413.94


In [6]:
# del pipe
import torch
models = [
    AutoModelForTokenClassification.from_pretrained(f).eval().cuda() for f in PUNCT_MODELS
]
tokenizer = AutoTokenizer.from_pretrained(PUNCT_MODELS[0])
def punctuate(text):
    input_ids = tokenizer(text).input_ids
    with torch.no_grad():
        model = models[0]
        logits = torch.nn.functional.softmax(
            model(input_ids=torch.LongTensor([input_ids]).cuda()).logits[0, 1:-1],
            dim=1).cpu()
        for model in models[1:]:
            logits += torch.nn.functional.softmax(
                model(input_ids=torch.LongTensor([input_ids]).cuda()).logits[0, 1:-1],
                dim=1).cpu()
        logits = logits / len(models)
        logits *= torch.FloatTensor(PUNCT_WEIGHTS)
        label_ids = torch.argmax(logits, dim=-1)

        tokens = tokenizer(text, add_special_tokens=False).input_ids
        punct_text = ""
        for index, token in enumerate(tokens):
            token_str = tokenizer.decode(token)
            if '##' not in token_str:
                punct_text += " " + token_str
            else:
                punct_text += token_str[2:]
            punct_text += ['', '।', ',', '?'][label_ids[index].item()]

    punct_text = punct_text.strip()
    return punct_text

In [7]:
# texts = [{'text': 'text'} for _ in range(len(files))]

In [27]:
predictions = []
with open("submission.csv", 'wt', encoding="utf8") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['file_name', 'sentence'])
    for f, text in zip(files, texts):
        file_id = Path(f).stem
        pred = text['text'].strip()
        pred = fix_repetition(pred, max_count=8)
        if len(pred) == 0:
            print('empty prediction on', f)
            pred = ' '
    
        prediction = [file_id, pred]
        writer.writerow(prediction)
        predictions.append(prediction)
print("inference finished!")

import pandas as pd
submission = pd.read_csv("submission.csv")
print(submission.head())

solution = solution.rename(columns = {'transcripts': 'sentence', 'district': 'domain'}).drop(columns = 'paths')
solution['file_name'] = solution['file_name'].apply(lambda x: x.replace('.wav', ''))
print(solution.head())

mean_wer(solution, submission)

inference finished!
          file_name                                           sentence
0  test_sandwip (1)                                        হুম। খাইবো�
1  test_sandwip (2)  <> ঘরের সংসারে চলছে। হুম। বেশি থাকলে বেশি খরচ ...
2  test_sandwip (3)  হুম। হাক্কুনিডা নোনি হারো যাইবানে এক্কবারে, এই...
3  test_sandwip (4)  আম কাডোলা এই দরো ইয়া আছে, হল আছে। তারপরে, দুধ ...
4  test_sandwip (5)  এই তো গোসল করতি না? হিয়া হরে কইরবুরি? হুম। গোস...
          file_name                                           sentence  \
0  test_sandwip (1)  হুম, আই দাওয়াত খাইবো। খাইবো অনকা তো বুজি <> হব...   
1  test_sandwip (2)  এন্নে বিয়া-সাদি চুয়াইছে, ঘরে সংসারও চইলছে, হুম...   
2  test_sandwip (3)  হুম, হাক্কন হিডা ন নি হারা যাইবো নে? এক্কবারে ...   
3  test_sandwip (4)  আম-কাডল <> ধর ইয়া আছে, হল আছে তারফরে দুধ আছে আ...   
4  test_sandwip (5)  এই তুই গোসল কইত্তি ন? হিয়া হরে কইরবো এরি। গোসল...   

    domain  
0  sandwip  
1  sandwip  
2  sandwip  
3  sandwip  
4  sandwip  
domain
barishal       0.672002


0.7032243544529495

In [28]:
predictions = []
with open("submission.csv", 'wt', encoding="utf8") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['file_name', 'sentence'])
    for f, text in zip(files, texts):
        file_id = Path(f).stem
        pred = text['text'].strip()
        pred = fix_repetition(pred, max_count=8)
        if len(pred) == 0:
            print('empty prediction on', f)
            pred = ' '
        else:
            # pass
            pred = punctuate(pred)
            if pred[-1] not in ['।', '?', ',']:
                pred = pred + '।'
        # print(i, file_id, pred)
        prediction = [file_id, pred]
        writer.writerow(prediction)
        predictions.append(prediction)
print("inference finished!")

import pandas as pd
submission = pd.read_csv("submission.csv")
print(submission.head())

solution = pd.read_csv(dir_root+'/iut-comp-dataset/test.csv')
solution = solution.rename(columns = {'transcripts': 'sentence', 'district': 'domain'})
solution['file_name'] = solution['file_name'].apply(lambda x: x.replace('.wav', ''))
print(solution.head())

mean_wer(solution, submission)

inference finished!
          file_name                                           sentence
0  test_sandwip (1)                                       হুম । খাইবো।
1  test_sandwip (2)  < > ঘরের সংসারে চলছে । হুম । বেশি থাকলে বেশি খ...
2  test_sandwip (3)  হুম । হাক্কুনিডা নোনি হারো যাইবানে এক্কবারে ,।...
3  test_sandwip (4)  আম কাডোলা এই দরো ইয়া আছে , হল আছে । তারপরে , দ...
4  test_sandwip (5)  এই তো গোসল করতি না ? হিয়া হরে কইরবুরি ? হুম । ...
          file_name                                           sentence  \
0  test_sandwip (1)  হুম, আই দাওয়াত খাইবো। খাইবো অনকা তো বুজি <> হব...   
1  test_sandwip (2)  এন্নে বিয়া-সাদি চুয়াইছে, ঘরে সংসারও চইলছে, হুম...   
2  test_sandwip (3)  হুম, হাক্কন হিডা ন নি হারা যাইবো নে? এক্কবারে ...   
3  test_sandwip (4)  আম-কাডল <> ধর ইয়া আছে, হল আছে তারফরে দুধ আছে আ...   
4  test_sandwip (5)  এই তুই গোসল কইত্তি ন? হিয়া হরে কইরবো এরি। গোসল...   

    domain  
0  sandwip  
1  sandwip  
2  sandwip  
3  sandwip  
4  sandwip  
domain
barishal       0.849765


0.9185856242790656