<a href="https://colab.research.google.com/github/HamdanXI/nlp_adventure/blob/main/wav2vec2-noCTC-base-lj-speech-DifferentStructure-ManyChanges.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets>=1.18.3
!pip install transformers==4.11.3
!pip install librosa
!pip install jiwer
!pip install transformers[torch]
!apt install git-lfs

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
import torch
from torch.nn import Transformer, Linear, Embedding
from transformers import Wav2Vec2Processor, Wav2Vec2Config
from datasets import load_dataset, load_metric
from transformers import Trainer, TrainingArguments
import numpy as np
import json

import torch.nn as nn
from transformers import Wav2Vec2Model, Wav2Vec2Config
from torch.nn import Transformer

class Wav2Vec2Seq2Seq(nn.Module):
    def __init__(self, vocab_size):
        super(Wav2Vec2Seq2Seq, self).__init__()
        config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
        self.encoder = Wav2Vec2Model(config)
        self.decoder = Transformer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            num_encoder_layers=0,
            num_decoder_layers=config.num_hidden_layers
        )

        self.proj = nn.Linear(config.hidden_size, vocab_size)

    def forward(self, input_values, decoder_input_ids):
        encoder_outputs = self.encoder(input_values).last_hidden_state
        decoder_outputs = self.decoder(tgt=decoder_input_ids, memory=encoder_outputs)
        logits = self.proj(decoder_outputs)
        return logits

from transformers import PreTrainedTokenizer
import json

class CharTokenizer(PreTrainedTokenizer):
    def __init__(self, vocab_file, **kwargs):
        with open(vocab_file, 'r') as f:
            self.vocab = json.load(f)
        self.ids_to_tokens = {id: token for token, id in self.vocab.items()}
        super().__init__(**kwargs)

    def _tokenize(self, text):
        return list(text)

    def _convert_token_to_id(self, token):
        return self.vocab.get(token, self.vocab.get("[UNK]"))

    def _convert_id_to_token(self, id):
        return self.ids_to_tokens.get(id, "[UNK]")

    def convert_tokens_to_string(self, tokens):
        return ''.join(tokens)

    def get_vocab(self):
        return dict(self.vocab)

from datasets import load_dataset, load_metric

timit = load_dataset("HamdanXI/lj_speech_DifferentStructure")
repo_name = "wav2vec2-noCTC-base-lj-speech-DifferentStructure"

def remove_vocab(text):
    vocab_to_remove = ['1', 'é', '”', 'è', 'â', '6', 'à', '3', '&', ')', '£', '8', '7', '0', '“', 'ê', '’', '2', '5', 'ü', '9', '4', '(']
    for v in vocab_to_remove:
        text = text.replace(v, '')
    return text.strip()

timit["train"] = timit["train"].map(lambda x: {"text": remove_vocab(x["text"])})
timit["test"] = timit["test"].map(lambda x: {"text": remove_vocab(x["text"])})

def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocabs = timit.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])

vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(vocab_list)}

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

tokenizer = CharTokenizer(vocab_file="./vocab.json", unk_token="[UNK]", pad_token="[PAD]")

from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=22050, padding_value=0.0, do_normalize=True, return_attention_mask=False)

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    assert "labels" in batch, f"Failed to process batch: {batch}"
    return batch

timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=1)

vocab_size = len(tokenizer)

class Wav2Vec2ForSeq2Seq(torch.nn.Module):
    def __init__(self, config: Wav2Vec2Config, vocab_size: int):
        super(Wav2Vec2ForSeq2Seq, self).__init__()
        self.wav2vec2 = Wav2Vec2Model(config)
        self.decoder = Transformer(
            d_model=config.hidden_size,
            nhead=config.num_attention_heads,
            num_encoder_layers=0,
            num_decoder_layers=config.num_hidden_layers
        )
        self.embedding = Embedding(vocab_size, config.hidden_size)
        self.output_layer = Linear(config.hidden_size, vocab_size)

    def forward(self, input_values, labels=None):
        encoder_outputs = self.wav2vec2(input_values).last_hidden_state

        if labels is not None:
            decoder_input = labels[:, :-1]
            embedded = self.embedding(decoder_input)
            decoder_outputs = self.decoder(embedded, encoder_outputs)
            logits = self.output_layer(decoder_outputs)
            return logits
        else:
            pass

class DataCollatorForSeq2Seq:
    def __init__(self, processor: Wav2Vec2Processor, model: Wav2Vec2ForSeq2Seq, padding=True):
        self.processor = processor
        self.model = model
        self.padding = padding

    def __call__(self, batch):
        input_values = self.processor.feature_extractor.pad(
            [feature["input_values"] for feature in batch],
            padding=self.padding,
            return_tensors="pt"
        )["input_values"]

        labels_with_tokens = []
        for feature in batch:
            labels_with_tokens.append([vocab_dict["[START]"]] + feature["labels"] + [vocab_dict["[END]"]])

        labels = self.processor.tokenizer.pad(
            labels_with_tokens,
            padding=self.padding,
            return_tensors="pt"
        )["input_ids"]

        return {
            "input_values": input_values,
            "labels": labels
        }

config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base")
vocab_size = len(vocab_dict)
model = Wav2Vec2ForSeq2Seq(config, vocab_size)

data_collator = DataCollatorForSeq2Seq(processor=processor, model=model)

wer_metric = load_metric("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == processor.tokenizer.pad_token_id] = -100
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

training_args = TrainingArguments(
    output_dir="wav2vec2-seq2seq",
    group_by_length=True,
    per_device_train_batch_size=32,
    evaluation_strategy="steps",
    num_train_epochs=2,
    fp16=True,
    gradient_checkpointing=True,
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=timit["train"],
    eval_dataset=timit["test"],
    tokenizer=processor.feature_extractor,
)

Map:   0%|          | 0/4620 [00:00<?, ? examples/s]

Map:   0%|          | 0/1680 [00:00<?, ? examples/s]

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()