# PytorchLightning version of training Notebook
- easy support for multiple GPUs
- easy support for fp16 training

In [1]:
from datasets import load_dataset, load_metric
import pandas as pd

#Set all sources of data
commonvoice = "data/commonvoice/train.csv"
singlespeaker = "data/singlespeaker/train.csv"
speechcollector = "data/speechcollector/train.csv"
voxpopuli = "data/fi/train.csv"

test = "data/commonvoice/test.csv"

train_df = pd.concat([pd.read_csv(commonvoice), pd.read_csv(singlespeaker), pd.read_csv(speechcollector), pd.read_csv(voxpopuli)])
test_df = pd.read_csv(test)

print(f"Training set contains {len(train_df)} Samples")
train_df.head()

Training set contains 15920 Samples


Unnamed: 0,path,sentence
0,/home/sampo/.cache/huggingface/datasets/downlo...,Mitä nyt tekisimme?
1,/home/sampo/.cache/huggingface/datasets/downlo...,Äänestämme tämän vuoksi toisin kuin maataloude...
2,/home/sampo/.cache/huggingface/datasets/downlo...,"Rupeatko remmiin, vai et?"
3,/home/sampo/.cache/huggingface/datasets/downlo...,Äänestin näin ollen mietinnön puolesta.
4,/home/sampo/.cache/huggingface/datasets/downlo...,"Kiitos, että tulitte ja opetitte meille viisau..."


In [2]:
import random
import pandas as pd
from IPython.display import display, HTML
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\...\…\–\é]'

def custom_remove_special_characters(sent):
    sent = re.sub(chars_to_ignore_regex, '', sent).lower() + " "
    return sent

train_df['sentence'] = train_df['sentence'].apply(custom_remove_special_characters)
test_df['sentence'] = test_df['sentence'].apply(custom_remove_special_characters)
train_df.head()

Unnamed: 0,path,sentence
0,/home/sampo/.cache/huggingface/datasets/downlo...,mitä nyt tekisimme
1,/home/sampo/.cache/huggingface/datasets/downlo...,äänestämme tämän vuoksi toisin kuin maataloude...
2,/home/sampo/.cache/huggingface/datasets/downlo...,rupeatko remmiin vai et
3,/home/sampo/.cache/huggingface/datasets/downlo...,äänestin näin ollen mietinnön puolesta
4,/home/sampo/.cache/huggingface/datasets/downlo...,kiitos että tulitte ja opetitte meille viisaut...


In [3]:
import itertools

def get_chars(df):
    return set(itertools.chain(*[list(x) for x in df['sentence'].values]))

vocab_list = list(get_chars(train_df).union(get_chars(test_df)))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
print(vocab_dict)

{'x': 0, 'o': 1, 'v': 2, 'e': 3, 'k': 4, 'ö': 5, 'h': 6, 'm': 7, 'i': 8, 's': 9, ' ': 10, 't': 11, 'u': 12, 'l': 13, 'ä': 14, 'c': 15, 'a': 16, 'f': 17, 'r': 18, 'b': 19, 'd': 20, 'å': 21, 'q': 22, 'p': 23, 'z': 24, 'g': 25, 'j': 26, 'w': 27, 'y': 28, 'n': 29}


In [4]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

#for key in vocab_dict.keys():
#    if key != "[PAD]":
#        vocab_dict[key] +=1

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [5]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [6]:
import torch
import numpy as np
import torchaudio
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import librosa

def resample(audio, source_sr, target_sr = 16000):
    audio = librosa.resample(np.asarray(audio), source_sr, target_sr)
    return audio

class CTCDataset(Dataset):
    """
    Dataset class used for Speech recognition with ctc loss
    enables precomputing data as arrays or transforming on the fly
    if dataset does not fit into ram
    """
    def __init__(self, dataframe, processor, mode="otf"):
        
        self.data = dataframe
        self.data.sort_values(by="sentence", key=lambda x: x.str.len(), inplace=True, ascending=False)
        self.processor = processor
        self.mode = mode
        if mode!="otf":
            raise NotImplemented
    
    def _processaudio(self, path):
        data, sr = torchaudio.load(path)
        data = data[0].numpy()
        data = resample(data, sr, 16000)
        
        return data
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):     
        if self.mode == 'otf':
            sent = self.data.iloc[idx, 1]
            data = self._processaudio(self.data.iloc[idx, 0])
            return data, sent
        
    def _precompute(self):
        pass
    
    def reorder_df(self):
        pass
    
