In [1]:
%%capture
! pip install transformers
! pip install jiwer
! pip install --upgrade wandb
! pip install --upgrade librosa

In [2]:
import os
import json
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor
) 

import wandb
import librosa

import warnings
warnings.simplefilter('ignore')

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [3]:
Config = {
    'audio_dir': '/kaggle/input/bengaliai-speech/train_mp3s',
    'model_name': 'facebook/wav2vec2-base',
    'lr': 3e-4,
    'wd': 1e-5,
    'T_0': 10,
    'T_mult': 2,
    'eta_min': 1e-6,
    'nb_epochs': 2,
    'train_bs': 16,
    'valid_bs': 16,
    'sampling_rate': 16000,
    '_wandb_kernel': 'tanaym',
}

In [4]:
# W&B Logging
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wb_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wb_key)

run = wandb.init(
    project='pytorch',
    config=Config,
    group='asr',
    job_type='train',
)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mxanderex[0m ([33mpilotteams[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.15.8
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20230812_161703-bvjsr2rs[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mprime-sponge-5[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/pilotteams/pytorch[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/pilotteams/pytorch/runs/bvjsr2rs[0m


In [5]:
def read_audio(mp3_path, target_sr=16000):
    """
    Loads an mp3 audio file and resamples it to 16kHz 
    Required for needed for Wav2Vec2 training
    """
    audio, sr = librosa.load(mp3_path, sr=32000)
    audio_array = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
    return audio_array

def construct_vocab(texts):
    """
    Get unique characters from all the text in a list
    """
    all_text = " ".join(texts)
    vocab = list(set(all_text))
    return vocab

def wandb_log(**kwargs):
    for k, v in kwargs.items():
        wandb.log({k: v})

def save_vocab(dataframe):
    """
    Saves the processed vocab file as 'vocab.json', to be ingested by tokenizer
    """
    vocab = construct_vocab(dataframe['sentence'].tolist())
    vocab_dict = {v: k for k, v in enumerate(vocab)}
    vocab_dict["__"] = vocab_dict[" "]
    _ = vocab_dict.pop(" ")
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)

    with open('vocab.json', 'w') as fl:
        json.dump(vocab_dict, fl)

    print("Created Vocab file!")

In [6]:
class ASRDataset(Dataset):
    def __init__(self, df, config, is_test=False):
        self.df = df
        self.config = config
        self.is_test = is_test
    
    def __getitem__(self, idx):
        # First read and pre-process the audio file
        audio = read_audio(self.df.loc[idx]['path'])
        audio = processor(
            audio, 
            sampling_rate=self.config['sampling_rate']
        ).input_values[0]
        
        # Return -1 for label if in test-only mode
        if self.is_test:
            return {'audio': audio, 'label': -1}
        else:
            # If we are training/validating, also process the labels (actual sentences)
            with processor.as_target_processor():
                labels = processor(self.df.loc[idx]['sentence']).input_ids
            return {'audio': audio, 'label': labels}
        
    def __len__(self):
        return len(self.df)
    
def ctc_data_collator(batch):
    """
    Custom data collator function to dynamically pad the data
    """
    input_features = [{"input_values": sample["audio"]} for sample in batch]
    label_features = [{"input_ids": sample["label"]} for sample in batch]
    batch = processor.pad(
        input_features,
        padding=True,
        return_tensors="pt",
    )
    with processor.as_target_processor():
        labels_batch = processor.pad(
            label_features,
            padding=True,
            return_tensors="pt",
        )
        
    labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
    batch["labels"] = labels
    return batch

In [7]:
def train_one_epoch(model, train_loader, optimizer, scheduler, device='cuda:0'):
    model.train()
    pbar = tqdm(train_loader, total=len(train_loader))
    avg_loss = 0
    for data in pbar:
        data = {k: v.to(device) for k, v in data.items()}
        loss = model(**data).loss
        loss_itm = loss.item()
        
        avg_loss += loss_itm
        pbar.set_description(f"loss: {loss_itm:.4f}")
        wandb_log(train_step_loss=loss_itm)
        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    return avg_loss / len(train_loader)

@torch.no_grad()
def valid_one_epoch(model, valid_loader, device='cuda:0'):
    pbar = tqdm(valid_loader, total=len(valid_loader))
    avg_loss = 0
    for data in pbar:
        data = {k: v.to(device) for k, v in data.items()}
        loss = model(**data).loss
        loss_itm = loss.item()
        
        avg_loss += loss_itm
        pbar.set_description(f"val_loss: {loss_itm:.4f}")
        wandb_log(valid_step_loss=loss_itm)

    return avg_loss / len(valid_loader)

In [8]:
if __name__ == "__main__":
    # Read in the dataframe and split by training and validation splits
    df = pd.read_csv("/kaggle/input/bengaliai-speech/train.csv")
    
    # Get a paths feature for reading in during dataloading
    df['path'] = df['id'].apply(lambda x: os.path.join(Config['audio_dir'], x+'.mp3'))
    train_df = df[df['split'] == 'train'].sample(frac=.15).reset_index(drop=True)
    valid_df = df[df['split'] == 'valid'].sample(frac=.15).reset_index(drop=True)
    print(f"Training on samples: {len(train_df)}, Validation on samples: {len(valid_df)}")

    # Construct and save the vocab file
    save_vocab(df)
    
    # Init the tokenizer, feature_extractor, processor and model
    tokenizer = Wav2Vec2CTCTokenizer(
        "./vocab.json", 
        unk_token="[UNK]",
        pad_token="[PAD]",
        word_delimiter_token="__"
    )
    feature_extractor = Wav2Vec2FeatureExtractor(
        feature_size=1, 
        sampling_rate=Config['sampling_rate'], 
        padding_value=0.0, 
        do_normalize=True, 
        return_attention_mask=False
    )
    processor = Wav2Vec2Processor(
        feature_extractor=feature_extractor, 
        tokenizer=tokenizer
    )

    model = Wav2Vec2ForCTC.from_pretrained(
        Config['model_name'],
        ctc_loss_reduction="mean", 
        pad_token_id=processor.tokenizer.pad_token_id,
        vocab_size = len(tokenizer),
    )
    wandb.watch(model)
    
    # Freeze the feature encoder part since we won't be training it
    model.to('cuda')
    model.freeze_feature_encoder()
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=Config['lr'], 
        weight_decay=Config['wd']
    )
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=Config['T_0'],
        T_mult=Config['T_mult'],
        eta_min=Config['eta_min']
    )
    
    # Construct training and validation dataloaders
    train_ds = ASRDataset(train_df, Config)
    valid_ds = ASRDataset(valid_df, Config)
    
    train_loader = DataLoader(
        train_ds, 
        batch_size=Config['train_bs'], 
        collate_fn=ctc_data_collator, 
    )
    valid_loader = DataLoader(
        valid_ds,
        batch_size=Config['valid_bs'],
        collate_fn=ctc_data_collator,
    )
    
    # Train the model
    best_loss = float('inf')
    for epoch in range(Config['nb_epochs']):
        print(f"{'='*40} Epoch: {epoch+1} / {Config['nb_epochs']} {'='*40}")
        train_loss = train_one_epoch(model, train_loader, optimizer, scheduler)
        valid_loss = valid_one_epoch(model, valid_loader)
        wandb_log(train_loss=train_loss, val_loss=valid_loss)
        print(f"train_loss: {train_loss:.4f}, valid_loss: {valid_loss:.4f}")
        
        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), f"wav2vec2_base_bengaliAI.pt")
            print(f"Saved the best model so far with val_loss: {valid_loss:.4f}")

Training on samples: 140107, Validation on samples: 4438
Created Vocab file!


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForCTC: ['project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors', 'project_hid.weight', 'project_hid.bias', 'project_q.bias']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predicti



  0%|          | 0/8757 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

train_loss: 3.4447, valid_loss: 3.4090
Saved the best model so far with val_loss: 3.4090


  0%|          | 0/8757 [00:00<?, ?it/s]

  0%|          | 0/278 [00:00<?, ?it/s]

train_loss: 3.4293, valid_loss: 3.4092


In [9]:
wandb.finish()

[34m[1mwandb[0m: Waiting for W&B process to finish... [32m(success).[0m
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:      train_loss █▁
[34m[1mwandb[0m: train_step_loss ▂▁▃▄▁▃▃▂▂▂▃▁▃▂▂▃▃▄▃▄▃▂▂▂▁▂▂▃▂▂▂▃▃▂▂▂▃▃▃█
[34m[1mwandb[0m:        val_loss ▁█
[34m[1mwandb[0m: valid_step_loss ▃▁▃▃▃▂▁█▂▃▁▂▃▃▄▄▆▄▅▁▃▃▂▂▃▄▂▃▃▃▃▁▄▂▄▃▃▂▂▃
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:      train_loss 3.42929
[34m[1mwandb[0m: train_step_loss 3.40529
[34m[1mwandb[0m:        val_loss 3.40917
[34m[1mwandb[0m: valid_step_loss 3.48096
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mprime-sponge-5[0m at: [34m[4mhttps://wandb.ai/pilotteams/pytorch/runs/bvjsr2rs[0m
[34m[1mwandb[0m: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20230812_161703-bvjsr2rs/logs[0m
