# Day 3. UCU Acoustic School Home task

## Imports and Constants

In [5]:
import whisper
from whisper import Whisper, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
from task1 import wer_torch

import torch
from torch.utils.data import DataLoader, Dataset
from torchaudio.datasets import LIBRISPEECH
from torchaudio.compliance.kaldi import mfcc
from tqdm import tqdm

In [6]:
DEFAULT_BATCH_SIZE = 3
DEFAULT_LEARNING_RATE = 1e-3

## Data manipulations

In [7]:
class MyDataset ( LIBRISPEECH ):
    def __init__(self, root, tokenizer, url="train-clean-100"):
        super().__init__(".", url=url, download=True)
        self.root = root
        self.tokenizer = tokenizer    

    def __getitem__(self, index):
        wav, sr, text, speaker_id, chapter_id, utterance_id = super().__getitem__(index)

        padded_wav = pad_or_trim(wav)
        spectrogram = log_mel_spectrogram(padded_wav)

        text = text.lower()

        tokenized_text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        tokenized_labels = tokenized_text[1:] + [self.tokenizer.eot]

        spectrogram = spectrogram.squeeze()

        return spectrogram, tokenized_labels, tokenized_text

In [8]:
def collate_fn (batch):
    spectrograms, labels, texts = zip(*batch)

    max_label_len = max([len(label) for label in labels])
    padded_labels = [label + [0] * (max_label_len - len(label)) for label in labels]

    max_text_len = max([len(text) for text in texts])
    padded_texts = [text + [0] * (max_text_len - len(text)) for text in texts]

    spectrograms = torch.stack([torch.FloatTensor(spec) for spec in spectrograms])
    padded_labels = torch.LongTensor(padded_labels)
    padded_texts = torch.LongTensor(padded_texts)

    return spectrograms, padded_labels, padded_texts

## Training

In [9]:
class Trainer :
    def __init__(self,
            model: Whisper, train_dataset , valid_dataset, output_dir, lang, device, n_epoch, 
            batch_size=DEFAULT_BATCH_SIZE, lr=DEFAULT_LEARNING_RATE):
        self.model = model
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.output_dir = output_dir
        self.lang = lang
        self.device = device
        self.n_epoch = n_epoch
        self.batch_size = batch_size
        self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)
        self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.best_val_loss = float("inf")

    def train_step(self, input_spec, target_labels):
        input_spec = input_spec.to(self.device)
        target_labels = target_labels.to(self.device)

        self.model.train()
        self.optimizer.zero_grad()

        output = self.model(input_spec, target_labels)
        output = output.transpose(1, 2)
        loss = self.criterion(output, target_labels)

        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self):
        for epoch in range(self.n_epoch):
            epoch_loss = 0.0

            for input_spec, target_labels, _ in tqdm(self.train_dataloader):
                loss = self.train_step(input_spec, target_labels)
                epoch_loss += loss

            epoch_loss = epoch_loss / len(self.train_dataloader)

            print(f'EPOCH {epoch}')
            print(f"Training loss: {epoch_loss:.4f}")
            val_loss, wer = self.validate()

            print(f'Validation loss: {val_loss}')
            print(f'Validation WER:  {wer}')
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.model.state_dict(), "./whisper_librespeech_shevtsov.pt")
                print("Saved state dict!")

    
    def validate (self):
        val_loss = 0
        wers = 0
        self.model.eval()    
        
        for input_spec, target_labels, _ in tqdm(self.valid_dataloader):  # The last batch can't be a src
            with torch.no_grad():
                input_spec = input_spec.to(self.device)
                target_labels = target_labels.to(self.device)
                output = self.model(input_spec, target_labels)
                output = output.transpose(1, 2)
                loss = self.criterion(output, target_labels)              
                # prediction = prediction.reshape(batch_size * seq_len, -1)   
                loss = self.criterion(output, target_labels)
                output_labels = output.argmax(dim=1)

                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)
                val_loss += loss.item() * input_spec.size()[0]
                wers += wer_torch(output_labels, target_labels)
                # cers += cer(prediction, y_val)
        return val_loss / len(self.valid_dataloader), wers / len(self.valid_dataloader)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_epoch = 1
lang = 'en'
tokenizer = get_tokenizer(True, language=lang, task='transcribe')
train_dataset = MyDataset(".", tokenizer)
valid_dataset = MyDataset(".", tokenizer, "dev-clean")
output_dir = 'results'

In [12]:
params = torch.load("patriotic_whisper_mixed_en_uk.pt")
model = whisper.load_model("tiny", device=device)
model.load_state_dict(params)
print("Patriotic model loaded!")

Patriotic model loaded!


In [13]:
trainer = Trainer(model, train_dataset , valid_dataset, output_dir, lang, device, n_epoch )

In [14]:
trainer.train()

100%|██████████| 9513/9513 [1:29:07<00:00,  1.78it/s]  


EPOCH 0
Training loss: 0.1773


100%|██████████| 901/901 [18:34<00:00,  1.24s/it]


Validation loss: 0.10212980939721669
Validation WER:  0.0036770787555724382
Saved state dict!


## Testing trained model

In [32]:
spec, labels, text = next(trainer.valid_dataloader._get_iterator())

In [35]:
tokenizer.decode(text[0])

'<|startoftranscript|><|en|><|transcribe|><|notimestamps|>mister quilter is the apostle of the middle classes and we are glad to welcome his gospel!!!!!!!!!!!!!!!!'

In [38]:
spec = spec.to(device)
labels = labels.to(device)
output = model(spec, labels)
output_text = output.argmax(dim=2)
tokenizer.decode(output_text[0])

'<|en|><|transcribe|><|notimestamps|>mister forter is the apostle of the middle classes and we are glad to welcome his gospel<|endoftext|>!!!!!!!!!!!!!!!!'