In [None]:
!pip install mltu
!pip install opencv-python
!pip install opencv-python-headless
!pip install onnx
!pip install torch==1.13.1+cu111
!pip install transformers==4.33.1
!pip install onnxruntime


In [None]:
import os
import pandas as pd
import torch
from torch import nn
from transformers import Wav2Vec2ForCTC
import torch.nn.functional as F
from datetime import datetime

In [None]:
import mltu
from mltu.torch.model import Model
from mltu.torch.losses import CTCLoss
from mltu.torch.dataProvider import DataProvider
from mltu.torch.metrics import CERMetric, WERMetric
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, WarmupCosineDecay

from mltu.augmentors import RandomAudioNoise, RandomAudioPitchShift, RandomAudioTimeStretch

In [None]:
from mltu.preprocessors import AudioReader
from mltu.transformers import LabelIndexer, LabelPadding, AudioPadding

In [None]:
from mltu.configs import BaseModelConfigs

class ModelConfigs(BaseModelConfigs):
    def __init__(self):
        super().__init__()
        self.model_path = os.path.join(
            "Models/10_wav2vec2_torch",
            datetime.strftime(datetime.now(), "%Y%m%d%H%M"),
        )
        self.batch_size = 8
        self.train_epochs = 60
        self.train_workers = 20

        self.init_lr = 1.0e-8
        self.lr_after_warmup = 1e-05
        self.final_lr = 5e-06
        self.warmup_epochs = 10
        self.decay_epochs = 40
        self.weight_decay = 0.005
        self.mixed_precision = True

        self.max_audio_length = 246000
        self.max_label_length = 256

        self.vocab = [' ', "'", 'a', 'b', 'c', 'd', 'e', 'ɛ', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'ɔ','p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

In [None]:
configs = ModelConfigs()

Dataset Path

In [None]:
dataset_path = "Datasets/ahshanti_wav"
metadata_path = os.path.join(dataset_path, "data.csv")
wavs_path = os.path.join(dataset_path, "wavs")

Read metadata file and parse it

In [None]:
metadata_df = pd.read_csv(metadata_path, sep="\t", header=None, quoting=3)
dataset = []
vocab = [' ', "'", 'a', 'b', 'c', 'd', 'e','ɛ', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'ɔ','p', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

In [None]:
metadata_df['Audio Filepath'] = metadata_df['Audio Filepath'].str.replace('.ogg', '.wav')

Remove the 'Unnamed: 0' column if it exists

In [None]:
if 'Unnamed: 0' in metadata_df.columns:
    metadata_df.drop(columns=['Unnamed: 0'], inplace=True)

In [None]:
for file_name, transcription, _ in metadata_df.values.tolist():
    path = f"Datasets/ashanti_wav/wavs/{file_name}.wav"
    new_label = "".join([l for l in transcription.lower() if l in vocab])
    dataset.append([path, new_label])

Create a data provider for the dataset

In [None]:
data_provider = DataProvider(
    dataset=dataset,
    skip_validation=True,
    batch_size=configs.batch_size,
    data_preprocessors=[
        AudioReader(sample_rate=16000),
        ],
    transformers=[
        LabelIndexer(vocab),],
    use_cache=False,
    batch_postprocessors=[
        AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True),
        LabelPadding(padding_value=len(vocab), use_on_batch=True),
    ],
    use_multiprocessing=True,
    max_queue_size=10,
    workers=configs.train_workers,
)

In [None]:
train_dataProvider, test_dataProvider = data_provider.split(split=0.9)

train_dataProvider.augmentors = [<br>
        RandomAudioNoise(), <br>
        RandomAudioPitchShift(), <br>
        RandomAudioTimeStretch()<br>
    ]

In [None]:
vocab = sorted(vocab)
configs.vocab = vocab
configs.save()

In [None]:
class CustomWav2Vec2Model(nn.Module):
    def __init__(self, hidden_states, dropout_rate=0.2, **kwargs):
        super(CustomWav2Vec2Model, self).__init__( **kwargs)
        pretrained_name = "facebook/wav2vec2-base-960h"
        self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_name, vocab_size=hidden_states, ignore_mismatched_sizes=True)
        self.model.freeze_feature_encoder() # this part does not need to be fine-tuned
    def forward(self, inputs):
        output = self.model(inputs, attention_mask=None).logits
        # Apply softmax
        output = F.log_softmax(output, -1)
        return output

In [None]:
custom_model = CustomWav2Vec2Model(hidden_states = len(vocab)+1)

put on cuda device if available

In [None]:
if torch.cuda.is_available():
    custom_model = custom_model.cuda()

create callbacks

In [None]:
warmupCosineDecay = WarmupCosineDecay(
    lr_after_warmup=configs.lr_after_warmup,
    warmup_epochs=configs.warmup_epochs,
    decay_epochs=configs.decay_epochs,
    final_lr=configs.final_lr,
    initial_lr=configs.init_lr,
    verbose=True,
)

In [None]:
tb_callback = TensorBoard(configs.model_path + "/logs")

In [None]:
earlyStopping = EarlyStopping(monitor="val_CER", patience=16, mode="min", verbose=1)

In [None]:
modelCheckpoint = ModelCheckpoint(configs.model_path + "/model.pt", monitor="val_CER", mode="min", save_best_only=True, verbose=1)

In [None]:
model2onnx = Model2onnx(
    saved_model_path=configs.model_path + "/model.pt",
    input_shape=(1, configs.max_audio_length),
    verbose=1,
    metadata={"vocab": configs.vocab},
    dynamic_axes={"input": {0: "batch_size", 1: "sequence_length"}, "output": {0: "batch_size", 1: "sequence_length"}}
)

create model object that will handle training and testing of the network

In [None]:
model = Model(
    custom_model,
    loss = CTCLoss(blank=len(configs.vocab), zero_infinity=True),
    optimizer = torch.optim.AdamW(custom_model.parameters(), lr=configs.init_lr, weight_decay=configs.weight_decay),
    metrics=[
        CERMetric(configs.vocab),
        WERMetric(configs.vocab)
    ],
    mixed_precision=configs.mixed_precision,
)

Save training and validation datasets as csv files

In [None]:
train_dataProvider.to_csv(os.path.join(configs.model_path, "train.csv"))
test_dataProvider.to_csv(os.path.join(configs.model_path, "val.csv"))

In [None]:
model.fit(
    train_dataProvider,
    test_dataProvider,
    epochs=configs.train_epochs,
    callbacks=[
        warmupCosineDecay,
        tb_callback,
        earlyStopping,
        modelCheckpoint,
        model2onnx
    ]
)