def collate_fn_otf(batch):
    """
    collate function used for training and loading audio data on the fly
    """
    
    lists = list(zip(*batch))
    inputs = processor(lists[0], sampling_rate=16_000, return_tensors="pt", padding=True)
    with processor.as_target_processor():
        labels = processor(lists[1], padding=True, return_tensors="pt").input_ids
    return inputs.input_values, inputs.attention_mask, labels

trainset = CTCDataset(train_df, processor)
testset = CTCDataset(test_df, processor)

trainloader = DataLoader(trainset, batch_size = 4, collate_fn = collate_fn_otf, num_workers=4)
testloader = DataLoader(testset, batch_size=1, collate_fn = collate_fn_otf, num_workers=4)

# Make Lightning DataModule

In [9]:
from transformers import Wav2Vec2ForCTC
import pytorch_lightning as pl

def decode_output(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    pred = processor.batch_decode(pred_ids)
    return pred[0]

class Wav2Vec2Module(pl.LightningModule):
    
    def __init__(self, ref_sentences, device="cuda"):
        super().__init__()
        

        self.model = Wav2Vec2ForCTC.from_pretrained(
            "facebook/wav2vec2-large-100k-voxpopuli",
            attention_dropout=0.1,
            hidden_dropout=0.1,
            feat_proj_dropout=0.0,
            mask_time_prob=0.05,
            layerdrop=0.1,
            gradient_checkpointing=True,
            ctc_loss_reduction="mean",
            pad_token_id=processor.tokenizer.pad_token_id,
            vocab_size=len(processor.tokenizer),
            ctc_zero_infinity=False
        )
        self.model.train()
        self.model.freeze_feature_extractor()
        self.best_wer = 1.0
        self.predictions = []
        #model.to(device)
        self.ref_sentences = ref_sentences
        
    def forward(self, inputs, masks, labels):
        output = self.model(inputs, masks, labels=labels)
        return output
        
    def training_step(self, train_batch, batch_idx):
        inputs, masks, labels = train_batch
        loss = self(inputs, masks, labels)
        return loss.loss
    
    def validation_step(self, val_batch, batch_idx):
        inputs, masks, labels = val_batch
        output = self(inputs, masks, labels)
        loss = output.loss.item()
        pred = decode_output(output.logits)
        self.predictions.append(pred)
        return loss
    
    def configure_optimizers(self):
        lr = 0.0002
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        scheduler = scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, 
                                   steps_per_epoch=int(len(trainloader)/2),
                                   epochs = 30,
                                   anneal_strategy='linear')
        return [optimizer], [scheduler]
    

In [10]:
from pytorch_lightning.callbacks import Callback
from datasets import load_metric
from pytorch_lightning.callbacks import ModelCheckpoint


checkpoint_callback = ModelCheckpoint(monitor="wer", dirpath="lightning_run", filename="large_{wer:02f}", save_top_k=1, mode="min")
    

class EvalCallbacks(Callback):
    
    def on_validation_epoch_end(self, trainer, pl_module):
        wer = load_metric("wer")
        print(pl_module.predictions)
        print(pl_module.ref_sentences)
        wer_c = wer.compute(predictions=pl_module.predictions, references=pl_module.ref_sentences.to_list())
        if wer_c < pl_module.best_wer:
            pl_module.best_wer = wer_c
            print(f"new best wer: {wer_c}")
        pl_module.log("wer", wer_c)
        pl_module.predictions = []
        
        
        
m = Wav2Vec2Module(testset.data.sentence)
trainer = pl.Trainer(callbacks=[checkpoint_callback, EvalCallbacks()], accumulate_grad_batches=2, max_epochs=30, precision=16, gpus=1, num_sanity_val_steps=0)
trainer.fit(m, trainloader, testloader)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-100k-voxpopuli and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type           | Params
-----------------------------------------
0 | model | Wav2Vec2ForCTC | 315 M 
-----------------------------------------
311 M     Trainable params
4.2 M     Non-trainable params
315 M     Total params
1,261.886 Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '',

Traceback (most recent call last):
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


In [35]:
wer = load_metric("wer")
wer.compute(predictions=["moi", "hei"], references=["moi", "hei"])

0.0