In [1]:
import datasets, transformers, huggingface_hub as hub
from datasets import load_dataset

In [42]:
import torch
import numpy as np
from torch import nn

In [3]:
import loaders
from DatasetMixer import DatasetMixer, DatasetMixerConfig

In [219]:
def get_common_voice(lang, streaming=False) -> datasets.DatasetDict:
    """
    Loads Ukrainian Common Voice dataset from here
    https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0
    makes it compatible with DatasetMixer
    :param lang: the language of common voice dataset. Should be in ["en", "uk"]
    :param streaming (optional): pass this parameter to load_dataset function 
    """
    uk_speech_dataset = load_dataset("mozilla-foundation/common_voice_11_0", lang, streaming=streaming, split="validation")
    uk_speech_dataset = uk_speech_dataset.rename_columns({"client_id": "speaker_id", "sentence": "transcription"})
    return uk_speech_dataset

In [222]:
speech_dataset = loaders.get_common_voice("en", streaming=True)

In [223]:
from datasets import Audio

In [224]:
speech_dataset = speech_dataset.cast_column("audio", Audio(16000))

In [225]:
from transformers import WhisperFeatureExtractor, WhisperForConditionalGeneration

In [261]:
class WhisperBaseEncoderForCTC(nn.Module):
    def __init__(self, num_tokens, sampling_rate=16000):
        super().__init__()
        self.sampling_rate = sampling_rate
        self.feature_extractor = WhisperFeatureExtractor(return_attention_mask=True, sampling_rate=16000)
        self.whisper_encoder = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").get_encoder()
        self.linear = nn.Linear(512, num_tokens)
    
    def forward(self, X: list[np.ndarray], device="cuda"):
        features = self.feature_extractor(X, sampling_rate=self.sampling_rate)
        features["input_features"] = torch.tensor(np.stack(features["input_features"])).to(device)
        features["attention_mask"] = torch.tensor(features["attention_mask"]).to(device)
        features_input_size = features["attention_mask"].shape[1]
        encoder_output = self.whisper_encoder(**features)
        encoder_output_size = encoder_output.last_hidden_state.shape[1]
        Y = self.linear(encoder_output.last_hidden_state)
        mask = features["attention_mask"][:,::features_input_size // encoder_output_size]
        return Y.transpose(0, 1), mask

In [297]:
WhisperFeatureExtractor(return_attention_mask=True)(example1["audio"]["array"])

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


{'input_features': [array([[-0.7125068 , -0.7125068 , -0.66656506, ..., -0.7125068 ,
        -0.7125068 , -0.7125068 ],
       [-0.7125068 , -0.7125068 , -0.6105869 , ..., -0.7125068 ,
        -0.7125068 , -0.7125068 ],
       [-0.7125068 , -0.7125068 , -0.4616264 , ..., -0.7125068 ,
        -0.7125068 , -0.7125068 ],
       ...,
       [-0.7125068 , -0.7125068 , -0.42897058, ..., -0.7125068 ,
        -0.7125068 , -0.7125068 ],
       [-0.7125068 , -0.7125068 , -0.57720065, ..., -0.7125068 ,
        -0.7125068 , -0.7125068 ],
       [-0.7125068 , -0.7125068 , -0.5909358 , ..., -0.7125068 ,
        -0.7125068 , -0.7125068 ]], dtype=float32)], 'attention_mask': array([[1, 1, 1, ..., 0, 0, 0]])}

In [306]:
def generate(X: np.ndarray):
    with torch.no_grad():
        features_extraction = model.feature_extractor(X)
        features = features_extraction["input_features"][0]
        features = torch.tensor(features)
        attention_mask = features_extraction["attention_mask"][0]
        features_input_size = attention_mask.shape[0]
        features = torch.unsqueeze(features, dim=0).cuda()
        output = model.whisper_encoder(features).last_hidden_state
        encoder_output_size = output.shape[1]
        attention_length = torch.tensor(attention_mask[::features_input_size // encoder_output_size]).sum()
        Y = model.linear(output)[0][:attention_length,:]
        Y = torch.argmax(Y, dim=1).cpu().numpy()
        return tokenizer.decode(Y)
        

In [307]:
generate(example1["audio"]["array"])

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


'THE TR EA A P R S ON THE C OM P I L A T ION AL B O M C R E F OR K S <unk>'

In [308]:
example1["transcription"]

'The track appears on the compilation album "Kraftworks".'

In [262]:
model = WhisperBaseEncoderForCTC(32).to("cuda")

In [263]:
from transformers import Wav2Vec2CTCTokenizer

In [264]:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-100h")

In [265]:
tokenizer.vocab

{"'": 27,
 '</s>': 2,
 '<pad>': 0,
 '<s>': 1,
 '<unk>': 3,
 'A': 7,
 'B': 24,
 'C': 19,
 'D': 14,
 'E': 5,
 'F': 20,
 'G': 21,
 'H': 11,
 'I': 10,
 'J': 29,
 'K': 26,
 'L': 15,
 'M': 17,
 'N': 9,
 'O': 8,
 'P': 23,
 'Q': 30,
 'R': 13,
 'S': 12,
 'T': 6,
 'U': 16,
 'V': 25,
 'W': 18,
 'X': 28,
 'Y': 22,
 'Z': 31,
 '|': 4}

In [266]:
from torch.nn.functional import ctc_loss

In [267]:
def whisper_ctc_loss(logits, logits_mask, tokens, tokens_mask, blank):
    log_softmax_vectors = torch.log_softmax(logits, dim=2)
    logits_lengths = torch.sum(logits_mask, dim=1)
    tokens = torch.tensor(tokens)
    tokens_lengths = torch.sum(torch.tensor(tokens_mask), dim=1)
    loss = ctc_loss(
        log_probs=log_softmax_vectors,
        targets=tokens,
        input_lengths=logits_lengths,
        target_lengths=tokens_lengths,
        blank=blank,
    )
    return loss

In [268]:
def default_list_speech_collator(examples):
    new_examples = {
        "arrays": [examples[i]["audio"]["array"] for i in range(len(examples))],
        "transcriptions": [examples[i]["transcription"] for i in range(len(examples))],
        "sampling_date": examples[0]["audio"]["sampling_rate"],
    }
    return new_examples

In [269]:
from torch.utils.data import DataLoader, default_collate

In [270]:
loader = DataLoader(speech_dataset["train"], batch_size=4, collate_fn=default_list_speech_collator)

In [282]:
class TrainerBase:
    def __init__(self, model, dataloader, optimizer):
        self.model = model
        self.dataloader = dataloader
        self.optimizer = optimizer
    
    def loss(self, **kwargs):
        raise NotImplemented
    
    def _prepare_model_input(self, batch):
        raise NotImplemented
    
    def _prepare_loss_input(self, batch, model_output):
        raise NotImplemented
    
    def train(self, epochs):
        for epoch in range(epochs):
            losses = []
            print(f"EPOCH {epoch}")
            for i, batch in enumerate(self.dataloader):
                model_input = self._prepare_model_input(batch)
                model_output = self.model(**model_input)
                loss_input = self._prepare_loss_input(batch, model_output)
                loss = self.loss(**loss_input)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(f"loss: {loss.clone().detach().cpu().numpy()}")

In [283]:
class WhisperTrainer(TrainerBase):
    def __init__(self, model, dataloader, optimizer, tokenizer):
        super().__init__(model, dataloader, optimizer)
        self.tokenizer = tokenizer
    
    def loss(self, logits, logits_mask, tokens, tokens_mask, blank):
        log_softmax_vectors = torch.log_softmax(logits, dim=2)
        logits_lengths = torch.sum(logits_mask, dim=1)
        tokens = torch.tensor(tokens)
        tokens_lengths = torch.sum(torch.tensor(tokens_mask), dim=1)
        loss = ctc_loss(
            log_probs=log_softmax_vectors,
            targets=tokens,
            input_lengths=logits_lengths,
            target_lengths=tokens_lengths,
            blank=blank,
            zero_infinity=True,
        )
        return loss
    
    def _prepare_model_input(self, batch):
        return {
            "X": batch["arrays"]
        }
    
    def _prepare_loss_input(self, batch, model_output):
        tokenizer_output = tokenizer.batch_encode_plus([exm.upper() for exm in batch["transcriptions"]], padding=True)
        
        return {
            "logits": model_output[0], 
            "logits_mask": model_output[1], 
            "tokens": torch.tensor(tokenizer_output["input_ids"]).to("cuda"),
            "tokens_mask": torch.tensor(tokenizer_output["attention_mask"]).to("cuda"), 
            "blank": self.tokenizer.word_delimiter_token_id
        }

In [284]:
from torch.optim import Adam

In [285]:
optimizer = Adam(model.parameters(), lr=0.00001)

In [286]:
trainer = WhisperTrainer(model, loader, optimizer, tokenizer)

In [288]:
trainer.train(1)

EPOCH 0


Reading metadata...: 948736it [00:33, 28058.17it/s]
  tokens = torch.tensor(tokens)
  tokens_lengths = torch.sum(torch.tensor(tokens_mask), dim=1)


loss: 0.3525269031524658
loss: 0.8549197912216187
loss: 0.9841049313545227
loss: 0.7877734899520874
loss: 0.4664420187473297
loss: 0.5828015208244324
loss: 0.7376987338066101
loss: 0.7080963850021362
loss: 0.3800029754638672
loss: 0.4763419032096863
loss: 0.6368874311447144
loss: 0.7256090641021729
loss: 0.6380287408828735


KeyboardInterrupt: 