In [1]:
from whisper import Whisper, load_model, log_mel_spectrogram, pad_or_trim
from whisper.tokenizer import get_tokenizer
from task1 import cer, wer

import torch
from torch.utils.data import DataLoader, Dataset
from torchaudio.datasets import LIBRISPEECH
from torchaudio.compliance.kaldi import mfcc
from tqdm import tqdm

In [2]:
def collate_fn ( items : list ):
    data = torch.stack([torch.Tensor(item) for item in items])
    return data

In [3]:
class MyDataset ( LIBRISPEECH ):
    def __init__(self, root, tokenizer, url="train-clean-100"):
        super().__init__(".", url=url, download=True)
        self.root = root
        self.tokenizer = tokenizer    
    # def __make_frames(wav):
    #     return mfcc(wav)

    def __getitem__(self, index):
        wav, sr, text, speaker_id, chapter_id, utterance_id = super().__getitem__(index)
        padded_wav = pad_or_trim(wav)
        spectrogram = log_mel_spectrogram(padded_wav)

        text = text.lower()

        tokenized_text = self.tokenizer.sot_sequnce_including_notimestamps.extend(self.tokenizer.encode(text))
        tokenized_labels = tokenized_text[1:] + [self.tokenizer.eot]

        return spectrogram, tokenized_labels, tokenized_text

In [4]:
class Trainer :
    def __init__(self,
            model: Whisper, train_dataset , valid_dataset, output_dir, lang, device, n_epoch, 
            batch_size=128, lr=1e-3):
        self.model = model
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.output_dir = output_dir
        self.lang = lang
        self.device = device
        self.n_epoch = n_epoch
        self.batch_size = batch_size
        self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)
        self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn)
        self.optimizer = torch.optim.Adam(lr=lr)
        self.criterion = wer


    
    def train_step ( self ):
        epoch_loss = 0
        self.model.train()    
        
        for X_train, y_train in tqdm(self.train_dataloader):  # The last batch can't be a src
            self.optimizer.zero_grad()

            X_train, y_train = X_train.to(self.device), y_train.to(self.device)
            prediction = self.model.transcribe(X_train)               

            # prediction = prediction.reshape(batch_size * seq_len, -1)   
            loss = torch.tensor(self.criterion(prediction, y_train), device=self.device)
            
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)
            self.optimizer.step()
            epoch_loss += loss.item() * X_train.size()[0]
        return epoch_loss / len(self.train_dataloader)
    
    def train ( self ):
        for epoch in range(self.n_epoch):
            print(f'STARTING EPOCH {epoch}')
            train_loss = self.train_step()
            val_loss, wer, cer = self.validate()
            print(f'Training loss:   {train_loss}')
            print(f'Validation loss: {val_loss}')
            print(f'Validation WER:  {wer}')
            print(f'Validation CER:  {cer}')

    
    def validate (self):
        val_loss = 0
        wers = 0
        cers = 0
        self.model.train()    
        
        for X_val, y_val in tqdm(self.valid_dataloader):  # The last batch can't be a src
            with torch.no_grad():
                X_val, y_val = X_val.to(self.device), y_val.to(self.device)
                prediction = self.model.transcribe(X_val)               

                # prediction = prediction.reshape(batch_size * seq_len, -1)   
                loss = torch.tensor(self.criterion(prediction, y_val), device=self.device)
                
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)
                val_loss += loss.item() * X_val.size()[0]
                wers += wer(prediction, y_val)
                cers += cer(prediction, y_val)
        return val_loss / len(self.valid_dataloader), wers / len(self.valid_dataloader), cers / len(self.valid_dataloader)

In [8]:
device = torch.cuda if torch.cuda.is_available() else torch.cpu
n_epoch = 1
lang = 'en'
tokenizer = get_tokenizer(True, language=lang, task='transcribe')
train_dataset = MyDataset(".", tokenizer)
valid_dataset = MyDataset(".", tokenizer, "dev-clean")
output_dir = 'results'

In [9]:
torch.cuda.is_available()

False

In [7]:
params = torch.load("patriotic_whisper_mixed_en_uk.pt")
model = load_model("tiny" , device=device)
model.load_state_dict(params)
print("Patriotic model loaded!")

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
trainer = Trainer (
    model ,
    train_dataset ,
    valid_dataset ,
    output_dir ,
    lang ,
    device ,
    n_epoch )